crypto: fix usages of Store interface

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
pull/218/head
Sumner Evans 2024-05-14 12:31:46 -06:00
parent a87716a358
commit de0347db00
No known key found for this signature in database
7 changed files with 19 additions and 19 deletions

View File

@ -192,7 +192,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
mach.megolmDecryptLock.Lock()
defer mach.megolmDecryptLock.Unlock()
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SessionID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
@ -254,7 +254,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
Int("max_messages", sess.MaxMessages).
Logger()
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, RatchetError
@ -265,14 +265,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
} else {

View File

@ -177,7 +177,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
MaxMessages: maxMessages,
KeyBackupVersion: version,
}
err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
return fmt.Errorf("failed to store new inbound group session: %w", err)
}

View File

@ -113,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
// We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs)
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}

View File

@ -184,12 +184,12 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
// We already have an equivalent or better session in the store, so don't override it.
return false
}
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
@ -308,7 +308,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
return
}
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
@ -365,7 +365,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us
Int("first_message_index", content.FirstMessageIndex).
Logger()
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, content.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Acked group session was already redacted")
@ -385,7 +385,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, content.SessionID, "outbound session acked")
if err != nil {
log.Err(err).Msg("Failed to redact group session")
}

View File

@ -517,7 +517,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
Msg("Mismatched session ID while creating inbound group session")
return fmt.Errorf("mismatched session ID while creating inbound group session")
}
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return fmt.Errorf("failed to store new inbound group session: %w", err)
@ -557,7 +557,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se
}
mach.keyWaitersLock.Unlock()
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID)
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
return true
}
@ -565,7 +565,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se
case <-ch:
return true
case <-time.After(timeout):
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID)
// Check if the session somehow appeared in the store without telling us
// We accept withheld sessions as received, as then the decryption attempt will show the error.
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)

View File

@ -58,7 +58,7 @@ func TestRatchetMegolmSession(t *testing.T) {
mach := newMachine(t, "user1")
outSess, err := mach.newOutboundGroupSession(context.TODO(), "meow")
assert.NoError(t, err)
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID())
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID())
require.NoError(t, err)
assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex())
err = inSess.RatchetTo(10)
@ -130,7 +130,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
if err != nil {
t.Errorf("Error creating inbound megolm session: %v", err)
}
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil {
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil {
t.Errorf("Error storing inbound megolm session: %v", err)
}
}

View File

@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) {
RoomID: "room1",
}
err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs)
err = store.PutGroupSession(context.TODO(), igs)
if err != nil {
t.Errorf("Error storing inbound group session: %v", err)
}
retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID())
retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID())
if err != nil {
t.Errorf("Error retrieving inbound group session: %v", err)
}