crypto/store: don't rely on sender key for storing and lookups

* Fixes compatibility with the Store interface
* Increases the usage of "defer"s for "gs.lock.Unlock" and
  "gs.lock.RUnlock"
* Increases the usage of "golang.org/x/exp/maps"

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
pull/218/head
Sumner Evans 2024-05-14 12:29:30 -06:00
parent d0de43f395
commit a87716a358
No known key found for this signature in database
1 changed files with 83 additions and 124 deletions

View File

@ -13,6 +13,7 @@ import (
"sync"
"go.mau.fi/util/dbutil"
"golang.org/x/exp/maps"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@ -160,8 +161,8 @@ type MemoryStore struct {
Account *OlmAccount
Sessions map[id.SenderKey]OlmSessionList
GroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession
WithheldGroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent
GroupSessions map[id.RoomID]map[id.SessionID]*InboundGroupSession
WithheldGroupSessions map[id.RoomID]map[id.SessionID]*event.RoomKeyWithheldEventContent
OutGroupSessions map[id.RoomID]*OutboundGroupSession
SharedGroupSessions map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{}
MessageIndices map[messageIndexKey]messageIndexValue
@ -182,8 +183,8 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
save: saveCallback,
Sessions: make(map[id.SenderKey]OlmSessionList),
GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession),
WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent),
GroupSessions: make(map[id.RoomID]map[id.SessionID]*InboundGroupSession),
WithheldGroupSessions: make(map[id.RoomID]map[id.SessionID]*event.RoomKeyWithheldEventContent),
OutGroupSessions: make(map[id.RoomID]*OutboundGroupSession),
SharedGroupSessions: make(map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{}),
MessageIndices: make(map[messageIndexKey]messageIndexValue),
@ -197,9 +198,8 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
func (gs *MemoryStore) Flush(_ context.Context) error {
gs.lock.Lock()
err := gs.save()
gs.lock.Unlock()
return err
defer gs.lock.Unlock()
return gs.save()
}
func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) {
@ -208,31 +208,29 @@ func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) {
func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error {
gs.lock.Lock()
defer gs.lock.Unlock()
gs.Account = account
err := gs.save()
gs.lock.Unlock()
return err
return gs.save()
}
func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
sessions, ok := gs.Sessions[senderKey]
if !ok {
sessions = []*OlmSession{}
gs.Sessions[senderKey] = sessions
}
gs.lock.Unlock()
return sessions, nil
}
func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error {
gs.lock.Lock()
sessions, _ := gs.Sessions[senderKey]
defer gs.lock.Unlock()
sessions := gs.Sessions[senderKey]
gs.Sessions[senderKey] = append(sessions, session)
sort.Sort(gs.Sessions[senderKey])
err := gs.save()
gs.lock.Unlock()
return err
return gs.save()
}
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
@ -242,102 +240,86 @@ func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSe
func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool {
gs.lock.RLock()
defer gs.lock.RUnlock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
return ok && len(sessions) > 0 && !sessions[0].Expired()
}
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
if !ok || len(sessions) == 0 {
return nil, nil
}
return sessions[0], nil
}
func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
func (gs *MemoryStore) getGroupSessions(roomID id.RoomID) map[id.SessionID]*InboundGroupSession {
room, ok := gs.GroupSessions[roomID]
if !ok {
room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession)
room = make(map[id.SessionID]*InboundGroupSession)
gs.GroupSessions[roomID] = room
}
sender, ok := room[senderKey]
if !ok {
sender = make(map[id.SessionID]*InboundGroupSession)
room[senderKey] = sender
}
return sender
return room
}
func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
func (gs *MemoryStore) PutGroupSession(_ context.Context, igs *InboundGroupSession) error {
gs.lock.Lock()
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
err := gs.save()
gs.lock.Unlock()
return err
defer gs.lock.Unlock()
gs.getGroupSessions(igs.RoomID)[igs.ID()] = igs
return gs.save()
}
func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID) (*InboundGroupSession, error) {
gs.lock.Lock()
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
defer gs.lock.Unlock()
session, ok := gs.getGroupSessions(roomID)[sessionID]
if !ok {
withheld, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
gs.lock.Unlock()
withheld, ok := gs.getWithheldGroupSessions(roomID)[sessionID]
if ok {
return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheld.Code)
}
return nil, nil
}
gs.lock.Unlock()
return session, nil
}
func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID, reason string) error {
gs.lock.Lock()
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
err := gs.save()
gs.lock.Unlock()
return err
defer gs.lock.Unlock()
delete(gs.getGroupSessions(roomID), sessionID)
return gs.save()
}
func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
var sessionIDs []id.SessionID
if roomID != "" && senderKey != "" {
sessions := gs.getGroupSessions(roomID, senderKey)
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
delete(sessions, sessionID)
sessions := gs.getGroupSessions(roomID)
for sessionID, session := range sessions {
if session.SenderKey == senderKey {
sessionIDs = append(sessionIDs, sessionID)
delete(sessions, sessionID)
}
}
} else if senderKey != "" {
for _, room := range gs.GroupSessions {
sessions, ok := room[senderKey]
if ok {
for sessionID := range sessions {
for sessionID, session := range room {
if session.SenderKey == senderKey {
sessionIDs = append(sessionIDs, sessionID)
delete(room, sessionID)
}
delete(room, senderKey)
}
}
} else if roomID != "" {
room, ok := gs.GroupSessions[roomID]
if ok {
for senderKey := range room {
sessions := room[senderKey]
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
}
}
delete(gs.GroupSessions, roomID)
}
sessionIDs = maps.Keys(gs.GroupSessions[roomID])
delete(gs.GroupSessions, roomID)
} else {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
err := gs.save()
gs.lock.Unlock()
return sessionIDs, err
return sessionIDs, gs.save()
}
func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) {
@ -348,32 +330,26 @@ func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.Sess
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID) map[id.SessionID]*event.RoomKeyWithheldEventContent {
room, ok := gs.WithheldGroupSessions[roomID]
if !ok {
room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent)
room = make(map[id.SessionID]*event.RoomKeyWithheldEventContent)
gs.WithheldGroupSessions[roomID] = room
}
sender, ok := room[senderKey]
if !ok {
sender = make(map[id.SessionID]*event.RoomKeyWithheldEventContent)
room[senderKey] = sender
}
return sender
return room
}
func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error {
gs.lock.Lock()
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
err := gs.save()
gs.lock.Unlock()
return err
defer gs.lock.Unlock()
gs.getWithheldGroupSessions(content.RoomID)[content.SessionID] = &content
return gs.save()
}
func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
gs.lock.Lock()
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
gs.lock.Unlock()
defer gs.lock.Unlock()
session, ok := gs.getWithheldGroupSessions(roomID)[sessionID]
if !ok {
return nil, nil
}
@ -387,51 +363,38 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room
if !ok {
return nil
}
var result []*InboundGroupSession
for _, sessions := range room {
for _, session := range sessions {
result = append(result, session)
}
}
return dbutil.NewSliceIter[*InboundGroupSession](result)
return dbutil.NewSliceIter(maps.Values(room))
}
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] {
gs.lock.Lock()
defer gs.lock.Unlock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
for _, sessions := range room {
for _, session := range sessions {
result = append(result, session)
}
}
result = append(result, maps.Values(room)...)
}
gs.lock.Unlock()
return dbutil.NewSliceIter[*InboundGroupSession](result)
return dbutil.NewSliceIter(result)
}
func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
gs.lock.Lock()
defer gs.lock.Unlock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
for _, sessions := range room {
for _, session := range sessions {
if session.KeyBackupVersion != version {
result = append(result, session)
}
for _, session := range room {
if session.KeyBackupVersion != version {
result = append(result, session)
}
}
}
gs.lock.Unlock()
return dbutil.NewSliceIter[*InboundGroupSession](result)
return dbutil.NewSliceIter(result)
}
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {
gs.lock.Lock()
defer gs.lock.Unlock()
gs.OutGroupSessions[session.RoomID] = session
err := gs.save()
gs.lock.Unlock()
return err
return gs.save()
}
func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error {
@ -441,8 +404,8 @@ func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *Outbound
func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
session, ok := gs.OutGroupSessions[roomID]
gs.lock.RUnlock()
if !ok {
return nil, nil
}
@ -451,18 +414,18 @@ func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.Room
func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error {
gs.lock.Lock()
defer gs.lock.Unlock()
session, ok := gs.OutGroupSessions[roomID]
if !ok || session == nil {
gs.lock.Unlock()
return nil
}
delete(gs.OutGroupSessions, roomID)
gs.lock.Unlock()
return nil
}
func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) error {
gs.lock.Lock()
defer gs.lock.Unlock()
if _, ok := gs.SharedGroupSessions[userID]; !ok {
gs.SharedGroupSessions[userID] = make(map[id.IdentityKey]map[id.SessionID]struct{})
@ -475,7 +438,6 @@ func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID
identities[identityKey][sessionID] = struct{}{}
gs.lock.Unlock()
return nil
}
@ -521,11 +483,11 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
if !ok {
devices = nil
}
gs.lock.RUnlock()
return devices, nil
}
@ -560,30 +522,30 @@ func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, iden
func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error {
gs.lock.Lock()
defer gs.lock.Unlock()
devices, ok := gs.Devices[userID]
if !ok {
devices = make(map[id.DeviceID]*id.Device)
gs.Devices[userID] = devices
}
devices[device.DeviceID] = device
err := gs.save()
gs.lock.Unlock()
return err
return gs.save()
}
func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
gs.lock.Lock()
defer gs.lock.Unlock()
gs.Devices[userID] = devices
err := gs.save()
if err == nil {
delete(gs.OutdatedUsers, userID)
}
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
var ptr int
for _, userID := range users {
_, ok := gs.Devices[userID]
@ -592,33 +554,33 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID)
ptr++
}
}
gs.lock.RUnlock()
return users[:ptr], nil
}
func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error {
gs.lock.Lock()
defer gs.lock.Unlock()
for _, userID := range users {
if _, ok := gs.Devices[userID]; ok {
gs.OutdatedUsers[userID] = struct{}{}
}
}
gs.lock.Unlock()
return nil
}
func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
users := make([]id.UserID, 0, len(gs.OutdatedUsers))
for userID := range gs.OutdatedUsers {
users = append(users, userID)
}
gs.lock.RUnlock()
return users, nil
}
func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
gs.lock.RLock()
defer gs.lock.RUnlock()
userKeys, ok := gs.CrossSigningKeys[userID]
if !ok {
userKeys = make(map[id.CrossSigningUsage]id.CrossSigningKey)
@ -635,7 +597,6 @@ func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, u
}
}
err := gs.save()
gs.lock.RUnlock()
return err
}
@ -651,6 +612,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID)
func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
gs.lock.RLock()
defer gs.lock.RUnlock()
signedUserSigs, ok := gs.KeySignatures[signedUserID]
if !ok {
signedUserSigs = make(map[id.Ed25519]map[id.UserID]map[id.Ed25519]string)
@ -667,9 +629,7 @@ func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, s
signaturesForKey[signerUserID] = signedByUser
}
signedByUser[signerKey] = signature
err := gs.save()
gs.lock.RUnlock()
return err
return gs.save()
}
func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
@ -700,8 +660,9 @@ func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key
}
func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
var count int64
gs.lock.RLock()
defer gs.lock.RUnlock()
var count int64
for _, userSigs := range gs.KeySignatures {
for _, keySigs := range userSigs {
if signedBySigner, ok := keySigs[userID]; ok {
@ -712,27 +673,25 @@ func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID,
}
}
}
gs.lock.RUnlock()
return count, nil
}
func (gs *MemoryStore) PutSecret(_ context.Context, name id.Secret, value string) error {
gs.lock.Lock()
defer gs.lock.Unlock()
gs.Secrets[name] = value
gs.lock.Unlock()
return nil
}
func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (value string, _ error) {
func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (string, error) {
gs.lock.RLock()
value = gs.Secrets[name]
gs.lock.RUnlock()
return
defer gs.lock.RUnlock()
return gs.Secrets[name], nil
}
func (gs *MemoryStore) DeleteSecret(_ context.Context, name id.Secret) error {
gs.lock.Lock()
defer gs.lock.Unlock()
delete(gs.Secrets, name)
gs.lock.Unlock()
return nil
}