mirror of https://github.com/mautrix/go.git
Add more contexts everywhere
parent
0a302c753d
commit
25bc36bc7a
|
@ -2,8 +2,8 @@
|
|||
|
||||
* **Breaking change *(bridge)*** Added raw event to portal membership handling
|
||||
functions.
|
||||
* **Breaking change *(client)*** Added context parameters to all functions
|
||||
(thanks to [@recht] in [#144]).
|
||||
* **Breaking change *(everything)*** Added context parameters to all functions
|
||||
(started by [@recht] in [#144]).
|
||||
* *(crypto)* Added experimental pure Go Olm implementation to replace libolm
|
||||
(thanks to [@DerLukas15] in [#106]).
|
||||
* You can use the `goolm` build tag to the new implementation.
|
||||
|
|
|
@ -93,12 +93,12 @@ type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
|
|||
type StateStore interface {
|
||||
mautrix.StateStore
|
||||
|
||||
IsRegistered(userID id.UserID) bool
|
||||
MarkRegistered(userID id.UserID)
|
||||
IsRegistered(ctx context.Context, userID id.UserID) (bool, error)
|
||||
MarkRegistered(ctx context.Context, userID id.UserID) error
|
||||
|
||||
GetPowerLevel(roomID id.RoomID, userID id.UserID) int
|
||||
GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int
|
||||
HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool
|
||||
GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error)
|
||||
GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error)
|
||||
HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error)
|
||||
}
|
||||
|
||||
// AppService is the main config for all appservices.
|
||||
|
|
|
@ -236,7 +236,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def
|
|||
}
|
||||
|
||||
if evt.Type.IsState() {
|
||||
mautrix.UpdateStateStore(as.StateStore, evt)
|
||||
mautrix.UpdateStateStore(ctx, as.StateStore, evt)
|
||||
}
|
||||
var ch chan *event.Event
|
||||
if evt.Type.Class == event.ToDeviceEventType {
|
||||
|
|
|
@ -13,6 +13,8 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -57,17 +59,26 @@ func (intent *IntentAPI) Register(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error {
|
||||
if intent.IsCustomPuppet {
|
||||
return nil
|
||||
}
|
||||
intent.registerLock.Lock()
|
||||
defer intent.registerLock.Unlock()
|
||||
if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) {
|
||||
isRegistered, err := intent.as.StateStore.IsRegistered(ctx, intent.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if user is registered: %w", err)
|
||||
} else if isRegistered {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := intent.Register(ctx)
|
||||
err = intent.Register(ctx)
|
||||
if err != nil && !errors.Is(err, mautrix.MUserInUse) {
|
||||
return fmt.Errorf("failed to ensure registered: %w", err)
|
||||
}
|
||||
intent.as.StateStore.MarkRegistered(intent.UserID)
|
||||
err = intent.as.StateStore.MarkRegistered(ctx, intent.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark user as registered in state store: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -83,7 +94,7 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
|
|||
} else if len(extra) == 1 {
|
||||
params = extra[0]
|
||||
}
|
||||
if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache {
|
||||
if intent.as.StateStore.IsInRoom(ctx, roomID, intent.UserID) && !params.IgnoreCache {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -111,7 +122,10 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext
|
|||
return fmt.Errorf("failed to ensure joined after invite: %w", err)
|
||||
}
|
||||
}
|
||||
intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin)
|
||||
err = intent.as.StateStore.SetMembership(ctx, resp.RoomID, intent.UserID, event.MembershipJoin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set membership in state store: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -205,13 +219,14 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
|
|||
Membership: membership,
|
||||
Reason: reason,
|
||||
}
|
||||
memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target)
|
||||
if !ok {
|
||||
memberContent, err := intent.as.StateStore.TryGetMember(ctx, roomID, target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get old member content from state store: %w", err)
|
||||
} else if memberContent == nil {
|
||||
if intent.as.GetProfile != nil {
|
||||
memberContent = intent.as.GetProfile(target, roomID)
|
||||
ok = memberContent != nil
|
||||
}
|
||||
if !ok {
|
||||
if memberContent == nil {
|
||||
profile, err := intent.GetProfile(ctx, target)
|
||||
if err != nil {
|
||||
intent.Log.Debug().Err(err).
|
||||
|
@ -224,7 +239,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i
|
|||
}
|
||||
}
|
||||
}
|
||||
if ok && memberContent != nil {
|
||||
if memberContent != nil {
|
||||
content.Displayname = memberContent.Displayname
|
||||
content.AvatarURL = memberContent.AvatarURL
|
||||
}
|
||||
|
@ -297,15 +312,25 @@ func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *m
|
|||
}
|
||||
|
||||
func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := intent.as.StateStore.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member, err := intent.as.StateStore.TryGetMember(ctx, roomID, userID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Str("room_id", roomID.String()).
|
||||
Str("user_id", userID.String()).
|
||||
Msg("Failed to get member from state store")
|
||||
}
|
||||
if member == nil {
|
||||
_ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member)
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
|
||||
pl = intent.as.StateStore.GetPowerLevels(roomID)
|
||||
pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get cached power levels: %w", err)
|
||||
return
|
||||
}
|
||||
if pl == nil {
|
||||
pl = &event.PowerLevelsEventContent{}
|
||||
err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl)
|
||||
|
@ -417,7 +442,7 @@ func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error
|
|||
}
|
||||
|
||||
func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error {
|
||||
if !intent.as.StateStore.IsInvited(roomID, userID) {
|
||||
if !intent.as.StateStore.IsInvited(ctx, roomID, userID) {
|
||||
_, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{
|
||||
UserID: userID,
|
||||
})
|
||||
|
|
|
@ -10,9 +10,8 @@ import (
|
|||
"os"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"go.mau.fi/util/random"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Registration contains the data in a Matrix appservice registration.
|
||||
|
|
|
@ -215,15 +215,15 @@ type Bridge struct {
|
|||
|
||||
type Crypto interface {
|
||||
HandleMemberEvent(*event.Event)
|
||||
Decrypt(*event.Event) (*event.Event, error)
|
||||
Encrypt(id.RoomID, event.Type, *event.Content) error
|
||||
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
Decrypt(context.Context, *event.Event) (*event.Event, error)
|
||||
Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error
|
||||
WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
|
||||
ResetSession(id.RoomID)
|
||||
Init() error
|
||||
ResetSession(context.Context, id.RoomID)
|
||||
Init(ctx context.Context) error
|
||||
Start()
|
||||
Stop()
|
||||
Reset(startAfterReset bool)
|
||||
Reset(ctx context.Context, startAfterReset bool)
|
||||
Client() *mautrix.Client
|
||||
ShareKeys(context.Context) error
|
||||
}
|
||||
|
@ -650,10 +650,10 @@ func (br *Bridge) WaitWebsocketConnected() {
|
|||
|
||||
func (br *Bridge) start() {
|
||||
br.ZLog.Debug().Msg("Running database upgrades")
|
||||
err := br.DB.Upgrade()
|
||||
err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
br.LogDBUpgradeErrorAndExit("main", err)
|
||||
} else if err = br.StateStore.Upgrade(); err != nil {
|
||||
} else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil {
|
||||
br.LogDBUpgradeErrorAndExit("matrix_state", err)
|
||||
}
|
||||
|
||||
|
@ -679,7 +679,7 @@ func (br *Bridge) start() {
|
|||
go br.fetchMediaConfig(ctx)
|
||||
|
||||
if br.Crypto != nil {
|
||||
err = br.Crypto.Init()
|
||||
err = br.Crypto.Init(ctx)
|
||||
if err != nil {
|
||||
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption")
|
||||
os.Exit(19)
|
||||
|
|
|
@ -17,7 +17,7 @@ var CommandDiscardMegolmSession = &FullHandler{
|
|||
if ce.Bridge.Crypto == nil {
|
||||
ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled")
|
||||
} else {
|
||||
ce.Bridge.Crypto.ResetSession(ce.RoomID)
|
||||
ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID)
|
||||
ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.")
|
||||
}
|
||||
},
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -61,7 +61,7 @@ func NewCryptoHelper(bridge *Bridge) Crypto {
|
|||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Init() error {
|
||||
func (helper *CryptoHelper) Init(ctx context.Context) error {
|
||||
if len(helper.bridge.CryptoPickleKey) == 0 {
|
||||
panic("CryptoPickleKey not set")
|
||||
}
|
||||
|
@ -75,13 +75,13 @@ func (helper *CryptoHelper) Init() error {
|
|||
helper.bridge.CryptoPickleKey,
|
||||
)
|
||||
|
||||
err := helper.store.DB.Upgrade()
|
||||
err := helper.store.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
helper.bridge.LogDBUpgradeErrorAndExit("crypto", err)
|
||||
}
|
||||
|
||||
var isExistingDevice bool
|
||||
helper.client, isExistingDevice, err = helper.loginBot(context.TODO())
|
||||
helper.client, isExistingDevice, err = helper.loginBot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ func (helper *CryptoHelper) Init() error {
|
|||
}
|
||||
|
||||
if encryptionConfig.DeleteKeys.DeleteOutdatedInbound {
|
||||
deleted, err := helper.store.RedactOutdatedGroupSessions()
|
||||
deleted, err := helper.store.RedactOutdatedGroupSessions(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -123,12 +123,12 @@ func (helper *CryptoHelper) Init() error {
|
|||
helper.client.Syncer = &cryptoSyncer{helper.mach}
|
||||
helper.client.Store = helper.store
|
||||
|
||||
err = helper.mach.Load()
|
||||
err = helper.mach.Load(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isExistingDevice {
|
||||
helper.verifyKeysAreOnServer(context.TODO())
|
||||
helper.verifyKeysAreOnServer(ctx)
|
||||
}
|
||||
|
||||
go helper.resyncEncryptionInfo(context.TODO())
|
||||
|
@ -138,22 +138,16 @@ func (helper *CryptoHelper) Init() error {
|
|||
|
||||
func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
|
||||
log := helper.log.With().Str("action", "resync encryption event").Logger()
|
||||
rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
|
||||
rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to query rooms for resync")
|
||||
return
|
||||
}
|
||||
var roomIDs []id.RoomID
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to scan room ID")
|
||||
continue
|
||||
}
|
||||
roomIDs = append(roomIDs, roomID)
|
||||
roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to scan rooms for resync")
|
||||
return
|
||||
}
|
||||
_ = rows.Close()
|
||||
if len(roomIDs) > 0 {
|
||||
log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms")
|
||||
for _, roomID := range roomIDs {
|
||||
|
@ -161,7 +155,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
|
|||
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
|
||||
if err != nil {
|
||||
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
|
||||
_, err = helper.bridge.DB.ExecContext(ctx, `
|
||||
_, err = helper.bridge.DB.Exec(ctx, `
|
||||
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
|
||||
`, roomID)
|
||||
if err != nil {
|
||||
|
@ -182,7 +176,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) {
|
|||
Int("max_messages", maxMessages).
|
||||
Interface("content", &evt).
|
||||
Msg("Resynced encryption event")
|
||||
_, err = helper.bridge.DB.ExecContext(ctx, `
|
||||
_, err = helper.bridge.DB.Exec(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET max_age=$1, max_messages=$2
|
||||
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
|
||||
|
@ -223,8 +217,10 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device
|
|||
}
|
||||
|
||||
func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) {
|
||||
deviceID := helper.store.FindDeviceID()
|
||||
if len(deviceID) > 0 {
|
||||
deviceID, err := helper.store.FindDeviceID(ctx)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to find existing device ID: %w", err)
|
||||
} else if len(deviceID) > 0 {
|
||||
helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
|
||||
}
|
||||
// Create a new client instance with the default AS settings (including as_token),
|
||||
|
@ -270,7 +266,7 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto")
|
||||
helper.Reset(false)
|
||||
helper.Reset(ctx, false)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Start() {
|
||||
|
@ -306,16 +302,16 @@ func (helper *CryptoHelper) Stop() {
|
|||
helper.syncDone.Wait()
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) clearDatabase() {
|
||||
_, err := helper.store.DB.Exec("DELETE FROM crypto_account")
|
||||
func (helper *CryptoHelper) clearDatabase(ctx context.Context) {
|
||||
_, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account")
|
||||
if err != nil {
|
||||
helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table")
|
||||
}
|
||||
_, err = helper.store.DB.Exec("DELETE FROM crypto_olm_session")
|
||||
_, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session")
|
||||
if err != nil {
|
||||
helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table")
|
||||
}
|
||||
_, err = helper.store.DB.Exec("DELETE FROM crypto_megolm_outbound_session")
|
||||
_, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session")
|
||||
if err != nil {
|
||||
helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table")
|
||||
}
|
||||
|
@ -325,22 +321,22 @@ func (helper *CryptoHelper) clearDatabase() {
|
|||
//_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures")
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Reset(startAfterReset bool) {
|
||||
func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) {
|
||||
helper.lock.Lock()
|
||||
defer helper.lock.Unlock()
|
||||
helper.log.Info().Msg("Resetting end-to-bridge encryption device")
|
||||
helper.Stop()
|
||||
helper.log.Debug().Msg("Crypto syncer stopped, clearing database")
|
||||
helper.clearDatabase()
|
||||
helper.clearDatabase(ctx)
|
||||
helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions")
|
||||
_, err := helper.client.LogoutAll(context.TODO())
|
||||
_, err := helper.client.LogoutAll(ctx)
|
||||
if err != nil {
|
||||
helper.log.Warn().Err(err).Msg("Failed to log out all devices")
|
||||
}
|
||||
helper.client = nil
|
||||
helper.store = nil
|
||||
helper.mach = nil
|
||||
err = helper.Init()
|
||||
err = helper.Init(ctx)
|
||||
if err != nil {
|
||||
helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption")
|
||||
os.Exit(50)
|
||||
|
@ -355,25 +351,24 @@ func (helper *CryptoHelper) Client() *mautrix.Client {
|
|||
return helper.client
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
|
||||
func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
return helper.mach.DecryptMegolmEvent(ctx, evt)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content *event.Content) (err error) {
|
||||
func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) {
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
var encrypted *event.EncryptedEventContent
|
||||
ctx := context.TODO()
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
|
||||
if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().Err(err).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Got error while encrypting event for room, sharing group session and trying again...")
|
||||
var users []id.UserID
|
||||
users, err = helper.store.GetRoomJoinedOrInvitedMembers(roomID)
|
||||
users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get room member list: %w", err)
|
||||
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
|
||||
|
@ -389,10 +384,10 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
|
|||
return
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
|
||||
return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
|
@ -419,10 +414,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
|
|||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) ResetSession(roomID id.RoomID) {
|
||||
func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) {
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID)
|
||||
err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
helper.log.Debug().Err(err).
|
||||
Str("room_id", roomID.String()).
|
||||
|
@ -499,18 +494,18 @@ type cryptoStateStore struct {
|
|||
|
||||
var _ crypto.StateStore = (*cryptoStateStore)(nil)
|
||||
|
||||
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
|
||||
func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) {
|
||||
portal := c.bridge.Child.GetIPortal(id)
|
||||
if portal != nil {
|
||||
return portal.IsEncrypted()
|
||||
return portal.IsEncrypted(), nil
|
||||
}
|
||||
return c.bridge.StateStore.IsEncrypted(id)
|
||||
return c.bridge.StateStore.IsEncrypted(ctx, id)
|
||||
}
|
||||
|
||||
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
|
||||
return c.bridge.StateStore.FindSharedRooms(id)
|
||||
func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) {
|
||||
return c.bridge.StateStore.FindSharedRooms(ctx, id)
|
||||
}
|
||||
|
||||
func (c *cryptoStateStore) GetEncryptionEvent(id id.RoomID) *event.EncryptionEventContent {
|
||||
return c.bridge.StateStore.GetEncryptionEvent(id)
|
||||
func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) {
|
||||
return c.bridge.StateStore.GetEncryptionEvent(ctx, id)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
|
@ -36,9 +38,9 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id
|
|||
}
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
|
||||
func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) {
|
||||
var rows dbutil.Rows
|
||||
rows, err = store.DB.Query(`
|
||||
rows, err = store.DB.Query(ctx, `
|
||||
SELECT user_id FROM mx_user_profile
|
||||
WHERE room_id=$1
|
||||
AND (membership='join' OR membership='invite')
|
||||
|
|
|
@ -494,7 +494,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
|
|||
log.Debug().Msg("Decrypting received event")
|
||||
|
||||
decryptionStart := time.Now()
|
||||
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
|
||||
decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt)
|
||||
decryptionRetryCount := 0
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
decryptionRetryCount = 1
|
||||
|
@ -502,9 +502,9 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
|
|||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0)
|
||||
if mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err = mx.bridge.Crypto.Decrypt(evt)
|
||||
decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt)
|
||||
} else {
|
||||
go mx.waitLongerForSession(ctx, evt, decryptionStart)
|
||||
return
|
||||
|
@ -529,14 +529,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev
|
|||
go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false)
|
||||
|
||||
if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
log.Debug().Msg("Didn't get session, giving up trying to decrypt event")
|
||||
mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
|
||||
decrypted, err := mx.bridge.Crypto.Decrypt(evt)
|
||||
decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decrypt event")
|
||||
mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true)
|
||||
|
|
113
client.go
113
client.go
|
@ -27,11 +27,11 @@ import (
|
|||
)
|
||||
|
||||
type CryptoHelper interface {
|
||||
Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error)
|
||||
Decrypt(*event.Event) (*event.Event, error)
|
||||
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
Encrypt(context.Context, id.RoomID, event.Type, any) (*event.EncryptedEventContent, error)
|
||||
Decrypt(context.Context, *event.Event) (*event.Event, error)
|
||||
WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
|
||||
Init() error
|
||||
Init(context.Context) error
|
||||
}
|
||||
|
||||
// Deprecated: switch to zerolog
|
||||
|
@ -846,7 +846,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin
|
|||
}
|
||||
_, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -858,7 +861,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin
|
|||
func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) {
|
||||
_, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1000,13 +1006,20 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
|
|||
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
|
||||
}
|
||||
|
||||
if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted && cli.StateStore.IsEncrypted(roomID) {
|
||||
contentJSON, err = cli.Crypto.Encrypt(roomID, eventType, contentJSON)
|
||||
if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted {
|
||||
var isEncrypted bool
|
||||
isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event: %w", err)
|
||||
err = fmt.Errorf("failed to check if room is encrypted: %w", err)
|
||||
return
|
||||
}
|
||||
eventType = event.EventEncrypted
|
||||
if isEncrypted {
|
||||
if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event: %w", err)
|
||||
return
|
||||
}
|
||||
eventType = event.EventEncrypted
|
||||
}
|
||||
}
|
||||
|
||||
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
|
||||
|
@ -1021,7 +1034,7 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy
|
|||
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
|
||||
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
|
||||
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1034,7 +1047,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID,
|
|||
})
|
||||
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
|
||||
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1100,19 +1113,29 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re
|
|||
urlPath := cli.BuildClientURL("v3", "createRoom")
|
||||
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
if storeErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(storeErr).
|
||||
Stringer("creator_user_id", cli.UserID).
|
||||
Msg("Failed to update creator membership in state store after creating room")
|
||||
}
|
||||
for _, evt := range req.InitialState {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
UpdateStateStore(ctx, cli.StateStore, evt)
|
||||
}
|
||||
inviteMembership := event.MembershipInvite
|
||||
if req.BeeperAutoJoinInvites {
|
||||
inviteMembership = event.MembershipJoin
|
||||
}
|
||||
for _, invitee := range req.Invite {
|
||||
cli.StateStore.SetMembership(resp.RoomID, invitee, inviteMembership)
|
||||
storeErr = cli.StateStore.SetMembership(ctx, resp.RoomID, invitee, inviteMembership)
|
||||
if storeErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(storeErr).
|
||||
Stringer("invitee_user_id", invitee).
|
||||
Msg("Failed to update membership in state store after creating room")
|
||||
}
|
||||
}
|
||||
for _, evt := range req.InitialState {
|
||||
cli.updateStoreWithOutgoingEvent(resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
|
||||
cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -1129,7 +1152,10 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "leave")
|
||||
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave)
|
||||
err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update membership in state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1146,7 +1172,10 @@ func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInv
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "invite")
|
||||
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite)
|
||||
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update membership in state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1163,7 +1192,10 @@ func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickU
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "kick")
|
||||
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
|
||||
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update membership in state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1173,7 +1205,10 @@ func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUse
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "ban")
|
||||
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan)
|
||||
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update membership in state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1183,7 +1218,10 @@ func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnba
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "unban")
|
||||
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
|
||||
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to update membership in state store: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1216,7 +1254,7 @@ func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err
|
|||
return
|
||||
}
|
||||
|
||||
func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) {
|
||||
func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) {
|
||||
if cli.StateStore == nil {
|
||||
return
|
||||
}
|
||||
|
@ -1246,7 +1284,7 @@ func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType even
|
|||
}
|
||||
return
|
||||
}
|
||||
UpdateStateStore(cli.StateStore, fakeEvt)
|
||||
UpdateStateStore(ctx, cli.StateStore, fakeEvt)
|
||||
}
|
||||
|
||||
// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with
|
||||
|
@ -1256,7 +1294,7 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
|
||||
_, err = cli.MakeRequest(ctx, "GET", u, nil, outContent)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent)
|
||||
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1310,10 +1348,13 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
|
|||
Handler: parseRoomStateArray,
|
||||
})
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.ClearCachedMembers(roomID)
|
||||
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID)
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to clear cached member list after fetching state")
|
||||
for _, evts := range stateMap {
|
||||
for _, evt := range evts {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
UpdateStateStore(ctx, cli.StateStore, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1630,13 +1671,22 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R
|
|||
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
|
||||
_, err = cli.MakeRequest(ctx, "GET", u, nil, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin)
|
||||
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin)
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to clear cached member list after fetching joined members")
|
||||
for userID, member := range resp.Joined {
|
||||
cli.StateStore.SetMember(roomID, userID, &event.MemberEventContent{
|
||||
updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{
|
||||
Membership: event.MembershipJoin,
|
||||
AvatarURL: id.ContentURIString(member.AvatarURL),
|
||||
Displayname: member.DisplayName,
|
||||
})
|
||||
if updateErr != nil {
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Stringer("user_id", userID).
|
||||
Msg("Failed to update membership in state store after fetching joined members")
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -1665,10 +1715,13 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
|
|||
clearMemberships = append(clearMemberships, extra.Membership)
|
||||
}
|
||||
if extra.NotMembership == "" {
|
||||
cli.StateStore.ClearCachedMembers(roomID, clearMemberships...)
|
||||
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...)
|
||||
cli.cliOrContextLog(ctx).Warn().Err(clearErr).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to clear cached member list after fetching joined members")
|
||||
}
|
||||
for _, evt := range resp.Chunk {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
UpdateStateStore(ctx, cli.StateStore, evt)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -42,7 +42,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *Cross
|
|||
}
|
||||
|
||||
func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) {
|
||||
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys from database: %w", err)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -34,8 +34,8 @@ var (
|
|||
ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC")
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
|
||||
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
|
||||
func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
|
||||
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err)
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe
|
|||
Str("signature", signature).
|
||||
Msg("Signed master key of user with our user-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
|
|||
return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures)
|
||||
}
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
|
@ -178,7 +178,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er
|
|||
Str("signature", signature).
|
||||
Msg("Signed own device key with self-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -19,7 +19,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
|||
log := mach.machOrContextLog(ctx)
|
||||
for userID, userKeys := range crossSigningKeys {
|
||||
log := log.With().Str("user_id", userID.String()).Logger()
|
||||
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Msg("Error fetching current cross-signing keys of user")
|
||||
|
@ -32,7 +32,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
|||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
log.Debug().
|
||||
|
@ -50,7 +50,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
|||
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
|
||||
for _, usage := range userKeys.Usage {
|
||||
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
|
||||
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
|
||||
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
|
||||
log.Error().Err(err).Msg("Error storing cross-signing key")
|
||||
}
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
|||
} else {
|
||||
if verified {
|
||||
log.Debug().Err(err).Msg("Cross-signing key signature verified")
|
||||
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
|
||||
err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error storing cross-signing key signature")
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func getOlmMachine(t *testing.T) *OlmMachine {
|
|||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
|
||||
if err = sqlStore.DB.Upgrade(); err != nil {
|
||||
if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
|
||||
|
@ -41,9 +41,9 @@ func getOlmMachine(t *testing.T) *OlmMachine {
|
|||
ssk, _ := olm.NewPkSigning()
|
||||
usk, _ := olm.NewPkSigning()
|
||||
|
||||
sqlStore.PutCrossSigningKey(userID, id.XSUsageMaster, mk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(userID, id.XSUsageSelfSigning, ssk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(userID, id.XSUsageUserSigning, usk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey)
|
||||
|
||||
return &OlmMachine{
|
||||
CryptoStore: sqlStore,
|
||||
|
@ -70,9 +70,9 @@ func TestTrustOwnDevice(t *testing.T) {
|
|||
t.Error("Own device trusted while it shouldn't be")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey,
|
||||
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
|
||||
m.CryptoStore.PutSignature(ownDevice.UserID, ownDevice.SigningKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
|
||||
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
|
||||
|
@ -91,20 +91,20 @@ func TestTrustOtherUser(t *testing.T) {
|
|||
}
|
||||
|
||||
theirMasterKey, _ := olm.NewPkSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
|
||||
|
||||
m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
|
||||
|
||||
// sign them with self-signing instead of user-signing key
|
||||
m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
|
||||
t.Error("Other user trusted before their master key has been signed with our user-signing key")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
|
||||
|
@ -128,27 +128,27 @@ func TestTrustOtherDevice(t *testing.T) {
|
|||
}
|
||||
|
||||
theirMasterKey, _ := olm.NewPkSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
|
||||
theirSSK, _ := olm.NewPkSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey)
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey)
|
||||
|
||||
m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
|
||||
m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
|
||||
t.Error("Other user not trusted while they should be")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(otherUser, theirSSK.PublicKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey,
|
||||
otherUser, theirMasterKey.PublicKey, "sig3")
|
||||
|
||||
if m.IsDeviceTrusted(theirDevice) {
|
||||
t.Error("Other device trusted before it has been signed with user's SSK")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(otherUser, theirDevice.SigningKey,
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
|
||||
otherUser, theirSSK.PublicKey, "sig4")
|
||||
|
||||
if !m.IsDeviceTrusted(theirDevice) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -23,7 +23,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
|||
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
|
||||
return device.Trust, nil
|
||||
}
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
|
@ -44,7 +44,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
|||
Msg("Self-signing key of user not found")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
|
||||
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
|
@ -57,7 +57,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
|||
Msg("Self-signing key of user is not signed by their master key")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
|
||||
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
|
@ -97,14 +97,14 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
|
|||
return true, nil
|
||||
}
|
||||
// first we verify our user-signing key
|
||||
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
|
||||
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
|
||||
return false, err
|
||||
} else if !ourUserSigningKeyTrusted {
|
||||
return false, nil
|
||||
}
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
|
@ -118,7 +118,7 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
|
|||
Msg("Master key of user not found")
|
||||
return false, nil
|
||||
}
|
||||
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
|
||||
sigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
|
|
|
@ -105,7 +105,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Init() error {
|
||||
func (helper *CryptoHelper) Init(ctx context.Context) error {
|
||||
if helper == nil {
|
||||
return fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func (helper *CryptoHelper) Init() error {
|
|||
|
||||
var stateStore crypto.StateStore
|
||||
if helper.managedStateStore != nil {
|
||||
err := helper.managedStateStore.Upgrade()
|
||||
err := helper.managedStateStore.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade client state store: %w", err)
|
||||
}
|
||||
|
@ -124,7 +124,6 @@ func (helper *CryptoHelper) Init() error {
|
|||
} else {
|
||||
stateStore = helper.client.StateStore.(crypto.StateStore)
|
||||
}
|
||||
ctx := context.TODO()
|
||||
var cryptoStore crypto.Store
|
||||
if helper.unmanagedCryptoStore == nil {
|
||||
managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey)
|
||||
|
@ -133,11 +132,14 @@ func (helper *CryptoHelper) Init() error {
|
|||
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
|
||||
helper.client.Store = managedCryptoStore
|
||||
}
|
||||
err := managedCryptoStore.DB.Upgrade()
|
||||
err := managedCryptoStore.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
|
||||
}
|
||||
storedDeviceID := managedCryptoStore.FindDeviceID()
|
||||
storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find existing device ID: %w", err)
|
||||
}
|
||||
if helper.LoginAs != nil {
|
||||
if storedDeviceID != "" {
|
||||
helper.LoginAs.DeviceID = storedDeviceID
|
||||
|
@ -168,7 +170,7 @@ func (helper *CryptoHelper) Init() error {
|
|||
return fmt.Errorf("the client must be logged in")
|
||||
}
|
||||
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
|
||||
err := helper.mach.Load()
|
||||
err := helper.mach.Load(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm account: %w", err)
|
||||
} else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil {
|
||||
|
@ -253,17 +255,18 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.
|
|||
Str("session_id", content.SessionID.String()).
|
||||
Logger()
|
||||
log.Debug().Msg("Decrypting received event")
|
||||
ctx := log.WithContext(context.TODO())
|
||||
|
||||
decrypted, err := helper.Decrypt(evt)
|
||||
decrypted, err := helper.Decrypt(ctx, evt)
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err = helper.Decrypt(evt)
|
||||
decrypted, err = helper.Decrypt(ctx, evt)
|
||||
} else {
|
||||
go helper.waitLongerForSession(log, src, evt)
|
||||
go helper.waitLongerForSession(ctx, log, src, evt)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -306,20 +309,20 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
|
|||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
|
||||
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
|
||||
content := evt.Content.AsEncrypted()
|
||||
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
|
||||
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
log.Debug().Msg("Didn't get session, giving up")
|
||||
helper.DecryptErrorCallback(evt, NoSessionFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
|
||||
decrypted, err := helper.Decrypt(evt)
|
||||
decrypted, err := helper.Decrypt(ctx, evt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decrypt event")
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
|
@ -329,32 +332,31 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix
|
|||
helper.postDecrypt(src, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
if helper == nil {
|
||||
return false
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
|
||||
return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||
func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
if helper == nil {
|
||||
return nil, fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
|
||||
return helper.mach.DecryptMegolmEvent(ctx, evt)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
if helper == nil {
|
||||
return nil, fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
ctx := context.TODO()
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
|
||||
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().
|
||||
|
@ -362,7 +364,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
|
|||
Str("room_id", roomID.String()).
|
||||
Msg("Got session error while encrypting event, sharing group session and trying again")
|
||||
var users []id.UserID
|
||||
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
|
||||
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get room member list: %w", err)
|
||||
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -91,7 +91,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
|||
} else {
|
||||
forwardedKeys = true
|
||||
lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1]
|
||||
device, _ = mach.CryptoStore.FindDeviceByKey(evt.Sender, id.IdentityKey(lastChainItem))
|
||||
device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem))
|
||||
if device != nil {
|
||||
trustLevel = mach.ResolveTrust(device)
|
||||
} else {
|
||||
|
@ -188,7 +188,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
mach.megolmDecryptLock.Lock()
|
||||
defer mach.megolmDecryptLock.Unlock()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
|
@ -250,7 +250,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
Int("max_messages", sess.MaxMessages).
|
||||
Logger()
|
||||
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
|
||||
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete fully used session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
|
@ -261,14 +261,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -159,7 +159,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
}
|
||||
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
|
||||
err = mach.CryptoStore.UpdateSession(senderKey, session)
|
||||
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
|
||||
|
@ -170,7 +170,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
|||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
|
||||
sessions, err := mach.CryptoStore.GetSessions(senderKey)
|
||||
sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err)
|
||||
|
@ -199,7 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
|||
}
|
||||
} else {
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
|
||||
err = mach.CryptoStore.UpdateSession(senderKey, session)
|
||||
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
|
||||
|
@ -217,7 +217,7 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S
|
|||
return nil, err
|
||||
}
|
||||
mach.saveAccount()
|
||||
err = mach.CryptoStore.AddSession(senderKey, session)
|
||||
err = mach.CryptoStore.AddSession(ctx, senderKey, session)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -53,7 +53,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
|||
Str("signed_device_id", deviceID.String()).
|
||||
Str("signature", signature).
|
||||
Msg("Verified self-signing signature")
|
||||
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
|
||||
err = mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
|
@ -74,7 +74,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
|||
}
|
||||
// save signature of device made by its own device signing key
|
||||
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
|
||||
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
|
||||
err := mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
|
@ -96,7 +96,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
log := mach.machOrContextLog(ctx)
|
||||
if !includeUntracked {
|
||||
var err error
|
||||
users, err = mach.CryptoStore.FilterTrackedUsers(users)
|
||||
users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to filter tracked user list")
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
delete(req.DeviceKeys, userID)
|
||||
|
||||
newDevices := make(map[id.DeviceID]*id.Device)
|
||||
existingDevices, err := mach.CryptoStore.GetDevices(userID)
|
||||
existingDevices, err := mach.CryptoStore.GetDevices(ctx, userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get existing devices for user")
|
||||
existingDevices = make(map[id.DeviceID]*id.Device)
|
||||
|
@ -151,7 +151,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
}
|
||||
}
|
||||
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
|
||||
err = mach.CryptoStore.PutDevices(userID, newDevices)
|
||||
err = mach.CryptoStore.PutDevices(ctx, userID, newDevices)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update device list")
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
Str("identity_key", device.IdentityKey.String()).
|
||||
Str("signing_key", device.SigningKey.String()).
|
||||
Logger()
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, "", device.IdentityKey, "device removed")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
|
||||
} else {
|
||||
|
@ -179,7 +179,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
}
|
||||
}
|
||||
}
|
||||
mach.OnDevicesChanged(userID)
|
||||
mach.OnDevicesChanged(ctx, userID)
|
||||
}
|
||||
}
|
||||
for userID := range req.DeviceKeys {
|
||||
|
@ -197,18 +197,25 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
|||
//
|
||||
// This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does
|
||||
// not need to be called manually.
|
||||
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
|
||||
func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) {
|
||||
if mach.DisableDeviceChangeKeyRotation {
|
||||
return
|
||||
}
|
||||
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
|
||||
mach.Log.Debug().
|
||||
rooms, err := mach.StateStore.FindSharedRooms(ctx, userID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).
|
||||
Stringer("with_user_id", userID).
|
||||
Msg("Failed to find shared rooms to invalidate group sessions")
|
||||
return
|
||||
}
|
||||
for _, roomID := range rooms {
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
Str("user_id", userID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Invalidating group session in room due to device change notification")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
|
||||
err = mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Err(err).
|
||||
mach.machOrContextLog(ctx).Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Failed to invalidate outbound group session")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -84,7 +84,7 @@ func parseMessageIndex(ciphertext []byte) (uint, error) {
|
|||
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
||||
} else if session == nil {
|
||||
|
@ -116,7 +116,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
|
|||
log = log.With().Uint("message_index", idx).Logger()
|
||||
}
|
||||
log.Debug().Msg("Encrypted event successfully")
|
||||
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
|
||||
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
|
||||
}
|
||||
|
@ -137,7 +137,13 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
|
|||
}
|
||||
|
||||
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
|
||||
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
|
||||
encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to get encryption event in room")
|
||||
}
|
||||
session := NewOutboundGroupSession(roomID, encryptionEvent)
|
||||
if !mach.DontStoreOutboundKeys {
|
||||
signingKey, idKey := mach.account.Keys()
|
||||
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
|
||||
|
@ -165,7 +171,7 @@ func strishArray[T ~string](arr []T) []string {
|
|||
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get previous outbound group session: %w", err)
|
||||
} else if session != nil && session.Shared && !session.Expired() {
|
||||
|
@ -192,7 +198,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
|||
|
||||
for _, userID := range users {
|
||||
log := log.With().Str("target_user_id", userID.String()).Logger()
|
||||
devices, err := mach.CryptoStore.GetDevices(userID)
|
||||
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get devices of user")
|
||||
} else if devices == nil {
|
||||
|
@ -292,7 +298,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
|||
|
||||
log.Debug().Msg("Group session successfully shared")
|
||||
session.Shared = true
|
||||
return mach.CryptoStore.AddOutboundGroupSession(session)
|
||||
return mach.CryptoStore.AddOutboundGroupSession(ctx, session)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
|
||||
|
@ -367,7 +373,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out
|
|||
Reason: "This device does not encrypt messages for unverified devices",
|
||||
}}
|
||||
session.Users[userKey] = OGSIgnored
|
||||
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
|
||||
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
|
||||
} else if deviceSession == nil {
|
||||
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -38,7 +38,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
|||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Encrypting olm message")
|
||||
msgType, ciphertext := session.Encrypt(plaintext)
|
||||
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
|
||||
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
|
||||
}
|
||||
|
@ -54,8 +54,8 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
|||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool {
|
||||
if !mach.CryptoStore.HasSession(identityKey) {
|
||||
func (mach *OlmMachine) shouldCreateNewSession(ctx context.Context, identityKey id.IdentityKey) bool {
|
||||
if !mach.CryptoStore.HasSession(ctx, identityKey) {
|
||||
return true
|
||||
}
|
||||
mach.devicesToUnwedgeLock.Lock()
|
||||
|
@ -72,7 +72,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
|
|||
for userID, devices := range input {
|
||||
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
|
||||
for deviceID, identity := range devices {
|
||||
if mach.shouldCreateNewSession(identity.IdentityKey) {
|
||||
if mach.shouldCreateNewSession(ctx, identity.IdentityKey) {
|
||||
request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519
|
||||
}
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
|
|||
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
|
||||
} else {
|
||||
wrapped := wrapSession(sess)
|
||||
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
|
||||
err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store created outbound session")
|
||||
} else {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -8,6 +8,7 @@ package crypto
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
|
@ -91,7 +92,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
|
|||
return sessionsJSON, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) {
|
||||
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
|
||||
if session.Algorithm != id.AlgorithmMegolmV1 {
|
||||
return false, ErrInvalidExportedAlgorithm
|
||||
}
|
||||
|
@ -112,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
|
|||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
// We already have an equivalent or better session in the store, so don't override it.
|
||||
return false, nil
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store imported session: %w", err)
|
||||
}
|
||||
|
@ -127,7 +128,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
|
|||
|
||||
// ImportKeys imports data that was exported with the format specified in the Matrix spec.
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#key-exports
|
||||
func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) {
|
||||
func (mach *OlmMachine) ImportKeys(ctx context.Context, passphrase string, data []byte) (int, int, error) {
|
||||
exportData, err := decodeKeyExport(data)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
|
@ -143,8 +144,11 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
|
|||
Str("room_id", session.RoomID.String()).
|
||||
Str("session_id", session.SessionID.String()).
|
||||
Logger()
|
||||
imported, err := mach.importExportedRoomKey(session)
|
||||
imported, err := mach.importExportedRoomKey(ctx, session)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return count, len(sessions), ctx.Err()
|
||||
}
|
||||
log.Error().Err(err).Msg("Failed to import Megolm session from file")
|
||||
} else if imported {
|
||||
log.Debug().Msg("Imported Megolm session from file")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -152,7 +152,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
Msg("Mismatched session ID while creating inbound group session from forward")
|
||||
return false
|
||||
}
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get encryption event for room")
|
||||
}
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
|
@ -178,7 +181,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
MaxMessages: maxMessages,
|
||||
IsScheduled: content.IsScheduled,
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store new inbound group session")
|
||||
return false
|
||||
|
@ -274,7 +277,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
|
|||
return
|
||||
}
|
||||
|
||||
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
|
||||
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Requested group session not available")
|
||||
|
@ -331,7 +334,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
|
|||
Int("first_message_index", content.FirstMessageIndex).
|
||||
Logger()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Acked group session was already redacted")
|
||||
|
@ -351,7 +354,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
|
|||
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
|
||||
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
|
||||
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
|
||||
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact group session")
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -80,11 +80,11 @@ type OlmMachine struct {
|
|||
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
|
||||
type StateStore interface {
|
||||
// IsEncrypted returns whether a room is encrypted.
|
||||
IsEncrypted(id.RoomID) bool
|
||||
IsEncrypted(context.Context, id.RoomID) (bool, error)
|
||||
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
|
||||
GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent
|
||||
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
|
||||
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
|
||||
FindSharedRooms(id.UserID) []id.RoomID
|
||||
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
|
||||
}
|
||||
|
||||
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
|
||||
|
@ -131,8 +131,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
|
|||
|
||||
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
|
||||
// This must be called before using the machine.
|
||||
func (mach *OlmMachine) Load() (err error) {
|
||||
mach.account, err = mach.CryptoStore.GetAccount()
|
||||
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
|
||||
mach.account, err = mach.CryptoStore.GetAccount(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -143,15 +143,15 @@ func (mach *OlmMachine) Load() (err error) {
|
|||
}
|
||||
|
||||
func (mach *OlmMachine) saveAccount() {
|
||||
err := mach.CryptoStore.PutAccount(mach.account)
|
||||
err := mach.CryptoStore.PutAccount(context.TODO(), mach.account)
|
||||
if err != nil {
|
||||
mach.Log.Error().Err(err).Msg("Failed to save account")
|
||||
}
|
||||
}
|
||||
|
||||
// FlushStore calls the Flush method of the CryptoStore.
|
||||
func (mach *OlmMachine) FlushStore() error {
|
||||
return mach.CryptoStore.Flush()
|
||||
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
|
||||
return mach.CryptoStore.Flush(ctx)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
|
||||
|
@ -284,7 +284,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
|
|||
//
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
|
||||
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
|
||||
if !mach.StateStore.IsEncrypted(evt.RoomID) {
|
||||
ctx := context.TODO()
|
||||
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
|
||||
Msg("Failed to check if room is encrypted to handle member event")
|
||||
return
|
||||
} else if !isEncrypted {
|
||||
return
|
||||
}
|
||||
content := evt.Content.AsMember()
|
||||
|
@ -311,7 +316,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
|
|||
Str("prev_membership", string(prevContent.Membership)).
|
||||
Str("new_membership", string(content.Membership)).
|
||||
Msg("Got membership state change, invalidating group session in room")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
|
||||
}
|
||||
|
@ -405,7 +410,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
|||
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
|
||||
// and if it's not found it asks the server for it.
|
||||
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
|
||||
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
|
||||
} else if device != nil {
|
||||
|
@ -425,7 +430,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
|
|||
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
|
||||
// the given identity key.
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil {
|
||||
return deviceIdentity, err
|
||||
}
|
||||
|
@ -455,7 +460,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
|
|||
mach.olmLock.Lock()
|
||||
defer mach.olmLock.Unlock()
|
||||
|
||||
olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey)
|
||||
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -499,7 +504,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
|
|||
Msg("Mismatched session ID while creating inbound group session")
|
||||
return
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
|
||||
return
|
||||
|
@ -525,7 +530,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
|
|||
}
|
||||
|
||||
// WaitForSession waits for the given Megolm session to arrive.
|
||||
func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
mach.keyWaitersLock.Lock()
|
||||
ch, ok := mach.keyWaiters[sessionID]
|
||||
if !ok {
|
||||
|
@ -534,7 +539,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
}
|
||||
mach.keyWaitersLock.Unlock()
|
||||
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
|
||||
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
|
||||
return true
|
||||
}
|
||||
|
@ -542,10 +547,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
case <-ch:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
|
||||
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
// Check if the session somehow appeared in the store without telling us
|
||||
// We accept withheld sessions as received, as then the decryption attempt will show the error.
|
||||
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -568,7 +575,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
|||
return
|
||||
}
|
||||
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get encryption event for room")
|
||||
}
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
|
@ -589,7 +599,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
|||
}
|
||||
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
|
||||
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact previous megolm sessions")
|
||||
} else {
|
||||
|
@ -606,7 +616,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even
|
|||
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
|
||||
return
|
||||
}
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(*content)
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
|
||||
}
|
||||
|
@ -662,7 +672,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
|||
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
|
||||
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
|
||||
for {
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact expired megolm sessions")
|
||||
} else if len(sessionIDs) > 0 {
|
||||
|
|
|
@ -20,18 +20,18 @@ import (
|
|||
|
||||
type mockStateStore struct{}
|
||||
|
||||
func (mockStateStore) IsEncrypted(id.RoomID) bool {
|
||||
return true
|
||||
func (mockStateStore) IsEncrypted(context.Context, id.RoomID) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (mockStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent {
|
||||
func (mockStateStore) GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error) {
|
||||
return &event.EncryptionEventContent{
|
||||
RotationPeriodMessages: 3,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (mockStateStore) FindSharedRooms(id.UserID) []id.RoomID {
|
||||
return []id.RoomID{"room1"}
|
||||
func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error) {
|
||||
return []id.RoomID{"room1"}, nil
|
||||
}
|
||||
|
||||
func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
|
||||
|
@ -47,7 +47,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
|
|||
}
|
||||
|
||||
machine := NewOlmMachine(client, nil, gobStore, mockStateStore{})
|
||||
if err := machine.Load(); err != nil {
|
||||
if err := machine.Load(context.TODO()); err != nil {
|
||||
t.Fatalf("Error creating account: %v", err)
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine {
|
|||
func TestRatchetMegolmSession(t *testing.T) {
|
||||
mach := newMachine(t, "user1")
|
||||
outSess := mach.newOutboundGroupSession(context.TODO(), "meow")
|
||||
inSess, err := mach.CryptoStore.GetGroupSession("meow", mach.OwnIdentity().IdentityKey, outSess.ID())
|
||||
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex())
|
||||
err = inSess.RatchetTo(10)
|
||||
|
@ -85,7 +85,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
}
|
||||
|
||||
// store sender device identity in receiving machine store
|
||||
machineIn.CryptoStore.PutDevices("user1", map[id.DeviceID]*id.Device{
|
||||
machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{
|
||||
"device1": {
|
||||
UserID: "user1",
|
||||
DeviceID: "device1",
|
||||
|
@ -97,7 +97,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
// create & store outbound megolm session for sending the event later
|
||||
megolmOutSession := machineOut.newOutboundGroupSession(context.TODO(), "room1")
|
||||
megolmOutSession.Shared = true
|
||||
machineOut.CryptoStore.AddOutboundGroupSession(megolmOutSession)
|
||||
machineOut.CryptoStore.AddOutboundGroupSession(context.TODO(), megolmOutSession)
|
||||
|
||||
// encrypt m.room_key event with olm session
|
||||
deviceIdentity := &id.Device{
|
||||
|
@ -125,7 +125,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("Error creating inbound megolm session: %v", err)
|
||||
}
|
||||
if err = machineIn.CryptoStore.PutGroupSession("room1", senderKey, igs.ID(), igs); err != nil {
|
||||
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil {
|
||||
t.Errorf("Error storing inbound megolm session: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -27,7 +27,7 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var PostgresArrayWrapper func(interface{}) interface {
|
||||
var PostgresArrayWrapper func(any) interface {
|
||||
driver.Valuer
|
||||
sql.Scanner
|
||||
}
|
||||
|
@ -62,21 +62,21 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID
|
|||
}
|
||||
|
||||
// Flush does nothing for this implementation as data is already persisted in the database.
|
||||
func (store *SQLCryptoStore) Flush() error {
|
||||
func (store *SQLCryptoStore) Flush(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutNextBatch stores the next sync batch token for the current account.
|
||||
func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error {
|
||||
store.SyncToken = nextBatch
|
||||
_, err := store.DB.ExecContext(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
|
||||
_, err := store.DB.Exec(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetNextBatch retrieves the next sync batch token for the current account.
|
||||
func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) {
|
||||
if store.SyncToken == "" {
|
||||
err := store.DB.
|
||||
err := store.DB.Conn(ctx).
|
||||
QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID).
|
||||
Scan(&store.SyncToken)
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -111,20 +111,19 @@ func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (st
|
|||
return nb, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
|
||||
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
// TODO return error
|
||||
store.DB.Log.Warn("Failed to scan device ID: %v", err)
|
||||
func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.DeviceID, err error) {
|
||||
err = store.DB.QueryRow(ctx, "SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PutAccount stores an OlmAccount in the database.
|
||||
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
|
||||
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
|
||||
store.Account = account
|
||||
bytes := account.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(`
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
|
||||
account=excluded.account, account_id=excluded.account_id
|
||||
|
@ -133,9 +132,9 @@ func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
|
|||
}
|
||||
|
||||
// GetAccount retrieves an OlmAccount from the database.
|
||||
func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
|
||||
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
|
||||
if store.Account == nil {
|
||||
row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
|
||||
var accountBytes []byte
|
||||
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
|
||||
|
@ -154,7 +153,7 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
|
|||
}
|
||||
|
||||
// HasSession returns whether there is an Olm session for the given sender key.
|
||||
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
||||
func (store *SQLCryptoStore) HasSession(ctx context.Context, key id.SenderKey) bool {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
cache, ok := store.olmSessionCache[key]
|
||||
store.olmSessionCacheLock.Unlock()
|
||||
|
@ -162,17 +161,17 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
|||
return true
|
||||
}
|
||||
var sessionID id.SessionID
|
||||
err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
|
||||
err := store.DB.QueryRow(ctx, "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
|
||||
key, store.AccountID).Scan(&sessionID)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
return len(sessionID) > 0
|
||||
}
|
||||
|
||||
// GetSessions returns all the known Olm sessions for a sender key.
|
||||
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) {
|
||||
rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
|
||||
func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) (OlmSessionList, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
|
||||
key, store.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -212,11 +211,11 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session
|
|||
}
|
||||
|
||||
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
|
||||
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) {
|
||||
func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
|
||||
row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
|
||||
row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
|
||||
key, store.AccountID)
|
||||
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession()}
|
||||
|
@ -224,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
|
|||
var sessionID id.SessionID
|
||||
|
||||
err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
|
@ -242,20 +241,20 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
|
|||
}
|
||||
|
||||
// AddSession persists an Olm session for a sender in the database.
|
||||
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error {
|
||||
func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID)
|
||||
store.getOlmSessionCache(key)[session.ID()] = session
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateSession replaces the Olm session for a sender in the database.
|
||||
func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error {
|
||||
func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
|
||||
sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
@ -275,14 +274,14 @@ func datePtr(t time.Time) *time.Time {
|
|||
}
|
||||
|
||||
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
|
||||
}
|
||||
_, err = store.DB.Exec(`
|
||||
_, err = store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_megolm_inbound_session (
|
||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
|
||||
|
@ -301,19 +300,19 @@ func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.Send
|
|||
}
|
||||
|
||||
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := store.DB.QueryRow(`
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
|
@ -327,22 +326,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
|||
Reason: withheldReason.String,
|
||||
}
|
||||
}
|
||||
igs := olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var chains []string
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
|
||||
if senderKey == "" {
|
||||
senderKey = id.Curve25519(senderKeyDB.String)
|
||||
}
|
||||
|
@ -360,8 +344,8 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
_, err := store.DB.Exec(`
|
||||
func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
|
||||
|
@ -369,27 +353,24 @@ func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, ses
|
|||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
if roomID == "" && senderKey == "" {
|
||||
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
|
||||
}
|
||||
res, err := store.DB.Query(`
|
||||
res, err := store.DB.Query(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
|
||||
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
|
||||
RETURNING session_id
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionIDs, err
|
||||
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) {
|
||||
var query string
|
||||
switch store.DB.Dialect {
|
||||
case dbutil.Postgres:
|
||||
|
@ -413,46 +394,40 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error
|
|||
default:
|
||||
return nil, fmt.Errorf("unsupported dialect")
|
||||
}
|
||||
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionIDs, err
|
||||
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
|
||||
res, err := store.DB.Query(`
|
||||
func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) {
|
||||
res, err := store.DB.Query(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
|
||||
RETURNING session_id
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionIDs, err
|
||||
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
var code, reason sql.NullString
|
||||
err := store.DB.QueryRow(`
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&code, &reason)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil || !code.Valid {
|
||||
return nil, err
|
||||
|
@ -467,82 +442,79 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) {
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
|
||||
igs = olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if forwardingChains != "" {
|
||||
chains = strings.Split(forwardingChains, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return
|
||||
err = fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
igs := olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var chains []string
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
result = append(result, &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*InboundGroupSession, error) {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||
roomID, store.AccountID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return []*InboundGroupSession{}, nil
|
||||
} else if err != nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store.scanGroupSessionList(rows)
|
||||
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
|
||||
store.AccountID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return []*InboundGroupSession{}, nil
|
||||
} else if err != nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store.scanGroupSessionList(rows)
|
||||
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
|
||||
}
|
||||
|
||||
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(`
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_megolm_outbound_session
|
||||
(room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
|
@ -556,24 +528,24 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
|
|||
}
|
||||
|
||||
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
|
||||
sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID.
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
var ogs OutboundGroupSession
|
||||
var sessionBytes []byte
|
||||
var maxAgeMS int64
|
||||
err := store.DB.QueryRow(`
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
||||
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
|
||||
roomID, store.AccountID,
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
|
@ -590,8 +562,8 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
|
|||
}
|
||||
|
||||
// RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID.
|
||||
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
_, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
|
||||
func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roomID id.RoomID) error {
|
||||
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
|
||||
roomID, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
@ -608,7 +580,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
|
|||
`
|
||||
var expectedEventID id.EventID
|
||||
var expectedTimestamp int64
|
||||
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
err := store.DB.QueryRow(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
|
||||
|
@ -623,69 +595,58 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func scanDevice(rows dbutil.Scannable) (*id.Device, error) {
|
||||
var device id.Device
|
||||
err := rows.Scan(&device.UserID, &device.DeviceID, &device.IdentityKey, &device.SigningKey, &device.Trust, &device.Deleted, &device.Name)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
|
||||
func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
var ignore id.UserID
|
||||
err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
|
||||
if err == sql.ErrNoRows {
|
||||
err := store.DB.QueryRow(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
|
||||
rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make(map[id.DeviceID]*id.Device)
|
||||
for rows.Next() {
|
||||
var identity id.Device
|
||||
err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
data[identity.DeviceID] = &identity
|
||||
err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) {
|
||||
data[device.DeviceID] = device
|
||||
return true, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// GetDevice returns the device dentity for a given user and device ID.
|
||||
func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
var identity id.Device
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT identity_key, signing_key, trust, deleted, name
|
||||
func (store *SQLCryptoStore) GetDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
return scanDevice(store.DB.QueryRow(ctx, `
|
||||
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
|
||||
FROM crypto_device WHERE user_id=$1 AND device_id=$2`,
|
||||
userID, deviceID,
|
||||
).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
identity.DeviceID = deviceID
|
||||
return &identity, nil
|
||||
))
|
||||
}
|
||||
|
||||
// FindDeviceByKey finds a specific device by its sender key.
|
||||
func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
var identity id.Device
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT device_id, signing_key, trust, deleted, name
|
||||
func (store *SQLCryptoStore) FindDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
return scanDevice(store.DB.QueryRow(ctx, `
|
||||
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
|
||||
FROM crypto_device WHERE user_id=$1 AND identity_key=$2`,
|
||||
userID, identityKey,
|
||||
).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
identity.IdentityKey = identityKey
|
||||
return &identity, nil
|
||||
))
|
||||
}
|
||||
|
||||
const deviceInsertQuery = `
|
||||
|
@ -698,106 +659,84 @@ ON CONFLICT (user_id, device_id) DO UPDATE
|
|||
var deviceMassInsertTemplate = strings.ReplaceAll(deviceInsertQuery, "($1, $2, $3, $4, $5, $6, $7)", "%s")
|
||||
|
||||
// PutDevice stores a single device for a user, replacing it if it exists already.
|
||||
func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
_, err := store.DB.Exec(deviceInsertQuery,
|
||||
func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, device *id.Device) error {
|
||||
_, err := store.DB.Exec(ctx, deviceInsertQuery,
|
||||
userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// PutDevices stores the device identity information for the given user ID.
|
||||
func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
tx, err := store.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to tracked users list: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("failed to delete old devices: %w", err)
|
||||
}
|
||||
if len(devices) == 0 {
|
||||
err = tx.Commit()
|
||||
func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes (no devices added): %w", err)
|
||||
return fmt.Errorf("failed to add user to tracked users list: %w", err)
|
||||
}
|
||||
|
||||
_, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old devices: %w", err)
|
||||
}
|
||||
if len(devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
deviceBatchLen := 5 // how many devices will be inserted per query
|
||||
deviceIDs := make([]id.DeviceID, 0, len(devices))
|
||||
for deviceID := range devices {
|
||||
deviceIDs = append(deviceIDs, deviceID)
|
||||
}
|
||||
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
|
||||
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
|
||||
var batchDevices []id.DeviceID
|
||||
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
|
||||
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
|
||||
} else {
|
||||
batchDevices = deviceIDs[batchDeviceIdx:]
|
||||
}
|
||||
values := make([]interface{}, 1, len(devices)*6+1)
|
||||
values[0] = userID
|
||||
valueStrings := make([]string, 0, len(devices))
|
||||
i := 2
|
||||
for _, deviceID := range batchDevices {
|
||||
identity := devices[deviceID]
|
||||
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
|
||||
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
|
||||
i += 6
|
||||
}
|
||||
valueString := strings.Join(valueStrings, ",")
|
||||
_, err = store.DB.Exec(ctx, fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert new devices: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
deviceBatchLen := 5 // how many devices will be inserted per query
|
||||
deviceIDs := make([]id.DeviceID, 0, len(devices))
|
||||
for deviceID := range devices {
|
||||
deviceIDs = append(deviceIDs, deviceID)
|
||||
}
|
||||
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
|
||||
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
|
||||
var batchDevices []id.DeviceID
|
||||
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
|
||||
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
|
||||
} else {
|
||||
batchDevices = deviceIDs[batchDeviceIdx:]
|
||||
}
|
||||
values := make([]interface{}, 1, len(devices)*6+1)
|
||||
values[0] = userID
|
||||
valueStrings := make([]string, 0, len(devices))
|
||||
i := 2
|
||||
for _, deviceID := range batchDevices {
|
||||
identity := devices[deviceID]
|
||||
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
|
||||
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
|
||||
i += 6
|
||||
}
|
||||
valueString := strings.Join(valueStrings, ",")
|
||||
_, err = tx.Exec(fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("failed to insert new devices: %w", err)
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information.
|
||||
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) {
|
||||
var rows dbutil.Rows
|
||||
var err error
|
||||
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
|
||||
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
|
||||
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
|
||||
} else {
|
||||
queryString := make([]string, len(users))
|
||||
params := make([]interface{}, len(users))
|
||||
for i, user := range users {
|
||||
queryString[i] = fmt.Sprintf("$%d", i+1)
|
||||
queryString[i] = fmt.Sprintf("?%d", i+1)
|
||||
params[i] = user
|
||||
}
|
||||
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
||||
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
||||
}
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
var ptr int
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&users[ptr])
|
||||
if err != nil {
|
||||
return users, err
|
||||
} else {
|
||||
ptr++
|
||||
}
|
||||
}
|
||||
return users[:ptr], nil
|
||||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
|
||||
}
|
||||
|
||||
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
|
||||
func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
_, err := store.DB.Exec(`
|
||||
func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
|
||||
`, userID, usage, key, key)
|
||||
|
@ -805,8 +744,8 @@ func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.Cross
|
|||
}
|
||||
|
||||
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
|
||||
func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
rows, err := store.DB.Query("SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
|
||||
func (store *SQLCryptoStore) GetCrossSigningKeys(ctx context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -825,8 +764,8 @@ func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.Cross
|
|||
}
|
||||
|
||||
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
|
||||
func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
_, err := store.DB.Exec(`
|
||||
func (store *SQLCryptoStore) PutSignature(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature
|
||||
`, signedUserID, signedKey, signerUserID, signerKey, signature)
|
||||
|
@ -834,8 +773,8 @@ func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.E
|
|||
}
|
||||
|
||||
// GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer.
|
||||
func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
|
||||
func (store *SQLCryptoStore) GetSignaturesForKeyBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -854,18 +793,18 @@ func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25
|
|||
}
|
||||
|
||||
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
|
||||
func (store *SQLCryptoStore) IsKeySignedBy(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
|
||||
func (store *SQLCryptoStore) IsKeySignedBy(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
|
||||
q := `SELECT EXISTS(
|
||||
SELECT 1 FROM crypto_cross_signing_signatures
|
||||
WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
|
||||
)`
|
||||
err = store.DB.QueryRow(q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
|
||||
err = store.DB.QueryRow(ctx, q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
|
||||
return
|
||||
}
|
||||
|
||||
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
|
||||
func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
|
||||
func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
res, err := store.DB.Exec(ctx, "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
package sql_store_upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
|
@ -21,7 +22,7 @@ const VersionTableName = "crypto_version"
|
|||
var fs embed.FS
|
||||
|
||||
func init() {
|
||||
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
|
||||
Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
|
||||
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
|
||||
})
|
||||
Table.RegisterFS(fs)
|
||||
|
|
138
crypto/store.go
138
crypto/store.go
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,64 +26,64 @@ var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
|
|||
type Store interface {
|
||||
// Flush ensures that everything in the store is persisted to disk.
|
||||
// This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately.
|
||||
Flush() error
|
||||
Flush(context.Context) error
|
||||
|
||||
// PutAccount updates the OlmAccount in the store.
|
||||
PutAccount(*OlmAccount) error
|
||||
PutAccount(context.Context, *OlmAccount) error
|
||||
// GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount.
|
||||
GetAccount() (*OlmAccount, error)
|
||||
GetAccount(ctx context.Context) (*OlmAccount, error)
|
||||
|
||||
// AddSession inserts an Olm session into the store.
|
||||
AddSession(id.SenderKey, *OlmSession) error
|
||||
AddSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
// HasSession returns whether or not the store has an Olm session with the given sender key.
|
||||
HasSession(id.SenderKey) bool
|
||||
HasSession(context.Context, id.SenderKey) bool
|
||||
// GetSessions returns all Olm sessions in the store with the given sender key.
|
||||
GetSessions(id.SenderKey) (OlmSessionList, error)
|
||||
GetSessions(context.Context, id.SenderKey) (OlmSessionList, error)
|
||||
// GetLatestSession returns the session with the highest session ID (lexiographically sorting).
|
||||
// It's usually safe to return the most recently added session if sorting by session ID is too difficult.
|
||||
GetLatestSession(id.SenderKey) (*OlmSession, error)
|
||||
GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error)
|
||||
// UpdateSession updates a session that has previously been inserted with AddSession.
|
||||
UpdateSession(id.SenderKey, *OlmSession) error
|
||||
UpdateSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
|
||||
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
|
||||
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
|
||||
// sessions inserted with this call.
|
||||
PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
|
||||
PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
|
||||
// GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld
|
||||
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
|
||||
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
|
||||
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
|
||||
GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
|
||||
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
|
||||
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
|
||||
RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error
|
||||
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
|
||||
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
|
||||
RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
|
||||
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
|
||||
RedactExpiredGroupSessions() ([]id.SessionID, error)
|
||||
RedactExpiredGroupSessions(context.Context) ([]id.SessionID, error)
|
||||
// RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata.
|
||||
RedactOutdatedGroupSessions() ([]id.SessionID, error)
|
||||
RedactOutdatedGroupSessions(context.Context) ([]id.SessionID, error)
|
||||
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
|
||||
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
|
||||
PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error
|
||||
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
|
||||
GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
|
||||
GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
|
||||
|
||||
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
|
||||
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
|
||||
GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error)
|
||||
GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error)
|
||||
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
|
||||
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
|
||||
GetAllGroupSessions() ([]*InboundGroupSession, error)
|
||||
GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error)
|
||||
|
||||
// AddOutboundGroupSession inserts the given outbound Megolm session into the store.
|
||||
//
|
||||
// The store should index inserted sessions by the RoomID field to support getting and removing sessions.
|
||||
// There will only be one outbound session per room ID at a time.
|
||||
AddOutboundGroupSession(*OutboundGroupSession) error
|
||||
AddOutboundGroupSession(context.Context, *OutboundGroupSession) error
|
||||
// UpdateOutboundGroupSession updates the given outbound Megolm session in the store.
|
||||
UpdateOutboundGroupSession(*OutboundGroupSession) error
|
||||
UpdateOutboundGroupSession(context.Context, *OutboundGroupSession) error
|
||||
// GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store.
|
||||
GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error)
|
||||
GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error)
|
||||
// RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID.
|
||||
RemoveOutboundGroupSession(id.RoomID) error
|
||||
RemoveOutboundGroupSession(context.Context, id.RoomID) error
|
||||
|
||||
// ValidateMessageIndex validates that the given message details aren't from a replay attack.
|
||||
//
|
||||
|
@ -96,29 +96,29 @@ type Store interface {
|
|||
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
|
||||
|
||||
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
|
||||
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
|
||||
GetDevices(context.Context, id.UserID) (map[id.DeviceID]*id.Device, error)
|
||||
// GetDevice returns a specific device of a given user.
|
||||
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
|
||||
GetDevice(context.Context, id.UserID, id.DeviceID) (*id.Device, error)
|
||||
// PutDevice stores a single device for a user, replacing it if it exists already.
|
||||
PutDevice(id.UserID, *id.Device) error
|
||||
PutDevice(context.Context, id.UserID, *id.Device) error
|
||||
// PutDevices overrides the stored device list for the given user with the given list.
|
||||
PutDevices(id.UserID, map[id.DeviceID]*id.Device) error
|
||||
PutDevices(context.Context, id.UserID, map[id.DeviceID]*id.Device) error
|
||||
// FindDeviceByKey finds a specific device by its identity key.
|
||||
FindDeviceByKey(id.UserID, id.IdentityKey) (*id.Device, error)
|
||||
FindDeviceByKey(context.Context, id.UserID, id.IdentityKey) (*id.Device, error)
|
||||
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
|
||||
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
|
||||
FilterTrackedUsers([]id.UserID) ([]id.UserID, error)
|
||||
FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error)
|
||||
|
||||
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
|
||||
PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error
|
||||
PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error
|
||||
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
|
||||
GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
|
||||
GetCrossSigningKeys(context.Context, id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
|
||||
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
|
||||
PutSignature(signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
|
||||
PutSignature(ctx context.Context, signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
|
||||
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
|
||||
IsKeySignedBy(userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
|
||||
IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
|
||||
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
|
||||
DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error)
|
||||
DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error)
|
||||
}
|
||||
|
||||
type messageIndexKey struct {
|
||||
|
@ -170,18 +170,18 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
|
|||
}
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) Flush() error {
|
||||
func (gs *MemoryStore) Flush(_ context.Context) error {
|
||||
gs.lock.Lock()
|
||||
err := gs.save()
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
|
||||
func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) {
|
||||
return gs.Account, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
|
||||
func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error {
|
||||
gs.lock.Lock()
|
||||
gs.Account = account
|
||||
err := gs.save()
|
||||
|
@ -189,7 +189,7 @@ func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
|
||||
func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) {
|
||||
gs.lock.Lock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
if !ok {
|
||||
|
@ -200,7 +200,7 @@ func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, erro
|
|||
return sessions, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
|
||||
func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error {
|
||||
gs.lock.Lock()
|
||||
sessions, _ := gs.Sessions[senderKey]
|
||||
gs.Sessions[senderKey] = append(sessions, session)
|
||||
|
@ -210,19 +210,19 @@ func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) e
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
|
||||
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
|
||||
func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool {
|
||||
gs.lock.RLock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
gs.lock.RUnlock()
|
||||
return ok && len(sessions) > 0 && !sessions[0].Expired()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
|
||||
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
|
||||
gs.lock.RLock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
gs.lock.RUnlock()
|
||||
|
@ -246,7 +246,7 @@ func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey
|
|||
return sender
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
|
||||
func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
|
||||
gs.lock.Lock()
|
||||
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
|
||||
err := gs.save()
|
||||
|
@ -254,7 +254,7 @@ func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
|
||||
if !ok {
|
||||
|
@ -269,7 +269,7 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey,
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
gs.lock.Lock()
|
||||
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
|
||||
err := gs.save()
|
||||
|
@ -277,7 +277,7 @@ func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderK
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
gs.lock.Lock()
|
||||
var sessionIDs []id.SessionID
|
||||
if roomID != "" && senderKey != "" {
|
||||
|
@ -315,11 +315,11 @@ func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.Sender
|
|||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
|
||||
func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.SessionID, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
|
@ -337,7 +337,7 @@ func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.S
|
|||
return sender
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error {
|
||||
gs.lock.Lock()
|
||||
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
|
||||
err := gs.save()
|
||||
|
@ -345,7 +345,7 @@ func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEven
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
|
||||
gs.lock.Unlock()
|
||||
|
@ -355,7 +355,7 @@ func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Se
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
defer gs.lock.Unlock()
|
||||
room, ok := gs.GroupSessions[roomID]
|
||||
|
@ -371,7 +371,7 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGrou
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
var result []*InboundGroupSession
|
||||
for _, room := range gs.GroupSessions {
|
||||
|
@ -385,7 +385,7 @@ func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {
|
||||
gs.lock.Lock()
|
||||
gs.OutGroupSessions[session.RoomID] = session
|
||||
err := gs.save()
|
||||
|
@ -393,12 +393,12 @@ func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) er
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
|
||||
func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
gs.lock.RLock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
gs.lock.RUnlock()
|
||||
|
@ -408,7 +408,7 @@ func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroup
|
|||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
if !ok || session == nil {
|
||||
|
@ -443,7 +443,7 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
|
|||
return true, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
|
@ -453,7 +453,7 @@ func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device,
|
|||
return devices, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
func (gs *MemoryStore) GetDevice(_ context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
|
@ -467,7 +467,7 @@ func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.De
|
|||
return device, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
|
@ -482,7 +482,7 @@ func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.Identity
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error {
|
||||
gs.lock.Lock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
|
@ -495,7 +495,7 @@ func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
gs.lock.Lock()
|
||||
gs.Devices[userID] = devices
|
||||
err := gs.save()
|
||||
|
@ -503,7 +503,7 @@ func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) {
|
||||
gs.lock.RLock()
|
||||
var ptr int
|
||||
for _, userID := range users {
|
||||
|
@ -517,7 +517,7 @@ func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error
|
|||
return users[:ptr], nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
gs.lock.RLock()
|
||||
userKeys, ok := gs.CrossSigningKeys[userID]
|
||||
if !ok {
|
||||
|
@ -539,7 +539,7 @@ func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSignin
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
keys, ok := gs.CrossSigningKeys[userID]
|
||||
|
@ -549,7 +549,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSignin
|
|||
return keys, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
gs.lock.RLock()
|
||||
signedUserSigs, ok := gs.KeySignatures[signedUserID]
|
||||
if !ok {
|
||||
|
@ -572,7 +572,7 @@ func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519
|
|||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
userKeys, ok := gs.KeySignatures[userID]
|
||||
|
@ -590,8 +590,8 @@ func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, s
|
|||
return sigsBySigner, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
|
||||
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
|
||||
func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
|
||||
sigs, err := gs.GetSignaturesForKeyBy(ctx, userID, key, signerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -599,7 +599,7 @@ func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID
|
|||
return ok, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
var count int64
|
||||
gs.lock.RLock()
|
||||
for _, userSigs := range gs.KeySignatures {
|
||||
|
|
|
@ -36,7 +36,7 @@ func getCryptoStores(t *testing.T) map[string]Store {
|
|||
t.Fatalf("Error opening db: %v", err)
|
||||
}
|
||||
sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test"))
|
||||
if err = sqlStore.DB.Upgrade(); err != nil {
|
||||
if err = sqlStore.DB.Upgrade(context.TODO()); err != nil {
|
||||
t.Fatalf("Error creating tables: %v", err)
|
||||
}
|
||||
|
||||
|
@ -65,8 +65,8 @@ func TestPutAccount(t *testing.T) {
|
|||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
acc := NewOlmAccount()
|
||||
store.PutAccount(acc)
|
||||
retrieved, err := store.GetAccount()
|
||||
store.PutAccount(context.TODO(), acc)
|
||||
retrieved, err := store.GetAccount(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("Error retrieving account: %v", err)
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ func TestStoreOlmSession(t *testing.T) {
|
|||
stores := getCryptoStores(t)
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
if store.HasSession(olmSessID) {
|
||||
if store.HasSession(context.TODO(), olmSessID) {
|
||||
t.Error("Found Olm session before inserting it")
|
||||
}
|
||||
olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test"))
|
||||
|
@ -117,15 +117,15 @@ func TestStoreOlmSession(t *testing.T) {
|
|||
id: olmSessID,
|
||||
Internal: *olmInternal,
|
||||
}
|
||||
err = store.AddSession(olmSessID, &olmSess)
|
||||
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing Olm session: %v", err)
|
||||
}
|
||||
if !store.HasSession(olmSessID) {
|
||||
if !store.HasSession(context.TODO(), olmSessID) {
|
||||
t.Error("Not found Olm session after inserting it")
|
||||
}
|
||||
|
||||
retrieved, err := store.GetLatestSession(olmSessID)
|
||||
retrieved, err := store.GetLatestSession(context.TODO(), olmSessID)
|
||||
if err != nil {
|
||||
t.Errorf("Failed retrieving Olm session: %v", err)
|
||||
}
|
||||
|
@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) {
|
|||
RoomID: "room1",
|
||||
}
|
||||
|
||||
err = store.PutGroupSession("room1", acc.IdentityKey(), igs.ID(), igs)
|
||||
err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing inbound group session: %v", err)
|
||||
}
|
||||
|
||||
retrieved, err := store.GetGroupSession("room1", acc.IdentityKey(), igs.ID())
|
||||
retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID())
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving inbound group session: %v", err)
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
|
|||
stores := getCryptoStores(t)
|
||||
for storeName, store := range stores {
|
||||
t.Run(storeName, func(t *testing.T) {
|
||||
sess, err := store.GetOutboundGroupSession("room1")
|
||||
sess, err := store.GetOutboundGroupSession(context.TODO(), "room1")
|
||||
if sess != nil {
|
||||
t.Error("Got outbound session before inserting")
|
||||
}
|
||||
|
@ -188,12 +188,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
|
|||
}
|
||||
|
||||
outbound := NewOutboundGroupSession("room1", nil)
|
||||
err = store.AddOutboundGroupSession(outbound)
|
||||
err = store.AddOutboundGroupSession(context.TODO(), outbound)
|
||||
if err != nil {
|
||||
t.Errorf("Error inserting outbound session: %v", err)
|
||||
}
|
||||
|
||||
sess, err = store.GetOutboundGroupSession("room1")
|
||||
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
|
||||
if sess == nil {
|
||||
t.Error("Did not get outbound session after inserting")
|
||||
}
|
||||
|
@ -201,12 +201,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) {
|
|||
t.Errorf("Error retrieving outbound session: %v", err)
|
||||
}
|
||||
|
||||
err = store.RemoveOutboundGroupSession("room1")
|
||||
err = store.RemoveOutboundGroupSession(context.TODO(), "room1")
|
||||
if err != nil {
|
||||
t.Errorf("Error deleting outbound session: %v", err)
|
||||
}
|
||||
|
||||
sess, err = store.GetOutboundGroupSession("room1")
|
||||
sess, err = store.GetOutboundGroupSession(context.TODO(), "room1")
|
||||
if sess != nil {
|
||||
t.Error("Got outbound session after deleting")
|
||||
}
|
||||
|
@ -232,11 +232,11 @@ func TestStoreDevices(t *testing.T) {
|
|||
SigningKey: acc.SigningKey(),
|
||||
}
|
||||
}
|
||||
err := store.PutDevices("user1", deviceMap)
|
||||
err := store.PutDevices(context.TODO(), "user1", deviceMap)
|
||||
if err != nil {
|
||||
t.Errorf("Error string devices: %v", err)
|
||||
}
|
||||
devs, err := store.GetDevices("user1")
|
||||
devs, err := store.GetDevices(context.TODO(), "user1")
|
||||
if err != nil {
|
||||
t.Errorf("Error getting devices: %v", err)
|
||||
}
|
||||
|
@ -250,7 +250,7 @@ func TestStoreDevices(t *testing.T) {
|
|||
t.Errorf("Last device identity key does not match")
|
||||
}
|
||||
|
||||
filtered, err := store.FilterTrackedUsers([]id.UserID{"user0", "user1", "user2"})
|
||||
filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"})
|
||||
if err != nil {
|
||||
t.Errorf("Error filtering tracked users: %v", err)
|
||||
} else if len(filtered) != 1 || filtered[0] != "user1" {
|
||||
|
|
|
@ -507,7 +507,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use
|
|||
|
||||
// we can finally trust this device
|
||||
device.Trust = id.TrustStateVerified
|
||||
err = mach.CryptoStore.PutDevice(device.UserID, device)
|
||||
err = mach.CryptoStore.PutDevice(ctx, device.UserID, device)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
|
||||
}
|
||||
|
@ -521,7 +521,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use
|
|||
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
|
||||
}
|
||||
} else {
|
||||
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
|
||||
masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
|
||||
} else {
|
||||
|
|
2
go.mod
2
go.mod
|
@ -12,7 +12,7 @@ require (
|
|||
github.com/tidwall/gjson v1.17.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/yuin/goldmark v1.6.0
|
||||
go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80
|
||||
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894
|
||||
go.mau.fi/zeroconfig v0.1.2
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/exp v0.0.0-20231226003508-02704c960a9b
|
||||
|
|
4
go.sum
4
go.sum
|
@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
|||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
|
||||
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 h1:zcfIxHgzZpgGSJv/FUVbOjO4ZWa12En4TGhxgUI/QH0=
|
||||
go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
|
||||
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8=
|
||||
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
|
||||
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
|
||||
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -7,6 +7,7 @@
|
|||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
|
@ -15,6 +16,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
@ -44,26 +46,28 @@ func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge b
|
|||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
|
||||
func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) (bool, error) {
|
||||
var isRegistered bool
|
||||
err := store.
|
||||
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
Scan(&isRegistered)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return isRegistered
|
||||
return isRegistered, err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
|
||||
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to mark %s as registered: %v", userID, err)
|
||||
}
|
||||
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
|
||||
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent {
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
type Member struct {
|
||||
id.UserID
|
||||
event.MemberEventContent
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
args := make([]any, len(memberships)+1)
|
||||
args[0] = roomID
|
||||
query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1"
|
||||
|
@ -75,25 +79,26 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...even
|
|||
}
|
||||
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
|
||||
}
|
||||
rows, err := store.Query(query, args...)
|
||||
rows, err := store.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return members
|
||||
return nil, err
|
||||
}
|
||||
var userID id.UserID
|
||||
var member event.MemberEventContent
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
members[userID] = &member
|
||||
}
|
||||
}
|
||||
return members
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
|
||||
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
|
||||
return
|
||||
}).Iter(func(m Member) (bool, error) {
|
||||
members[m.UserID] = &m.MemberEventContent
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
|
||||
memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite)
|
||||
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) {
|
||||
var memberMap map[id.UserID]*event.MemberEventContent
|
||||
memberMap, err = store.GetRoomMembers(ctx, roomID, event.MembershipJoin, event.MembershipInvite)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
members = make([]id.UserID, len(memberMap))
|
||||
i := 0
|
||||
for userID := range memberMap {
|
||||
|
@ -103,37 +108,39 @@ func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (mem
|
|||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
membership := event.MembershipLeave
|
||||
err := store.
|
||||
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
func (store *SQLStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (membership event.Membership, err error) {
|
||||
err = store.
|
||||
QueryRow(ctx, "SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&membership)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
membership = event.MembershipLeave
|
||||
err = nil
|
||||
}
|
||||
return membership
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member.Membership = event.MembershipLeave
|
||||
func (store *SQLStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
member, err := store.TryGetMember(ctx, roomID, userID)
|
||||
if member == nil && err == nil {
|
||||
member = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return member
|
||||
return member, err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
|
||||
func (store *SQLStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
var member event.MemberEventContent
|
||||
err := store.
|
||||
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
QueryRow(ctx, "SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &member, err == nil
|
||||
return &member, nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
||||
func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
|
||||
query := `
|
||||
SELECT room_id FROM mx_user_profile
|
||||
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||
|
@ -141,38 +148,32 @@ func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID
|
|||
`
|
||||
if !store.IsBridge {
|
||||
query = `
|
||||
SELECT mx_user_profile.room_id FROM mx_user_profile
|
||||
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
|
||||
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
|
||||
`
|
||||
SELECT mx_user_profile.room_id FROM mx_user_profile
|
||||
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
|
||||
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
|
||||
`
|
||||
}
|
||||
rows, err := store.Query(query, userID)
|
||||
rows, err := store.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership, err := store.GetMembership(ctx, roomID, userID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get membership")
|
||||
return false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
|
@ -181,27 +182,23 @@ func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, all
|
|||
return false
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
_, err := store.Exec(`
|
||||
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
|
||||
`, roomID, userID, membership)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
_, err := store.Exec(`
|
||||
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
|
||||
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
|
||||
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
|
||||
params := make([]any, len(memberships)+1)
|
||||
params[0] = roomID
|
||||
|
@ -213,109 +210,85 @@ func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...
|
|||
}
|
||||
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
|
||||
}
|
||||
_, err := store.Exec(query, params...)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err)
|
||||
}
|
||||
_, err := store.Exec(ctx, query, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
|
||||
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
|
||||
contentBytes, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err)
|
||||
return
|
||||
return fmt.Errorf("failed to marshal content JSON: %w", err)
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
_, err = store.Exec(ctx, `
|
||||
INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption
|
||||
`, roomID, contentBytes)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
||||
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err)
|
||||
}
|
||||
return nil
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else if data == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
content := &event.EncryptionEventContent{}
|
||||
err = json.Unmarshal(data, content)
|
||||
var content event.EncryptionEventContent
|
||||
err = json.Unmarshal(data, &content)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err)
|
||||
return nil
|
||||
return nil, fmt.Errorf("failed to parse content JSON: %w", err)
|
||||
}
|
||||
return content
|
||||
return &content, nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool {
|
||||
cfg := store.GetEncryptionEvent(roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
|
||||
func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
|
||||
cfg, err := store.GetEncryptionEvent(ctx, roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
levelsBytes, err := json.Marshal(levels)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err)
|
||||
return
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
|
||||
`, roomID, levelsBytes)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to store power levels of %s: %v", roomID, err)
|
||||
}
|
||||
`, roomID, dbutil.JSON{Data: levels})
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err)
|
||||
}
|
||||
return
|
||||
} else if data == nil {
|
||||
return
|
||||
}
|
||||
levels = &event.PowerLevelsEventContent{}
|
||||
err = json.Unmarshal(data, levels)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err)
|
||||
return nil
|
||||
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
|
||||
err = store.
|
||||
QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&dbutil.JSON{Data: &levels})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
QueryRow(ctx, `
|
||||
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, userID).
|
||||
Scan(&powerLevel)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err)
|
||||
return powerLevel, err
|
||||
} else {
|
||||
levels, err := store.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return powerLevel
|
||||
return levels.GetUserLevel(userID), nil
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
|
@ -325,23 +298,26 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType
|
|||
}
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
QueryRow(ctx, `
|
||||
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&powerLevel)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
powerLevel = defaultValue
|
||||
}
|
||||
return powerLevel
|
||||
return powerLevel, err
|
||||
} else {
|
||||
levels, err := store.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return levels.GetEventLevel(eventType), nil
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
|
@ -351,19 +327,22 @@ func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, ev
|
|||
}
|
||||
var hasPower bool
|
||||
err := store.
|
||||
QueryRow(`SELECT
|
||||
QueryRow(ctx, `SELECT
|
||||
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
>=
|
||||
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
|
||||
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&hasPower)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue == 0
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
hasPower = defaultValue == 0
|
||||
}
|
||||
return hasPower
|
||||
return hasPower, err
|
||||
} else {
|
||||
levels, err := store.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
|
||||
}
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error {
|
||||
portalExists, err := db.TableExists(tx, "portal")
|
||||
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error {
|
||||
portalExists, err := db.TableExists(ctx, "portal")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if portal table exists")
|
||||
}
|
||||
if portalExists {
|
||||
_, err = tx.Exec(`
|
||||
_, err = db.Exec(ctx, `
|
||||
INSERT INTO mx_room_state (room_id, encryption)
|
||||
SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
|
||||
ON CONFLICT (room_id) DO UPDATE
|
||||
|
|
134
statestore.go
134
statestore.go
|
@ -7,33 +7,37 @@
|
|||
package mautrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exerrors"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// StateStore is an interface for storing basic room state information.
|
||||
type StateStore interface {
|
||||
IsInRoom(roomID id.RoomID, userID id.UserID) bool
|
||||
IsInvited(roomID id.RoomID, userID id.UserID) bool
|
||||
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
|
||||
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
|
||||
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
|
||||
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
|
||||
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
|
||||
ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership)
|
||||
IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
|
||||
IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
|
||||
IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
|
||||
GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
|
||||
TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
|
||||
SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
|
||||
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
|
||||
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
|
||||
|
||||
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
|
||||
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
|
||||
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
|
||||
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
|
||||
|
||||
SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent)
|
||||
IsEncrypted(roomID id.RoomID) bool
|
||||
SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error
|
||||
IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error)
|
||||
|
||||
GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error)
|
||||
GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error)
|
||||
}
|
||||
|
||||
func UpdateStateStore(store StateStore, evt *event.Event) {
|
||||
func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
|
||||
if store == nil || evt == nil || evt.StateKey == nil {
|
||||
return
|
||||
}
|
||||
|
@ -41,13 +45,20 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
|
|||
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.MemberEventContent:
|
||||
store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
|
||||
err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content)
|
||||
case *event.PowerLevelsEventContent:
|
||||
store.SetPowerLevels(evt.RoomID, content)
|
||||
err = store.SetPowerLevels(ctx, evt.RoomID, content)
|
||||
case *event.EncryptionEventContent:
|
||||
store.SetEncryptionEvent(evt.RoomID, content)
|
||||
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
|
||||
}
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Stringer("event_id", evt.ID).
|
||||
Str("event_type", evt.Type.Type).
|
||||
Msg("Failed to update state store")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +68,7 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
|
|||
//
|
||||
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
|
||||
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt)
|
||||
}
|
||||
|
||||
type MemoryStateStore struct {
|
||||
|
@ -81,20 +92,21 @@ func NewMemoryStateStore() StateStore {
|
|||
}
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool {
|
||||
func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) {
|
||||
store.registrationsLock.RLock()
|
||||
defer store.registrationsLock.RUnlock()
|
||||
registered, ok := store.Registrations[userID]
|
||||
return ok && registered
|
||||
return ok && registered, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) MarkRegistered(userID id.UserID) {
|
||||
func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error {
|
||||
store.registrationsLock.Lock()
|
||||
defer store.registrationsLock.Unlock()
|
||||
store.Registrations[userID] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
|
||||
func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
store.membersLock.RLock()
|
||||
members, ok := store.Members[roomID]
|
||||
store.membersLock.RUnlock()
|
||||
|
@ -104,11 +116,14 @@ func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*e
|
|||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
return members
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) {
|
||||
members := store.GetRoomMembers(roomID)
|
||||
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
|
||||
members, err := store.GetRoomMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids := make([]id.UserID, 0, len(members))
|
||||
for id := range members {
|
||||
ids = append(ids, id)
|
||||
|
@ -116,39 +131,39 @@ func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (
|
|||
return ids, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
return store.GetMember(roomID, userID).Membership
|
||||
func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) {
|
||||
return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
member, err := store.TryGetMember(ctx, roomID, userID)
|
||||
if member == nil && err == nil {
|
||||
member = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return member
|
||||
return member, err
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
|
||||
func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
|
||||
store.membersLock.RLock()
|
||||
defer store.membersLock.RUnlock()
|
||||
members, membersOk := store.Members[roomID]
|
||||
if !membersOk {
|
||||
return
|
||||
}
|
||||
member, ok = members[userID]
|
||||
member = members[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := exerrors.Must(store.GetMembership(ctx, roomID, userID))
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
|
@ -157,7 +172,7 @@ func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID,
|
|||
return false
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
|
@ -175,9 +190,10 @@ func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID,
|
|||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
|
@ -189,14 +205,15 @@ func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, mem
|
|||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
|
||||
func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
store.membersLock.Lock()
|
||||
defer store.membersLock.Unlock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
for userID, member := range members {
|
||||
for _, membership := range memberships {
|
||||
|
@ -206,46 +223,49 @@ func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships
|
|||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
|
||||
store.powerLevelsLock.Lock()
|
||||
store.PowerLevels[roomID] = levels
|
||||
store.powerLevelsLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
|
||||
store.powerLevelsLock.RLock()
|
||||
levels = store.PowerLevels[roomID]
|
||||
store.powerLevelsLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
|
||||
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
|
||||
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
|
||||
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
|
||||
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
|
||||
store.encryptionLock.Lock()
|
||||
store.Encryption[roomID] = content
|
||||
store.encryptionLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
||||
func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
|
||||
store.encryptionLock.RLock()
|
||||
defer store.encryptionLock.RUnlock()
|
||||
return store.Encryption[roomID]
|
||||
return store.Encryption[roomID], nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool {
|
||||
cfg := store.GetEncryptionEvent(roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
|
||||
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
|
||||
cfg, err := store.GetEncryptionEvent(ctx, roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue