Add more contexts everywhere

pull/152/head
Tulir Asokan 2024-01-07 22:44:06 +02:00
parent 0a302c753d
commit 25bc36bc7a
37 changed files with 886 additions and 840 deletions

View File

@ -2,8 +2,8 @@
* **Breaking change *(bridge)*** Added raw event to portal membership handling
functions.
* **Breaking change *(client)*** Added context parameters to all functions
(thanks to [@recht] in [#144]).
* **Breaking change *(everything)*** Added context parameters to all functions
(started by [@recht] in [#144]).
* *(crypto)* Added experimental pure Go Olm implementation to replace libolm
(thanks to [@DerLukas15] in [#106]).
* You can use the `goolm` build tag to the new implementation.

View File

@ -93,12 +93,12 @@ type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
type StateStore interface {
mautrix.StateStore
IsRegistered(userID id.UserID) bool
MarkRegistered(userID id.UserID)
IsRegistered(ctx context.Context, userID id.UserID) (bool, error)
MarkRegistered(ctx context.Context, userID id.UserID) error
GetPowerLevel(roomID id.RoomID, userID id.UserID) int
GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int
HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool
GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error)
GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error)
HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error)
}
// AppService is the main config for all appservices.

View File

@ -236,7 +236,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
}
if evt.Type.IsState() {
mautrix.UpdateStateStore(as.StateStore, evt)
mautrix.UpdateStateStore(ctx, as.StateStore, evt)
}
var ch chan *event.Event
if evt.Type.Class == event.ToDeviceEventType {

View File

@ -13,6 +13,8 @@ import (
"strings"
"sync"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@ -57,17 +59,26 @@ func (intent *IntentAPI) Register(ctx context.Context) error {
}
func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error {
if intent.IsCustomPuppet {
return nil
}
intent.registerLock.Lock()
defer intent.registerLock.Unlock()
if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) {
isRegistered, err := intent.as.StateStore.IsRegistered(ctx, intent.UserID)
if err != nil {
return fmt.Errorf("failed to check if user is registered: %w", err)
} else if isRegistered {
return nil
}
err := intent.Register(ctx)
err = intent.Register(ctx)
if err != nil && !errors.Is(err, mautrix.MUserInUse) {
return fmt.Errorf("failed to ensure registered: %w", err)
}
intent.as.StateStore.MarkRegistered(intent.UserID)
err = intent.as.StateStore.MarkRegistered(ctx, intent.UserID)
if err != nil {
return fmt.Errorf("failed to mark user as registered in state store: %w", err)
}
return nil
}
@ -83,7 +94,7 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
} else if len(extra) == 1 {
params = extra[0]
}
if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache {
if intent.as.StateStore.IsInRoom(ctx, roomID, intent.UserID) && !params.IgnoreCache {
return nil
}
@ -111,7 +122,10 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
return fmt.Errorf("failed to ensure joined after invite: %w", err)
}
}
intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin)
err = intent.as.StateStore.SetMembership(ctx, resp.RoomID, intent.UserID, event.MembershipJoin)
if err != nil {
return fmt.Errorf("failed to set membership in state store: %w", err)
}
return nil
}
@ -205,13 +219,14 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
Membership: membership,
Reason: reason,
}
memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target)
if !ok {
memberContent, err := intent.as.StateStore.TryGetMember(ctx, roomID, target)
if err != nil {
return nil, fmt.Errorf("failed to get old member content from state store: %w", err)
} else if memberContent == nil {
if intent.as.GetProfile != nil {
memberContent = intent.as.GetProfile(target, roomID)
ok = memberContent != nil
}
if !ok {
if memberContent == nil {
profile, err := intent.GetProfile(ctx, target)
if err != nil {
intent.Log.Debug().Err(err).
@ -224,7 +239,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
}
}
}
if ok && memberContent != nil {
if memberContent != nil {
content.Displayname = memberContent.Displayname
content.AvatarURL = memberContent.AvatarURL
}
@ -297,15 +312,25 @@ func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *m
}
func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := intent.as.StateStore.TryGetMember(roomID, userID)
if !ok {
member, err := intent.as.StateStore.TryGetMember(ctx, roomID, userID)
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Str("room_id", roomID.String()).
Str("user_id", userID.String()).
Msg("Failed to get member from state store")
}
if member == nil {
_ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member)
}
return member
}
func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
pl = intent.as.StateStore.GetPowerLevels(roomID)
pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to get cached power levels: %w", err)
return
}
if pl == nil {
pl = &event.PowerLevelsEventContent{}
err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl)
@ -417,7 +442,7 @@ func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error
}
func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error {
if !intent.as.StateStore.IsInvited(roomID, userID) {
if !intent.as.StateStore.IsInvited(ctx, roomID, userID) {
_, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{
UserID: userID,
})

View File

@ -10,9 +10,8 @@ import (
"os"
"regexp"
"gopkg.in/yaml.v3"
"go.mau.fi/util/random"
"gopkg.in/yaml.v3"
)
// Registration contains the data in a Matrix appservice registration.

View File

@ -215,15 +215,15 @@ type Bridge struct {
type Crypto interface {
HandleMemberEvent(*event.Event)
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, *event.Content) error
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
Decrypt(context.Context, *event.Event) (*event.Event, error)
Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error
WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
ResetSession(context.Context, id.RoomID)
Init(ctx context.Context) error
Start()
Stop()
Reset(startAfterReset bool)
Reset(ctx context.Context, startAfterReset bool)
Client() *mautrix.Client
ShareKeys(context.Context) error
}
@ -650,10 +650,10 @@ func (br *Bridge) WaitWebsocketConnected() {
func (br *Bridge) start() {
br.ZLog.Debug().Msg("Running database upgrades")
err := br.DB.Upgrade()
err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO()))
if err != nil {
br.LogDBUpgradeErrorAndExit("main", err)
} else if err = br.StateStore.Upgrade(); err != nil {
} else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil {
br.LogDBUpgradeErrorAndExit("matrix_state", err)
}
@ -679,7 +679,7 @@ func (br *Bridge) start() {
go br.fetchMediaConfig(ctx)
if br.Crypto != nil {
err = br.Crypto.Init()
err = br.Crypto.Init(ctx)
if err != nil {
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption")
os.Exit(19)

View File

@ -17,7 +17,7 @@ var CommandDiscardMegolmSession = &FullHandler{
if ce.Bridge.Crypto == nil {
ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled")
} else {
ce.Bridge.Crypto.ResetSession(ce.RoomID)
ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID)
ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.")
}
},

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -61,7 +61,7 @@ func NewCryptoHelper(bridge *Bridge) Crypto {
}
}
func (helper *CryptoHelper) Init() error {
func (helper *CryptoHelper) Init(ctx context.Context) error {
if len(helper.bridge.CryptoPickleKey) == 0 {
panic("CryptoPickleKey not set")
}
@ -75,13 +75,13 @@ func (helper *CryptoHelper) Init() error {
helper.bridge.CryptoPickleKey,
)
err := helper.store.DB.Upgrade()
err := helper.store.DB.Upgrade(ctx)
if err != nil {
helper.bridge.LogDBUpgradeErrorAndExit("crypto", err)
}
var isExistingDevice bool
helper.client, isExistingDevice, err = helper.loginBot(context.TODO())
helper.client, isExistingDevice, err = helper.loginBot(ctx)
if err != nil {
return err
}
@ -111,7 +111,7 @@ func (helper *CryptoHelper) Init() error {
}
if encryptionConfig.DeleteKeys.DeleteOutdatedInbound {
deleted, err := helper.store.RedactOutdatedGroupSessions()
deleted, err := helper.store.RedactOutdatedGroupSessions(ctx)
if err != nil {
return err
}
@ -123,12 +123,12 @@ func (helper *CryptoHelper) Init() error {
helper.client.Syncer = &cryptoSyncer{helper.mach}
helper.client.Store = helper.store
err = helper.mach.Load()
err = helper.mach.Load(ctx)
if err != nil {
return err
}
if isExistingDevice {
helper.verifyKeysAreOnServer(context.TODO())
helper.verifyKeysAreOnServer(ctx)
}
go helper.resyncEncryptionInfo(context.TODO())
@ -138,22 +138,16 @@ func (helper *CryptoHelper) Init() error {
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
if err != nil {
log.Err(err).Msg("Failed to query rooms for resync")
return
}
var roomIDs []id.RoomID
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
log.Err(err).Msg("Failed to scan room ID")
continue
}
roomIDs = append(roomIDs, roomID)
roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
if err != nil {
log.Err(err).Msg("Failed to scan rooms for resync")
return
}
_ = rows.Close()
if len(roomIDs) > 0 {
log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms")
for _, roomID := range roomIDs {
@ -161,7 +155,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
if err != nil {
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
_, err = helper.bridge.DB.ExecContext(ctx, `
_, err = helper.bridge.DB.Exec(ctx, `
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
`, roomID)
if err != nil {
@ -182,7 +176,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
Int("max_messages", maxMessages).
Interface("content", &evt).
Msg("Resynced encryption event")
_, err = helper.bridge.DB.ExecContext(ctx, `
_, err = helper.bridge.DB.Exec(ctx, `
UPDATE crypto_megolm_inbound_session
SET max_age=$1, max_messages=$2
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
@ -223,8 +217,10 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device
}
func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) {
deviceID := helper.store.FindDeviceID()
if len(deviceID) > 0 {
deviceID, err := helper.store.FindDeviceID(ctx)
if err != nil {
return nil, false, fmt.Errorf("failed to find existing device ID: %w", err)
} else if len(deviceID) > 0 {
helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
}
// Create a new client instance with the default AS settings (including as_token),
@ -270,7 +266,7 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
return
}
helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto")
helper.Reset(false)
helper.Reset(ctx, false)
}
func (helper *CryptoHelper) Start() {
@ -306,16 +302,16 @@ func (helper *CryptoHelper) Stop() {
helper.syncDone.Wait()
}
func (helper *CryptoHelper) clearDatabase() {
_, err := helper.store.DB.Exec("DELETE FROM crypto_account")
func (helper *CryptoHelper) clearDatabase(ctx context.Context) {
_, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account")
if err != nil {
helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table")
}
_, err = helper.store.DB.Exec("DELETE FROM crypto_olm_session")
_, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session")
if err != nil {
helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table")
}
_, err = helper.store.DB.Exec("DELETE FROM crypto_megolm_outbound_session")
_, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session")
if err != nil {
helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table")
}
@ -325,22 +321,22 @@ func (helper *CryptoHelper) clearDatabase() {
//_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures")
}
func (helper *CryptoHelper) Reset(startAfterReset bool) {
func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) {
helper.lock.Lock()
defer helper.lock.Unlock()
helper.log.Info().Msg("Resetting end-to-bridge encryption device")
helper.Stop()
helper.log.Debug().Msg("Crypto syncer stopped, clearing database")
helper.clearDatabase()
helper.clearDatabase(ctx)
helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions")
_, err := helper.client.LogoutAll(context.TODO())
_, err := helper.client.LogoutAll(ctx)
if err != nil {
helper.log.Warn().Err(err).Msg("Failed to log out all devices")
}
helper.client = nil
helper.store = nil
helper.mach = nil
err = helper.Init()
err = helper.Init(ctx)
if err != nil {
helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption")
os.Exit(50)
@ -355,25 +351,24 @@ func (helper *CryptoHelper) Client() *mautrix.Client {
return helper.client
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
return helper.mach.DecryptMegolmEvent(ctx, evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content *event.Content) (err error) {
func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) {
helper.lock.RLock()
defer helper.lock.RUnlock()
var encrypted *event.EncryptedEventContent
ctx := context.TODO()
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
return
}
helper.log.Debug().Err(err).
Str("room_id", roomID.String()).
Msg("Got error while encrypting event for room, sharing group session and trying again...")
var users []id.UserID
users, err = helper.store.GetRoomJoinedOrInvitedMembers(roomID)
users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to get room member list: %w", err)
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
@ -389,10 +384,10 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
return
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
helper.lock.RLock()
defer helper.lock.RUnlock()
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
@ -419,10 +414,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
}
}
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) {
helper.lock.RLock()
defer helper.lock.RUnlock()
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
if err != nil {
helper.log.Debug().Err(err).
Str("room_id", roomID.String()).
@ -499,18 +494,18 @@ type cryptoStateStore struct {
var _ crypto.StateStore = (*cryptoStateStore)(nil)
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) {
portal := c.bridge.Child.GetIPortal(id)
if portal != nil {
return portal.IsEncrypted()
return portal.IsEncrypted(), nil
}
return c.bridge.StateStore.IsEncrypted(id)
return c.bridge.StateStore.IsEncrypted(ctx, id)
}
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
return c.bridge.StateStore.FindSharedRooms(id)
func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) {
return c.bridge.StateStore.FindSharedRooms(ctx, id)
}
func (c *cryptoStateStore) GetEncryptionEvent(id id.RoomID) *event.EncryptionEventContent {
return c.bridge.StateStore.GetEncryptionEvent(id)
func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) {
return c.bridge.StateStore.GetEncryptionEvent(ctx, id)
}

View File

@ -9,6 +9,8 @@
package bridge
import (
"context"
"github.com/lib/pq"
"go.mau.fi/util/dbutil"
@ -36,9 +38,9 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id
}
}
func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) {
var rows dbutil.Rows
rows, err = store.DB.Query(`
rows, err = store.DB.Query(ctx, `
SELECT user_id FROM mx_user_profile
WHERE room_id=$1
AND (membership='join' OR membership='invite')

View File

@ -494,7 +494,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
log.Debug().Msg("Decrypting received event")
decryptionStart := time.Now()
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt)
decryptionRetryCount := 0
if errors.Is(err, NoSessionFound) {
decryptionRetryCount = 1
@ -502,9 +502,9 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0)
if mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = mx.bridge.Crypto.Decrypt(evt)
decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt)
} else {
go mx.waitLongerForSession(ctx, evt, decryptionStart)
return
@ -529,14 +529,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev
go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false)
if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
log.Debug().Msg("Didn't get session, giving up trying to decrypt event")
mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true)
return
}
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt event")
mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true)

113
client.go
View File

@ -27,11 +27,11 @@ import (
)
type CryptoHelper interface {
Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error)
Decrypt(*event.Event) (*event.Event, error)
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
Encrypt(context.Context, id.RoomID, event.Type, any) (*event.EncryptedEventContent, error)
Decrypt(context.Context, *event.Event) (*event.Event, error)
WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
Init() error
Init(context.Context) error
}
// Deprecated: switch to zerolog
@ -846,7 +846,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin
}
_, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
if err != nil {
err = fmt.Errorf("failed to update state store: %w", err)
}
}
return
}
@ -858,7 +861,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin
func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) {
_, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
if err != nil {
err = fmt.Errorf("failed to update state store: %w", err)
}
}
return
}
@ -1000,13 +1006,20 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
}
if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted && cli.StateStore.IsEncrypted(roomID) {
contentJSON, err = cli.Crypto.Encrypt(roomID, eventType, contentJSON)
if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
var isEncrypted bool
isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to encrypt event: %w", err)
err = fmt.Errorf("failed to check if room is encrypted: %w", err)
return
}
eventType = event.EventEncrypted
if isEncrypted {
if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
err = fmt.Errorf("failed to encrypt event: %w", err)
return
}
eventType = event.EventEncrypted
}
}
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
@ -1021,7 +1034,7 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
}
return
}
@ -1034,7 +1047,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID,
})
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
}
return
}
@ -1100,19 +1113,29 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re
urlPath := cli.BuildClientURL("v3", "createRoom")
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
if storeErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(storeErr).
Stringer("creator_user_id", cli.UserID).
Msg("Failed to update creator membership in state store after creating room")
}
for _, evt := range req.InitialState {
UpdateStateStore(cli.StateStore, evt)
UpdateStateStore(ctx, cli.StateStore, evt)
}
inviteMembership := event.MembershipInvite
if req.BeeperAutoJoinInvites {
inviteMembership = event.MembershipJoin
}
for _, invitee := range req.Invite {
cli.StateStore.SetMembership(resp.RoomID, invitee, inviteMembership)
storeErr = cli.StateStore.SetMembership(ctx, resp.RoomID, invitee, inviteMembership)
if storeErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(storeErr).
Stringer("invitee_user_id", invitee).
Msg("Failed to update membership in state store after creating room")
}
}
for _, evt := range req.InitialState {
cli.updateStoreWithOutgoingEvent(resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
}
}
return
@ -1129,7 +1152,10 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq
u := cli.BuildClientURL("v3", "rooms", roomID, "leave")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave)
err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave)
if err != nil {
err = fmt.Errorf("failed to update membership in state store: %w", err)
}
}
return
}
@ -1146,7 +1172,10 @@ func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInv
u := cli.BuildClientURL("v3", "rooms", roomID, "invite")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite)
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite)
if err != nil {
err = fmt.Errorf("failed to update membership in state store: %w", err)
}
}
return
}
@ -1163,7 +1192,10 @@ func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickU
u := cli.BuildClientURL("v3", "rooms", roomID, "kick")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave)
if err != nil {
err = fmt.Errorf("failed to update membership in state store: %w", err)
}
}
return
}
@ -1173,7 +1205,10 @@ func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUse
u := cli.BuildClientURL("v3", "rooms", roomID, "ban")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan)
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan)
if err != nil {
err = fmt.Errorf("failed to update membership in state store: %w", err)
}
}
return
}
@ -1183,7 +1218,10 @@ func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnba
u := cli.BuildClientURL("v3", "rooms", roomID, "unban")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave)
if err != nil {
err = fmt.Errorf("failed to update membership in state store: %w", err)
}
}
return
}
@ -1216,7 +1254,7 @@ func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err
return
}
func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) {
func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) {
if cli.StateStore == nil {
return
}
@ -1246,7 +1284,7 @@ func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType even
}
return
}
UpdateStateStore(cli.StateStore, fakeEvt)
UpdateStateStore(ctx, cli.StateStore, fakeEvt)
}
// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with
@ -1256,7 +1294,7 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e
u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest(ctx, "GET", u, nil, outContent)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent)
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent)
}
return
}
@ -1310,10 +1348,13 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
Handler: parseRoomStateArray,
})
if err == nil && cli.StateStore != nil {
cli.StateStore.ClearCachedMembers(roomID)
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID)
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching state")
for _, evts := range stateMap {
for _, evt := range evts {
UpdateStateStore(cli.StateStore, evt)
UpdateStateStore(ctx, cli.StateStore, evt)
}
}
}
@ -1630,13 +1671,22 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
_, err = cli.MakeRequest(ctx, "GET", u, nil, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin)
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin)
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching joined members")
for userID, member := range resp.Joined {
cli.StateStore.SetMember(roomID, userID, &event.MemberEventContent{
updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{
Membership: event.MembershipJoin,
AvatarURL: id.ContentURIString(member.AvatarURL),
Displayname: member.DisplayName,
})
if updateErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Stringer("user_id", userID).
Msg("Failed to update membership in state store after fetching joined members")
}
}
}
return
@ -1665,10 +1715,13 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
clearMemberships = append(clearMemberships, extra.Membership)
}
if extra.NotMembership == "" {
cli.StateStore.ClearCachedMembers(roomID, clearMemberships...)
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...)
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching joined members")
}
for _, evt := range resp.Chunk {
UpdateStateStore(cli.StateStore, evt)
UpdateStateStore(ctx, cli.StateStore, evt)
}
}
return

View File

@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -42,7 +42,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *Cross
}
func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) {
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get keys from database: %w", err)
}

View File

@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -34,8 +34,8 @@ var (
ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC")
)
func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
if err != nil {
return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err)
}
@ -85,7 +85,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe
Str("signature", signature).
Msg("Signed master key of user with our user-signing key")
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@ -137,7 +137,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures)
}
if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@ -178,7 +178,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er
Str("signature", signature).
Msg("Signed own device key with self-signing key")
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}

View File

@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -19,7 +19,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
log := log.With().Str("user_id", userID.String()).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
@ -32,7 +32,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
log.Debug().
@ -50,7 +50,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
log.Error().Err(err).Msg("Error storing cross-signing key")
}
}
@ -85,7 +85,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
} else {
if verified {
log.Debug().Err(err).Msg("Cross-signing key signature verified")
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature)
if err != nil {
log.Error().Err(err).Msg("Error storing cross-signing key signature")
}

View File

@ -32,7 +32,7 @@ func getOlmMachine(t *testing.T) *OlmMachine {
t.Fatalf("Error opening db: %v", err)
}
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
if err = sqlStore.DB.Upgrade(); err != nil {
if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
t.Fatalf("Error creating tables: %v", err)
}
@ -41,9 +41,9 @@ func getOlmMachine(t *testing.T) *OlmMachine {
ssk, _ := olm.NewPkSigning()
usk, _ := olm.NewPkSigning()
sqlStore.PutCrossSigningKey(userID, id.XSUsageMaster, mk.PublicKey)
sqlStore.PutCrossSigningKey(userID, id.XSUsageSelfSigning, ssk.PublicKey)
sqlStore.PutCrossSigningKey(userID, id.XSUsageUserSigning, usk.PublicKey)
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey)
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey)
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey)
return &OlmMachine{
CryptoStore: sqlStore,
@ -70,9 +70,9 @@ func TestTrustOwnDevice(t *testing.T) {
t.Error("Own device trusted while it shouldn't be")
}
m.CryptoStore.PutSignature(ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey,
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
m.CryptoStore.PutSignature(ownDevice.UserID, ownDevice.SigningKey,
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2")
if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
@ -91,20 +91,20 @@ func TestTrustOtherUser(t *testing.T) {
}
theirMasterKey, _ := olm.NewPkSigning()
m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
// sign them with self-signing instead of user-signing key
m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig")
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
t.Error("Other user trusted before their master key has been signed with our user-signing key")
}
m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
@ -128,27 +128,27 @@ func TestTrustOtherDevice(t *testing.T) {
}
theirMasterKey, _ := olm.NewPkSigning()
m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
theirSSK, _ := olm.NewPkSigning()
m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey)
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey)
m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
t.Error("Other user not trusted while they should be")
}
m.CryptoStore.PutSignature(otherUser, theirSSK.PublicKey,
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey,
otherUser, theirMasterKey.PublicKey, "sig3")
if m.IsDeviceTrusted(theirDevice) {
t.Error("Other device trusted before it has been signed with user's SSK")
}
m.CryptoStore.PutSignature(otherUser, theirDevice.SigningKey,
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
otherUser, theirSSK.PublicKey, "sig4")
if !m.IsDeviceTrusted(theirDevice) {

View File

@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -23,7 +23,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
return device.Trust, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
@ -44,7 +44,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
Msg("Self-signing key of user not found")
return id.TrustStateUnset, nil
}
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
@ -57,7 +57,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
Msg("Self-signing key of user is not signed by their master key")
return id.TrustStateUnset, nil
}
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
@ -97,14 +97,14 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
return true, nil
}
// first we verify our user-signing key
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
return false, err
} else if !ourUserSigningKeyTrusted {
return false, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).
@ -118,7 +118,7 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
Msg("Master key of user not found")
return false, nil
}
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
sigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).

View File

@ -105,7 +105,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH
}, nil
}
func (helper *CryptoHelper) Init() error {
func (helper *CryptoHelper) Init(ctx context.Context) error {
if helper == nil {
return fmt.Errorf("crypto helper is nil")
}
@ -116,7 +116,7 @@ func (helper *CryptoHelper) Init() error {
var stateStore crypto.StateStore
if helper.managedStateStore != nil {
err := helper.managedStateStore.Upgrade()
err := helper.managedStateStore.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade client state store: %w", err)
}
@ -124,7 +124,6 @@ func (helper *CryptoHelper) Init() error {
} else {
stateStore = helper.client.StateStore.(crypto.StateStore)
}
ctx := context.TODO()
var cryptoStore crypto.Store
if helper.unmanagedCryptoStore == nil {
managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey)
@ -133,11 +132,14 @@ func (helper *CryptoHelper) Init() error {
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
helper.client.Store = managedCryptoStore
}
err := managedCryptoStore.DB.Upgrade()
err := managedCryptoStore.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
}
storedDeviceID := managedCryptoStore.FindDeviceID()
storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx)
if err != nil {
return fmt.Errorf("failed to find existing device ID: %w", err)
}
if helper.LoginAs != nil {
if storedDeviceID != "" {
helper.LoginAs.DeviceID = storedDeviceID
@ -168,7 +170,7 @@ func (helper *CryptoHelper) Init() error {
return fmt.Errorf("the client must be logged in")
}
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
err := helper.mach.Load()
err := helper.mach.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm account: %w", err)
} else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil {
@ -253,17 +255,18 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.
Str("session_id", content.SessionID.String()).
Logger()
log.Debug().Msg("Decrypting received event")
ctx := log.WithContext(context.TODO())
decrypted, err := helper.Decrypt(evt)
decrypted, err := helper.Decrypt(ctx, evt)
if errors.Is(err, NoSessionFound) {
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = helper.Decrypt(evt)
decrypted, err = helper.Decrypt(ctx, evt)
} else {
go helper.waitLongerForSession(log, src, evt)
go helper.waitLongerForSession(ctx, log, src, evt)
return
}
}
@ -306,20 +309,20 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
}
}
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
log.Debug().Msg("Didn't get session, giving up")
helper.DecryptErrorCallback(evt, NoSessionFound)
return
}
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
decrypted, err := helper.Decrypt(evt)
decrypted, err := helper.Decrypt(ctx, evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
@ -329,32 +332,31 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix
helper.postDecrypt(src, decrypted)
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
if helper == nil {
return false
}
helper.lock.RLock()
defer helper.lock.RUnlock()
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
return helper.mach.DecryptMegolmEvent(ctx, evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
helper.lock.RLock()
defer helper.lock.RUnlock()
ctx := context.TODO()
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
return
}
helper.log.Debug().
@ -362,7 +364,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
Str("room_id", roomID.String()).
Msg("Got session error while encrypting event, sharing group session and trying again")
var users []id.UserID
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to get room member list: %w", err)
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -91,7 +91,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
} else {
forwardedKeys = true
lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1]
device, _ = mach.CryptoStore.FindDeviceByKey(evt.Sender, id.IdentityKey(lastChainItem))
device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem))
if device != nil {
trustLevel = mach.ResolveTrust(device)
} else {
@ -188,7 +188,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
mach.megolmDecryptLock.Lock()
defer mach.megolmDecryptLock.Unlock()
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
@ -250,7 +250,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
Int("max_messages", sess.MaxMessages).
Logger()
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, RatchetError
@ -261,14 +261,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
} else {

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -159,7 +159,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
}
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
endTimeTrace()
if err != nil {
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
@ -170,7 +170,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
sessions, err := mach.CryptoStore.GetSessions(senderKey)
sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey)
endTimeTrace()
if err != nil {
return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err)
@ -199,7 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
endTimeTrace()
if err != nil {
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
@ -217,7 +217,7 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S
return nil, err
}
mach.saveAccount()
err = mach.CryptoStore.AddSession(senderKey, session)
err = mach.CryptoStore.AddSession(ctx, senderKey, session)
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -53,7 +53,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
Str("signed_device_id", deviceID.String()).
Str("signature", signature).
Msg("Verified self-signing signature")
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
err = mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
if err != nil {
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
@ -74,7 +74,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
}
// save signature of device made by its own device signing key
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
err := mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
if err != nil {
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
@ -96,7 +96,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
log := mach.machOrContextLog(ctx)
if !includeUntracked {
var err error
users, err = mach.CryptoStore.FilterTrackedUsers(users)
users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users)
if err != nil {
log.Warn().Err(err).Msg("Failed to filter tracked user list")
}
@ -123,7 +123,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
existingDevices, err := mach.CryptoStore.GetDevices(userID)
existingDevices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Warn().Err(err).Msg("Failed to get existing devices for user")
existingDevices = make(map[id.DeviceID]*id.Device)
@ -151,7 +151,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
}
}
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
err = mach.CryptoStore.PutDevices(userID, newDevices)
err = mach.CryptoStore.PutDevices(ctx, userID, newDevices)
if err != nil {
log.Warn().Err(err).Msg("Failed to update device list")
}
@ -169,7 +169,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
Str("identity_key", device.IdentityKey.String()).
Str("signing_key", device.SigningKey.String()).
Logger()
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, "", device.IdentityKey, "device removed")
if err != nil {
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
} else {
@ -179,7 +179,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
}
}
}
mach.OnDevicesChanged(userID)
mach.OnDevicesChanged(ctx, userID)
}
}
for userID := range req.DeviceKeys {
@ -197,18 +197,25 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
//
// This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does
// not need to be called manually.
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) {
if mach.DisableDeviceChangeKeyRotation {
return
}
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
mach.Log.Debug().
rooms, err := mach.StateStore.FindSharedRooms(ctx, userID)
if err != nil {
mach.machOrContextLog(ctx).Err(err).
Stringer("with_user_id", userID).
Msg("Failed to find shared rooms to invalidate group sessions")
return
}
for _, roomID := range rooms {
mach.machOrContextLog(ctx).Debug().
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Invalidating group session in room due to device change notification")
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
err = mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
if err != nil {
mach.Log.Warn().Err(err).
mach.machOrContextLog(ctx).Err(err).
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Failed to invalidate outbound group session")

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -84,7 +84,7 @@ func parseMessageIndex(ciphertext []byte) (uint, error) {
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
@ -116,7 +116,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
log = log.With().Uint("message_index", idx).Logger()
}
log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
if err != nil {
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
}
@ -137,7 +137,13 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
}
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID)
if err != nil {
mach.machOrContextLog(ctx).Err(err).
Stringer("room_id", roomID).
Msg("Failed to get encryption event in room")
}
session := NewOutboundGroupSession(roomID, encryptionEvent)
if !mach.DontStoreOutboundKeys {
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
@ -165,7 +171,7 @@ func strishArray[T ~string](arr []T) []string {
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
@ -192,7 +198,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
for _, userID := range users {
log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(userID)
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Error().Err(err).Msg("Failed to get devices of user")
} else if devices == nil {
@ -292,7 +298,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
log.Debug().Msg("Group session successfully shared")
session.Shared = true
return mach.CryptoStore.AddOutboundGroupSession(session)
return mach.CryptoStore.AddOutboundGroupSession(ctx, session)
}
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
@ -367,7 +373,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out
Reason: "This device does not encrypt messages for unverified devices",
}}
session.Users[userKey] = OGSIgnored
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil {
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
} else if deviceSession == nil {
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -38,7 +38,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
Str("olm_session_description", session.Describe()).
Msg("Encrypting olm message")
msgType, ciphertext := session.Encrypt(plaintext)
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
}
@ -54,8 +54,8 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
}
}
func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool {
if !mach.CryptoStore.HasSession(identityKey) {
func (mach *OlmMachine) shouldCreateNewSession(ctx context.Context, identityKey id.IdentityKey) bool {
if !mach.CryptoStore.HasSession(ctx, identityKey) {
return true
}
mach.devicesToUnwedgeLock.Lock()
@ -72,7 +72,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
for userID, devices := range input {
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
for deviceID, identity := range devices {
if mach.shouldCreateNewSession(identity.IdentityKey) {
if mach.shouldCreateNewSession(ctx, identity.IdentityKey) {
request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519
}
}
@ -117,7 +117,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
} else {
wrapped := wrapSession(sess)
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped)
if err != nil {
log.Error().Err(err).Msg("Failed to store created outbound session")
} else {

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,6 +8,7 @@ package crypto
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
@ -91,7 +92,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
return sessionsJSON, nil
}
func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) {
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
if session.Algorithm != id.AlgorithmMegolmV1 {
return false, ErrInvalidExportedAlgorithm
}
@ -112,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
// We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs)
err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs)
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}
@ -127,7 +128,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
// ImportKeys imports data that was exported with the format specified in the Matrix spec.
// See https://spec.matrix.org/v1.2/client-server-api/#key-exports
func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) {
func (mach *OlmMachine) ImportKeys(ctx context.Context, passphrase string, data []byte) (int, int, error) {
exportData, err := decodeKeyExport(data)
if err != nil {
return 0, 0, err
@ -143,8 +144,11 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
Str("room_id", session.RoomID.String()).
Str("session_id", session.SessionID.String()).
Logger()
imported, err := mach.importExportedRoomKey(session)
imported, err := mach.importExportedRoomKey(ctx, session)
if err != nil {
if ctx.Err() != nil {
return count, len(sessions), ctx.Err()
}
log.Error().Err(err).Msg("Failed to import Megolm session from file")
} else if imported {
log.Debug().Msg("Imported Megolm session from file")

View File

@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -152,7 +152,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
Msg("Mismatched session ID while creating inbound group session from forward")
return false
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
if err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
}
var maxAge time.Duration
var maxMessages int
if config != nil {
@ -178,7 +181,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
}
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
if err != nil {
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
@ -274,7 +277,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
return
}
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
@ -331,7 +334,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
Int("first_message_index", content.FirstMessageIndex).
Logger()
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Acked group session was already redacted")
@ -351,7 +354,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
if err != nil {
log.Err(err).Msg("Failed to redact group session")
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -80,11 +80,11 @@ type OlmMachine struct {
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
type StateStore interface {
// IsEncrypted returns whether a room is encrypted.
IsEncrypted(id.RoomID) bool
IsEncrypted(context.Context, id.RoomID) (bool, error)
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
FindSharedRooms(id.UserID) []id.RoomID
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
}
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
@ -131,8 +131,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load() (err error) {
mach.account, err = mach.CryptoStore.GetAccount()
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
mach.account, err = mach.CryptoStore.GetAccount(ctx)
if err != nil {
return
}
@ -143,15 +143,15 @@ func (mach *OlmMachine) Load() (err error) {
}
func (mach *OlmMachine) saveAccount() {
err := mach.CryptoStore.PutAccount(mach.account)
err := mach.CryptoStore.PutAccount(context.TODO(), mach.account)
if err != nil {
mach.Log.Error().Err(err).Msg("Failed to save account")
}
}
// FlushStore calls the Flush method of the CryptoStore.
func (mach *OlmMachine) FlushStore() error {
return mach.CryptoStore.Flush()
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
return mach.CryptoStore.Flush(ctx)
}
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
@ -284,7 +284,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
if !mach.StateStore.IsEncrypted(evt.RoomID) {
ctx := context.TODO()
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
Msg("Failed to check if room is encrypted to handle member event")
return
} else if !isEncrypted {
return
}
content := evt.Content.AsMember()
@ -311,7 +316,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
Str("prev_membership", string(prevContent.Membership)).
Str("new_membership", string(content.Membership)).
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
if err != nil {
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
}
@ -405,7 +410,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
// and if it's not found it asks the server for it.
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil {
@ -425,7 +430,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
if err != nil || deviceIdentity != nil {
return deviceIdentity, err
}
@ -455,7 +460,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey)
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
if err != nil {
return err
}
@ -499,7 +504,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
Msg("Mismatched session ID while creating inbound group session")
return
}
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
if err != nil {
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return
@ -525,7 +530,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
}
// WaitForSession waits for the given Megolm session to arrive.
func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
mach.keyWaitersLock.Lock()
ch, ok := mach.keyWaiters[sessionID]
if !ok {
@ -534,7 +539,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
}
mach.keyWaitersLock.Unlock()
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
return true
}
@ -542,10 +547,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
case <-ch:
return true
case <-time.After(timeout):
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
// Check if the session somehow appeared in the store without telling us
// We accept withheld sessions as received, as then the decryption attempt will show the error.
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
case <-ctx.Done():
return false
}
}
@ -568,7 +575,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
return
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
if err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
}
var maxAge time.Duration
var maxMessages int
if config != nil {
@ -589,7 +599,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
}
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
if err != nil {
log.Err(err).Msg("Failed to redact previous megolm sessions")
} else {
@ -606,7 +616,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
return
}
err := mach.CryptoStore.PutWithheldGroupSession(*content)
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
}
@ -662,7 +672,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
for {
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
if err != nil {
log.Err(err).Msg("Failed to redact expired megolm sessions")
} else if len(sessionIDs) > 0 {

View File

@ -20,18 +20,18 @@ import (
type mockStateStore struct{}
func (mockStateStore) IsEncrypted(id.RoomID) bool {
return true
func (mockStateStore) IsEncrypted(context.Context, id.RoomID) (bool, error) {
return true, nil
}
func (mockStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
func (mockStateStore) GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error) {
return &event.EncryptionEventContent{
RotationPeriodMessages: 3,
}
}, nil
}
func (mockStateStore) FindSharedRooms(id.UserID) []id.RoomID {
return []id.RoomID{"room1"}
func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error) {
return []id.RoomID{"room1"}, nil
}
func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
@ -47,7 +47,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
}
machine := NewOlmMachine(client, nil, gobStore, mockStateStore{})
if err := machine.Load(); err != nil {
if err := machine.Load(context.TODO()); err != nil {
t.Fatalf("Error creating account: %v", err)
}
@ -57,7 +57,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
func TestRatchetMegolmSession(t *testing.T) {
mach := newMachine(t, "user1")
outSess := mach.newOutboundGroupSession(context.TODO(), "meow")
inSess, err := mach.CryptoStore.GetGroupSession("meow", mach.OwnIdentity().IdentityKey, outSess.ID())
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID())
require.NoError(t, err)
assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex())
err = inSess.RatchetTo(10)
@ -85,7 +85,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
}
// store sender device identity in receiving machine store
machineIn.CryptoStore.PutDevices("user1", map[id.DeviceID]*id.Device{
machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{
"device1": {
UserID: "user1",
DeviceID: "device1",
@ -97,7 +97,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
// create & store outbound megolm session for sending the event later
megolmOutSession := machineOut.newOutboundGroupSession(context.TODO(), "room1")
megolmOutSession.Shared = true
machineOut.CryptoStore.AddOutboundGroupSession(megolmOutSession)
machineOut.CryptoStore.AddOutboundGroupSession(context.TODO(), megolmOutSession)
// encrypt m.room_key event with olm session
deviceIdentity := &id.Device{
@ -125,7 +125,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
if err != nil {
t.Errorf("Error creating inbound megolm session: %v", err)
}
if err = machineIn.CryptoStore.PutGroupSession("room1", senderKey, igs.ID(), igs); err != nil {
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil {
t.Errorf("Error storing inbound megolm session: %v", err)
}
}

View File

@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -27,7 +27,7 @@ import (
"maunium.net/go/mautrix/id"
)
var PostgresArrayWrapper func(interface{}) interface {
var PostgresArrayWrapper func(any) interface {
driver.Valuer
sql.Scanner
}
@ -62,21 +62,21 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID
}
// Flush does nothing for this implementation as data is already persisted in the database.
func (store *SQLCryptoStore) Flush() error {
func (store *SQLCryptoStore) Flush(_ context.Context) error {
return nil
}
// PutNextBatch stores the next sync batch token for the current account.
func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error {
store.SyncToken = nextBatch
_, err := store.DB.ExecContext(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
_, err := store.DB.Exec(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
return err
}
// GetNextBatch retrieves the next sync batch token for the current account.
func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) {
if store.SyncToken == "" {
err := store.DB.
err := store.DB.Conn(ctx).
QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID).
Scan(&store.SyncToken)
if !errors.Is(err, sql.ErrNoRows) {
@ -111,20 +111,19 @@ func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (st
return nb, nil
}
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
// TODO return error
store.DB.Log.Warn("Failed to scan device ID: %v", err)
func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.DeviceID, err error) {
err = store.DB.QueryRow(ctx, "SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
// PutAccount stores an OlmAccount in the database.
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
store.Account = account
bytes := account.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(`
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
account=excluded.account, account_id=excluded.account_id
@ -133,9 +132,9 @@ func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
}
// GetAccount retrieves an OlmAccount from the database.
func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
if store.Account == nil {
row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
var accountBytes []byte
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
@ -154,7 +153,7 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
}
// HasSession returns whether there is an Olm session for the given sender key.
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
func (store *SQLCryptoStore) HasSession(ctx context.Context, key id.SenderKey) bool {
store.olmSessionCacheLock.Lock()
cache, ok := store.olmSessionCache[key]
store.olmSessionCacheLock.Unlock()
@ -162,17 +161,17 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
return true
}
var sessionID id.SessionID
err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
err := store.DB.QueryRow(ctx, "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
key, store.AccountID).Scan(&sessionID)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return false
}
return len(sessionID) > 0
}
// GetSessions returns all the known Olm sessions for a sender key.
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) {
rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) (OlmSessionList, error) {
rows, err := store.DB.Query(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
key, store.AccountID)
if err != nil {
return nil, err
@ -212,11 +211,11 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session
}
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) {
func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
key, store.AccountID)
sess := OlmSession{Internal: *olm.NewBlankSession()}
@ -224,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
var sessionID id.SessionID
err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
@ -242,20 +241,20 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
}
// AddSession persists an Olm session for a sender in the database.
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error {
func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID)
store.getOlmSessionCache(key)[session.ID()] = session
return err
}
// UpdateSession replaces the Olm session for a sender in the database.
func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error {
func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
_, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID)
return err
}
@ -275,14 +274,14 @@ func datePtr(t time.Time) *time.Time {
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
}
_, err = store.DB.Exec(`
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
@ -301,19 +300,19 @@ func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.Send
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := store.DB.QueryRow(`
err := store.DB.QueryRow(ctx, `
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
@ -327,22 +326,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
Reason: withheldReason.String,
}
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
var chains []string
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
if senderKey == "" {
senderKey = id.Curve25519(senderKeyDB.String)
}
@ -360,8 +344,8 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
}, nil
}
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(`
func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(ctx, `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
@ -369,27 +353,24 @@ func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, ses
return err
}
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
if roomID == "" && senderKey == "" {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
res, err := store.DB.Query(`
res, err := store.DB.Query(ctx, `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
if err != nil {
return nil, err
}
return sessionIDs, err
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
}
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) {
var query string
switch store.DB.Dialect {
case dbutil.Postgres:
@ -413,46 +394,40 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error
default:
return nil, fmt.Errorf("unsupported dialect")
}
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
if err != nil {
return nil, err
}
return sessionIDs, err
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
}
func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
res, err := store.DB.Query(`
func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) {
res, err := store.DB.Query(ctx, `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
if err != nil {
return nil, err
}
return sessionIDs, err
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
return err
}
func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
var code, reason sql.NullString
err := store.DB.QueryRow(`
err := store.DB.QueryRow(ctx, `
SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&code, &reason)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil || !code.Valid {
return nil, err
@ -467,82 +442,79 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey
}, nil
}
func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) {
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
igs = olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return
}
if forwardingChains != "" {
chains = strings.Split(forwardingChains, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return
err = fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return
}
var chains []string
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
result = append(result, &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
})
}
return
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*InboundGroupSession, error) {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err != nil {
return nil, err
}
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
return &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
}, nil
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
if err == sql.ErrNoRows {
return []*InboundGroupSession{}, nil
} else if err != nil {
if err != nil {
return nil, err
}
return store.scanGroupSessionList(rows)
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
}
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
store.AccountID,
)
if err == sql.ErrNoRows {
return []*InboundGroupSession{}, nil
} else if err != nil {
if err != nil {
return nil, err
}
return store.scanGroupSessionList(rows)
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
}
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(`
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_outbound_session
(room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@ -556,24 +528,24 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
}
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error {
func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
_, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID)
return err
}
// GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID.
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
var ogs OutboundGroupSession
var sessionBytes []byte
var maxAgeMS int64
err := store.DB.QueryRow(`
err := store.DB.QueryRow(ctx, `
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
roomID, store.AccountID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
@ -590,8 +562,8 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
}
// RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID.
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
_, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roomID id.RoomID) error {
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
roomID, store.AccountID)
return err
}
@ -608,7 +580,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
`
var expectedEventID id.EventID
var expectedTimestamp int64
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
err := store.DB.QueryRow(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
if err != nil {
return false, err
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
@ -623,69 +595,58 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
return true, nil
}
func scanDevice(rows dbutil.Scannable) (*id.Device, error) {
var device id.Device
err := rows.Scan(&device.UserID, &device.DeviceID, &device.IdentityKey, &device.SigningKey, &device.Trust, &device.Deleted, &device.Name)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &device, nil
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
var ignore id.UserID
err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
if err == sql.ErrNoRows {
err := store.DB.QueryRow(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
if err != nil {
return nil, err
}
data := make(map[id.DeviceID]*id.Device)
for rows.Next() {
var identity id.Device
err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
return nil, err
}
identity.UserID = userID
data[identity.DeviceID] = &identity
err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) {
data[device.DeviceID] = device
return true, nil
})
if err != nil {
return nil, err
}
return data, nil
}
// GetDevice returns the device dentity for a given user and device ID.
func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
var identity id.Device
err := store.DB.QueryRow(`
SELECT identity_key, signing_key, trust, deleted, name
func (store *SQLCryptoStore) GetDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
return scanDevice(store.DB.QueryRow(ctx, `
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1 AND device_id=$2`,
userID, deviceID,
).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
identity.UserID = userID
identity.DeviceID = deviceID
return &identity, nil
))
}
// FindDeviceByKey finds a specific device by its sender key.
func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
var identity id.Device
err := store.DB.QueryRow(`
SELECT device_id, signing_key, trust, deleted, name
func (store *SQLCryptoStore) FindDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
return scanDevice(store.DB.QueryRow(ctx, `
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1 AND identity_key=$2`,
userID, identityKey,
).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
identity.UserID = userID
identity.IdentityKey = identityKey
return &identity, nil
))
}
const deviceInsertQuery = `
@ -698,106 +659,84 @@ ON CONFLICT (user_id, device_id) DO UPDATE
var deviceMassInsertTemplate = strings.ReplaceAll(deviceInsertQuery, "($1, $2, $3, $4, $5, $6, $7)", "%s")
// PutDevice stores a single device for a user, replacing it if it exists already.
func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *id.Device) error {
_, err := store.DB.Exec(deviceInsertQuery,
func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, device *id.Device) error {
_, err := store.DB.Exec(ctx, deviceInsertQuery,
userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name)
return err
}
// PutDevices stores the device identity information for the given user ID.
func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
tx, err := store.DB.Begin()
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
return fmt.Errorf("failed to add user to tracked users list: %w", err)
}
_, err = tx.Exec("UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to delete old devices: %w", err)
}
if len(devices) == 0 {
err = tx.Commit()
func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
return fmt.Errorf("failed to commit changes (no devices added): %w", err)
return fmt.Errorf("failed to add user to tracked users list: %w", err)
}
_, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
if err != nil {
return fmt.Errorf("failed to delete old devices: %w", err)
}
if len(devices) == 0 {
return nil
}
deviceBatchLen := 5 // how many devices will be inserted per query
deviceIDs := make([]id.DeviceID, 0, len(devices))
for deviceID := range devices {
deviceIDs = append(deviceIDs, deviceID)
}
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
var batchDevices []id.DeviceID
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
} else {
batchDevices = deviceIDs[batchDeviceIdx:]
}
values := make([]interface{}, 1, len(devices)*6+1)
values[0] = userID
valueStrings := make([]string, 0, len(devices))
i := 2
for _, deviceID := range batchDevices {
identity := devices[deviceID]
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
i += 6
}
valueString := strings.Join(valueStrings, ",")
_, err = store.DB.Exec(ctx, fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
if err != nil {
return fmt.Errorf("failed to insert new devices: %w", err)
}
}
return nil
}
deviceBatchLen := 5 // how many devices will be inserted per query
deviceIDs := make([]id.DeviceID, 0, len(devices))
for deviceID := range devices {
deviceIDs = append(deviceIDs, deviceID)
}
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
var batchDevices []id.DeviceID
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
} else {
batchDevices = deviceIDs[batchDeviceIdx:]
}
values := make([]interface{}, 1, len(devices)*6+1)
values[0] = userID
valueStrings := make([]string, 0, len(devices))
i := 2
for _, deviceID := range batchDevices {
identity := devices[deviceID]
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
i += 6
}
valueString := strings.Join(valueStrings, ",")
_, err = tx.Exec(fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to insert new devices: %w", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit changes: %w", err)
}
return nil
})
}
// FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information.
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) {
var rows dbutil.Rows
var err error
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
} else {
queryString := make([]string, len(users))
params := make([]interface{}, len(users))
for i, user := range users {
queryString[i] = fmt.Sprintf("$%d", i+1)
queryString[i] = fmt.Sprintf("?%d", i+1)
params[i] = user
}
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
}
if err != nil {
return users, err
}
var ptr int
for rows.Next() {
err = rows.Scan(&users[ptr])
if err != nil {
return users, err
} else {
ptr++
}
}
return users[:ptr], nil
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
}
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
_, err := store.DB.Exec(`
func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
`, userID, usage, key, key)
@ -805,8 +744,8 @@ func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.Cross
}
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
rows, err := store.DB.Query("SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
func (store *SQLCryptoStore) GetCrossSigningKeys(ctx context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
rows, err := store.DB.Query(ctx, "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
if err != nil {
return nil, err
}
@ -825,8 +764,8 @@ func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.Cross
}
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
_, err := store.DB.Exec(`
func (store *SQLCryptoStore) PutSignature(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature
`, signedUserID, signedKey, signerUserID, signerKey, signature)
@ -834,8 +773,8 @@ func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.E
}
// GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer.
func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
func (store *SQLCryptoStore) GetSignaturesForKeyBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
rows, err := store.DB.Query(ctx, "SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
if err != nil {
return nil, err
}
@ -854,18 +793,18 @@ func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25
}
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
func (store *SQLCryptoStore) IsKeySignedBy(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
func (store *SQLCryptoStore) IsKeySignedBy(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
q := `SELECT EXISTS(
SELECT 1 FROM crypto_cross_signing_signatures
WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
)`
err = store.DB.QueryRow(q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
err = store.DB.QueryRow(ctx, q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
return
}
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
res, err := store.DB.Exec(ctx, "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
if err != nil {
return 0, err
}

View File

@ -7,6 +7,7 @@
package sql_store_upgrade
import (
"context"
"embed"
"fmt"
@ -21,7 +22,7 @@ const VersionTableName = "crypto_version"
var fs embed.FS
func init() {
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
})
Table.RegisterFS(fs)

View File

@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,64 +26,64 @@ var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
type Store interface {
// Flush ensures that everything in the store is persisted to disk.
// This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately.
Flush() error
Flush(context.Context) error
// PutAccount updates the OlmAccount in the store.
PutAccount(*OlmAccount) error
PutAccount(context.Context, *OlmAccount) error
// GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount.
GetAccount() (*OlmAccount, error)
GetAccount(ctx context.Context) (*OlmAccount, error)
// AddSession inserts an Olm session into the store.
AddSession(id.SenderKey, *OlmSession) error
AddSession(context.Context, id.SenderKey, *OlmSession) error
// HasSession returns whether or not the store has an Olm session with the given sender key.
HasSession(id.SenderKey) bool
HasSession(context.Context, id.SenderKey) bool
// GetSessions returns all Olm sessions in the store with the given sender key.
GetSessions(id.SenderKey) (OlmSessionList, error)
GetSessions(context.Context, id.SenderKey) (OlmSessionList, error)
// GetLatestSession returns the session with the highest session ID (lexiographically sorting).
// It's usually safe to return the most recently added session if sorting by session ID is too difficult.
GetLatestSession(id.SenderKey) (*OlmSession, error)
GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error)
// UpdateSession updates a session that has previously been inserted with AddSession.
UpdateSession(id.SenderKey, *OlmSession) error
UpdateSession(context.Context, id.SenderKey, *OlmSession) error
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
// sessions inserted with this call.
PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
// GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
RedactExpiredGroupSessions() ([]id.SessionID, error)
RedactExpiredGroupSessions(context.Context) ([]id.SessionID, error)
// RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata.
RedactOutdatedGroupSessions() ([]id.SessionID, error)
RedactOutdatedGroupSessions(context.Context) ([]id.SessionID, error)
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error)
GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error)
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
GetAllGroupSessions() ([]*InboundGroupSession, error)
GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error)
// AddOutboundGroupSession inserts the given outbound Megolm session into the store.
//
// The store should index inserted sessions by the RoomID field to support getting and removing sessions.
// There will only be one outbound session per room ID at a time.
AddOutboundGroupSession(*OutboundGroupSession) error
AddOutboundGroupSession(context.Context, *OutboundGroupSession) error
// UpdateOutboundGroupSession updates the given outbound Megolm session in the store.
UpdateOutboundGroupSession(*OutboundGroupSession) error
UpdateOutboundGroupSession(context.Context, *OutboundGroupSession) error
// GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store.
GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error)
GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error)
// RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID.
RemoveOutboundGroupSession(id.RoomID) error
RemoveOutboundGroupSession(context.Context, id.RoomID) error
// ValidateMessageIndex validates that the given message details aren't from a replay attack.
//
@ -96,29 +96,29 @@ type Store interface {
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
GetDevices(context.Context, id.UserID) (map[id.DeviceID]*id.Device, error)
// GetDevice returns a specific device of a given user.
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
GetDevice(context.Context, id.UserID, id.DeviceID) (*id.Device, error)
// PutDevice stores a single device for a user, replacing it if it exists already.
PutDevice(id.UserID, *id.Device) error
PutDevice(context.Context, id.UserID, *id.Device) error
// PutDevices overrides the stored device list for the given user with the given list.
PutDevices(id.UserID, map[id.DeviceID]*id.Device) error
PutDevices(context.Context, id.UserID, map[id.DeviceID]*id.Device) error
// FindDeviceByKey finds a specific device by its identity key.
FindDeviceByKey(id.UserID, id.IdentityKey) (*id.Device, error)
FindDeviceByKey(context.Context, id.UserID, id.IdentityKey) (*id.Device, error)
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
FilterTrackedUsers([]id.UserID) ([]id.UserID, error)
FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error)
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error
PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
GetCrossSigningKeys(context.Context, id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
PutSignature(signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
PutSignature(ctx context.Context, signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
IsKeySignedBy(userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error)
DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error)
}
type messageIndexKey struct {
@ -170,18 +170,18 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
}
}
func (gs *MemoryStore) Flush() error {
func (gs *MemoryStore) Flush(_ context.Context) error {
gs.lock.Lock()
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) {
return gs.Account, nil
}
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error {
gs.lock.Lock()
gs.Account = account
err := gs.save()
@ -189,7 +189,7 @@ func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
return err
}
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) {
gs.lock.Lock()
sessions, ok := gs.Sessions[senderKey]
if !ok {
@ -200,7 +200,7 @@ func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, erro
return sessions, nil
}
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error {
gs.lock.Lock()
sessions, _ := gs.Sessions[senderKey]
gs.Sessions[senderKey] = append(sessions, session)
@ -210,19 +210,19 @@ func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) e
return err
}
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
return ok && len(sessions) > 0 && !sessions[0].Expired()
}
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
@ -246,7 +246,7 @@ func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey
return sender
}
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
gs.lock.Lock()
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
err := gs.save()
@ -254,7 +254,7 @@ func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey,
return err
}
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
gs.lock.Lock()
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
if !ok {
@ -269,7 +269,7 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey,
return session, nil
}
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
gs.lock.Lock()
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
err := gs.save()
@ -277,7 +277,7 @@ func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderK
return err
}
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
gs.lock.Lock()
var sessionIDs []id.SessionID
if roomID != "" && senderKey != "" {
@ -315,11 +315,11 @@ func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.Sender
return sessionIDs, err
}
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
@ -337,7 +337,7 @@ func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.S
return sender
}
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error {
gs.lock.Lock()
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
err := gs.save()
@ -345,7 +345,7 @@ func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEven
return err
}
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
gs.lock.Lock()
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
gs.lock.Unlock()
@ -355,7 +355,7 @@ func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Se
return session, nil
}
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
room, ok := gs.GroupSessions[roomID]
@ -371,7 +371,7 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGrou
return result, nil
}
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) {
gs.lock.Lock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
@ -385,7 +385,7 @@ func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
return result, nil
}
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {
gs.lock.Lock()
gs.OutGroupSessions[session.RoomID] = session
err := gs.save()
@ -393,12 +393,12 @@ func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) er
return err
}
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
gs.lock.RLock()
session, ok := gs.OutGroupSessions[roomID]
gs.lock.RUnlock()
@ -408,7 +408,7 @@ func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroup
return session, nil
}
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error {
gs.lock.Lock()
session, ok := gs.OutGroupSessions[roomID]
if !ok || session == nil {
@ -443,7 +443,7 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
return true, nil
}
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
gs.lock.RLock()
devices, ok := gs.Devices[userID]
if !ok {
@ -453,7 +453,7 @@ func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device,
return devices, nil
}
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
func (gs *MemoryStore) GetDevice(_ context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@ -467,7 +467,7 @@ func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.De
return device, nil
}
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@ -482,7 +482,7 @@ func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.Identity
return nil, nil
}
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error {
gs.lock.Lock()
devices, ok := gs.Devices[userID]
if !ok {
@ -495,7 +495,7 @@ func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
return err
}
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
gs.lock.Lock()
gs.Devices[userID] = devices
err := gs.save()
@ -503,7 +503,7 @@ func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.
return err
}
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) {
gs.lock.RLock()
var ptr int
for _, userID := range users {
@ -517,7 +517,7 @@ func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error
return users[:ptr], nil
}
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
gs.lock.RLock()
userKeys, ok := gs.CrossSigningKeys[userID]
if !ok {
@ -539,7 +539,7 @@ func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSignin
return err
}
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
keys, ok := gs.CrossSigningKeys[userID]
@ -549,7 +549,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSignin
return keys, nil
}
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
gs.lock.RLock()
signedUserSigs, ok := gs.KeySignatures[signedUserID]
if !ok {
@ -572,7 +572,7 @@ func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519
return err
}
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
userKeys, ok := gs.KeySignatures[userID]
@ -590,8 +590,8 @@ func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, s
return sigsBySigner, nil
}
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
sigs, err := gs.GetSignaturesForKeyBy(ctx, userID, key, signerID)
if err != nil {
return false, err
}
@ -599,7 +599,7 @@ func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID
return ok, nil
}
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
var count int64
gs.lock.RLock()
for _, userSigs := range gs.KeySignatures {

View File

@ -36,7 +36,7 @@ func getCryptoStores(t *testing.T) map[string]Store {
t.Fatalf("Error opening db: %v", err)
}
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
if err = sqlStore.DB.Upgrade(); err != nil {
if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
t.Fatalf("Error creating tables: %v", err)
}
@ -65,8 +65,8 @@ func TestPutAccount(t *testing.T) {
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
acc := NewOlmAccount()
store.PutAccount(acc)
retrieved, err := store.GetAccount()
store.PutAccount(context.TODO(), acc)
retrieved, err := store.GetAccount(context.TODO())
if err != nil {
t.Fatalf("Error retrieving account: %v", err)
}
@ -105,7 +105,7 @@ func TestStoreOlmSession(t *testing.T) {
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
if store.HasSession(olmSessID) {
if store.HasSession(context.TODO(), olmSessID) {
t.Error("Found Olm session before inserting it")
}
olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test"))
@ -117,15 +117,15 @@ func TestStoreOlmSession(t *testing.T) {
id: olmSessID,
Internal: *olmInternal,
}
err = store.AddSession(olmSessID, &olmSess)
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
if err != nil {
t.Errorf("Error storing Olm session: %v", err)
}
if !store.HasSession(olmSessID) {
if !store.HasSession(context.TODO(), olmSessID) {
t.Error("Not found Olm session after inserting it")
}
retrieved, err := store.GetLatestSession(olmSessID)
retrieved, err := store.GetLatestSession(context.TODO(), olmSessID)
if err != nil {
t.Errorf("Failed retrieving Olm session: %v", err)
}
@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) {
RoomID: "room1",
}
err = store.PutGroupSession("room1", acc.IdentityKey(), igs.ID(), igs)
err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs)
if err != nil {
t.Errorf("Error storing inbound group session: %v", err)
}
retrieved, err := store.GetGroupSession("room1", acc.IdentityKey(), igs.ID())
retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID())
if err != nil {
t.Errorf("Error retrieving inbound group session: %v", err)
}
@ -179,7 +179,7 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
stores := getCryptoStores(t)
for storeName, store := range stores {
t.Run(storeName, func(t *testing.T) {
sess, err := store.GetOutboundGroupSession("room1")
sess, err := store.GetOutboundGroupSession(context.TODO(), "room1")
if sess != nil {
t.Error("Got outbound session before inserting")
}
@ -188,12 +188,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
}
outbound := NewOutboundGroupSession("room1", nil)
err = store.AddOutboundGroupSession(outbound)
err = store.AddOutboundGroupSession(context.TODO(), outbound)
if err != nil {
t.Errorf("Error inserting outbound session: %v", err)
}
sess, err = store.GetOutboundGroupSession("room1")
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
if sess == nil {
t.Error("Did not get outbound session after inserting")
}
@ -201,12 +201,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
t.Errorf("Error retrieving outbound session: %v", err)
}
err = store.RemoveOutboundGroupSession("room1")
err = store.RemoveOutboundGroupSession(context.TODO(), "room1")
if err != nil {
t.Errorf("Error deleting outbound session: %v", err)
}
sess, err = store.GetOutboundGroupSession("room1")
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
if sess != nil {
t.Error("Got outbound session after deleting")
}
@ -232,11 +232,11 @@ func TestStoreDevices(t *testing.T) {
SigningKey: acc.SigningKey(),
}
}
err := store.PutDevices("user1", deviceMap)
err := store.PutDevices(context.TODO(), "user1", deviceMap)
if err != nil {
t.Errorf("Error string devices: %v", err)
}
devs, err := store.GetDevices("user1")
devs, err := store.GetDevices(context.TODO(), "user1")
if err != nil {
t.Errorf("Error getting devices: %v", err)
}
@ -250,7 +250,7 @@ func TestStoreDevices(t *testing.T) {
t.Errorf("Last device identity key does not match")
}
filtered, err := store.FilterTrackedUsers([]id.UserID{"user0", "user1", "user2"})
filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"})
if err != nil {
t.Errorf("Error filtering tracked users: %v", err)
} else if len(filtered) != 1 || filtered[0] != "user1" {

View File

@ -507,7 +507,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use
// we can finally trust this device
device.Trust = id.TrustStateVerified
err = mach.CryptoStore.PutDevice(device.UserID, device)
err = mach.CryptoStore.PutDevice(ctx, device.UserID, device)
if err != nil {
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
}
@ -521,7 +521,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
}
} else {
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID)
if err != nil {
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
} else {

2
go.mod
View File

@ -12,7 +12,7 @@ require (
github.com/tidwall/gjson v1.17.0
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.6.0
go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894
go.mau.fi/zeroconfig v0.1.2
golang.org/x/crypto v0.17.0
golang.org/x/exp v0.0.0-20231226003508-02704c960a9b

4
go.sum
View File

@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 h1:zcfIxHgzZpgGSJv/FUVbOjO4ZWa12En4TGhxgUI/QH0=
go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8=
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=

View File

@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -7,6 +7,7 @@
package sqlstatestore
import (
"context"
"database/sql"
"embed"
"encoding/json"
@ -15,6 +16,7 @@ import (
"strconv"
"strings"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
@ -44,26 +46,28 @@ func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge b
}
}
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) (bool, error) {
var isRegistered bool
err := store.
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
Scan(&isRegistered)
if err != nil {
store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return isRegistered
return isRegistered, err
}
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
store.Log.Warn("Failed to mark %s as registered: %v", userID, err)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent {
members := make(map[id.UserID]*event.MemberEventContent)
type Member struct {
id.UserID
event.MemberEventContent
}
func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) {
args := make([]any, len(memberships)+1)
args[0] = roomID
query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1"
@ -75,25 +79,26 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...even
}
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
}
rows, err := store.Query(query, args...)
rows, err := store.Query(ctx, query, args...)
if err != nil {
return members
return nil, err
}
var userID id.UserID
var member event.MemberEventContent
for rows.Next() {
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
} else {
members[userID] = &member
}
}
return members
members := make(map[id.UserID]*event.MemberEventContent)
return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
return
}).Iter(func(m Member) (bool, error) {
members[m.UserID] = &m.MemberEventContent
return true, nil
})
}
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite)
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) {
var memberMap map[id.UserID]*event.MemberEventContent
memberMap, err = store.GetRoomMembers(ctx, roomID, event.MembershipJoin, event.MembershipInvite)
if err != nil {
return
}
members = make([]id.UserID, len(memberMap))
i := 0
for userID := range memberMap {
@ -103,37 +108,39 @@ func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (mem
return
}
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
membership := event.MembershipLeave
err := store.
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
func (store *SQLStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (membership event.Membership, err error) {
err = store.
QueryRow(ctx, "SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&membership)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err)
if errors.Is(err, sql.ErrNoRows) {
membership = event.MembershipLeave
err = nil
}
return membership
return
}
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = event.MembershipLeave
func (store *SQLStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
member, err := store.TryGetMember(ctx, roomID, userID)
if member == nil && err == nil {
member = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return member
return member, err
}
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
func (store *SQLStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
var member event.MemberEventContent
err := store.
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
QueryRow(ctx, "SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &member, err == nil
return &member, nil
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
query := `
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
@ -141,38 +148,32 @@ func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID
`
if !store.IsBridge {
query = `
SELECT mx_user_profile.room_id FROM mx_user_profile
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
`
SELECT mx_user_profile.room_id FROM mx_user_profile
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
`
}
rows, err := store.Query(query, userID)
rows, err := store.Query(ctx, query, userID)
if err != nil {
store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err)
return
return nil, err
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.Log.Warn("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
}
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership, err := store.GetMembership(ctx, roomID, userID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get membership")
return false
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
@ -181,27 +182,23 @@ func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, all
return false
}
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
_, err := store.Exec(`
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
`, roomID, userID, membership)
if err != nil {
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
}
return err
}
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
_, err := store.Exec(`
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
if err != nil {
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
return err
}
func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
params := make([]any, len(memberships)+1)
params[0] = roomID
@ -213,109 +210,85 @@ func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...
}
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
}
_, err := store.Exec(query, params...)
if err != nil {
store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err)
}
_, err := store.Exec(ctx, query, params...)
return err
}
func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
contentBytes, err := json.Marshal(content)
if err != nil {
store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err)
return
return fmt.Errorf("failed to marshal content JSON: %w", err)
}
_, err = store.Exec(`
_, err = store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption
`, roomID, contentBytes)
if err != nil {
store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err)
}
return err
}
func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err)
}
return nil
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
} else if data == nil {
return nil
return nil, nil
}
content := &event.EncryptionEventContent{}
err = json.Unmarshal(data, content)
var content event.EncryptionEventContent
err = json.Unmarshal(data, &content)
if err != nil {
store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err)
return nil
return nil, fmt.Errorf("failed to parse content JSON: %w", err)
}
return content
return &content, nil
}
func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool {
cfg := store.GetEncryptionEvent(roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
}
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels)
if err != nil {
store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err)
return
}
_, err = store.Exec(`
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
`, roomID, levelsBytes)
if err != nil {
store.Log.Warn("Failed to store power levels of %s: %v", roomID, err)
}
`, roomID, dbutil.JSON{Data: levels})
return err
}
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
var data []byte
err := store.
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err)
}
return
} else if data == nil {
return
}
levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels)
if err != nil {
store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err)
return nil
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
err = store.
QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&dbutil.JSON{Data: &levels})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
if store.Dialect == dbutil.Postgres {
var powerLevel int
err := store.
QueryRow(`
QueryRow(ctx, `
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
FROM mx_room_state WHERE room_id=$1
`, roomID, userID).
Scan(&powerLevel)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err)
return powerLevel, err
} else {
levels, err := store.GetPowerLevels(ctx, roomID)
if err != nil {
return 0, err
}
return powerLevel
return levels.GetUserLevel(userID), nil
}
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
@ -325,23 +298,26 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType
}
var powerLevel int
err := store.
QueryRow(`
QueryRow(ctx, `
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
FROM mx_room_state WHERE room_id=$1
`, roomID, eventType.Type, defaultType, defaultValue).
Scan(&powerLevel)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue
if errors.Is(err, sql.ErrNoRows) {
err = nil
powerLevel = defaultValue
}
return powerLevel
return powerLevel, err
} else {
levels, err := store.GetPowerLevels(ctx, roomID)
if err != nil {
return 0, err
}
return levels.GetEventLevel(eventType), nil
}
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
@ -351,19 +327,22 @@ func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, ev
}
var hasPower bool
err := store.
QueryRow(`SELECT
QueryRow(ctx, `SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
>=
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
Scan(&hasPower)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue == 0
if errors.Is(err, sql.ErrNoRows) {
err = nil
hasPower = defaultValue == 0
}
return hasPower
return hasPower, err
} else {
levels, err := store.GetPowerLevels(ctx, roomID)
if err != nil {
return false, err
}
return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
}
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View File

@ -1,19 +1,20 @@
package sqlstatestore
import (
"context"
"fmt"
"go.mau.fi/util/dbutil"
)
func init() {
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error {
portalExists, err := db.TableExists(tx, "portal")
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error {
portalExists, err := db.TableExists(ctx, "portal")
if err != nil {
return fmt.Errorf("failed to check if portal table exists")
}
if portalExists {
_, err = tx.Exec(`
_, err = db.Exec(ctx, `
INSERT INTO mx_room_state (room_id, encryption)
SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
ON CONFLICT (room_id) DO UPDATE

View File

@ -7,33 +7,37 @@
package mautrix
import (
"context"
"sync"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// StateStore is an interface for storing basic room state information.
type StateStore interface {
IsInRoom(roomID id.RoomID, userID id.UserID) bool
IsInvited(roomID id.RoomID, userID id.UserID) bool
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership)
IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent)
IsEncrypted(roomID id.RoomID) bool
SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error
IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error)
GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error)
GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error)
}
func UpdateStateStore(store StateStore, evt *event.Event) {
func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
if store == nil || evt == nil || evt.StateKey == nil {
return
}
@ -41,13 +45,20 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
return
}
var err error
switch content := evt.Content.Parsed.(type) {
case *event.MemberEventContent:
store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content)
case *event.PowerLevelsEventContent:
store.SetPowerLevels(evt.RoomID, content)
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
store.SetEncryptionEvent(evt.RoomID, content)
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
}
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
Msg("Failed to update state store")
}
}
@ -57,7 +68,7 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
//
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
UpdateStateStore(cli.StateStore, evt)
UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt)
}
type MemoryStateStore struct {
@ -81,20 +92,21 @@ func NewMemoryStateStore() StateStore {
}
}
func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool {
func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) {
store.registrationsLock.RLock()
defer store.registrationsLock.RUnlock()
registered, ok := store.Registrations[userID]
return ok && registered
return ok && registered, nil
}
func (store *MemoryStateStore) MarkRegistered(userID id.UserID) {
func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error {
store.registrationsLock.Lock()
defer store.registrationsLock.Unlock()
store.Registrations[userID] = true
return nil
}
func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
store.membersLock.RLock()
members, ok := store.Members[roomID]
store.membersLock.RUnlock()
@ -104,11 +116,14 @@ func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*e
store.Members[roomID] = members
store.membersLock.Unlock()
}
return members
return members, nil
}
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) {
members := store.GetRoomMembers(roomID)
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
members, err := store.GetRoomMembers(ctx, roomID)
if err != nil {
return nil, err
}
ids := make([]id.UserID, 0, len(members))
for id := range members {
ids = append(ids, id)
@ -116,39 +131,39 @@ func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (
return ids, nil
}
func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
return store.GetMember(roomID, userID).Membership
func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) {
return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil
}
func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
member, err := store.TryGetMember(ctx, roomID, userID)
if member == nil && err == nil {
member = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return member
return member, err
}
func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
store.membersLock.RLock()
defer store.membersLock.RUnlock()
members, membersOk := store.Members[roomID]
if !membersOk {
return
}
member, ok = members[userID]
member = members[userID]
return
}
func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join")
}
func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join", "invite")
}
func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := exerrors.Must(store.GetMembership(ctx, roomID, userID))
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
@ -157,7 +172,7 @@ func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID,
return false
}
func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
@ -175,9 +190,10 @@ func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID,
}
store.Members[roomID] = members
store.membersLock.Unlock()
return nil
}
func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
@ -189,14 +205,15 @@ func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, mem
}
store.Members[roomID] = members
store.membersLock.Unlock()
return nil
}
func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error {
store.membersLock.Lock()
defer store.membersLock.Unlock()
members, ok := store.Members[roomID]
if !ok {
return
return nil
}
for userID, member := range members {
for _, membership := range memberships {
@ -206,46 +223,49 @@ func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships
}
}
}
return nil
}
func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
store.powerLevelsLock.Lock()
store.PowerLevels[roomID] = levels
store.powerLevelsLock.Unlock()
return nil
}
func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
store.powerLevelsLock.RUnlock()
return
}
func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
return store.GetPowerLevels(roomID).GetUserLevel(userID)
func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil
}
func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil
}
func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
store.encryptionLock.Unlock()
return nil
}
func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
store.encryptionLock.RLock()
defer store.encryptionLock.RUnlock()
return store.Encryption[roomID]
return store.Encryption[roomID], nil
}
func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool {
cfg := store.GetEncryptionEvent(roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
}