mirror of https://github.com/mautrix/go.git
crypto: fix usages of Store interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>pull/218/head
parent
a87716a358
commit
de0347db00
|
@ -192,7 +192,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
mach.megolmDecryptLock.Lock()
|
||||
defer mach.megolmDecryptLock.Unlock()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SessionID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
|
@ -254,7 +254,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
Int("max_messages", sess.MaxMessages).
|
||||
Logger()
|
||||
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete fully used session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
|
@ -265,14 +265,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
|||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
|
|
|
@ -177,7 +177,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
|
|||
MaxMessages: maxMessages,
|
||||
KeyBackupVersion: version,
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store new inbound group session: %w", err)
|
||||
}
|
||||
|
|
|
@ -113,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
|
|||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
// We already have an equivalent or better session in the store, so don't override it.
|
||||
return false, nil
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store imported session: %w", err)
|
||||
}
|
||||
|
|
|
@ -184,12 +184,12 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
MaxMessages: maxMessages,
|
||||
IsScheduled: content.IsScheduled,
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
// We already have an equivalent or better session in the store, so don't override it.
|
||||
return false
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store new inbound group session")
|
||||
return false
|
||||
|
@ -308,7 +308,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User
|
|||
return
|
||||
}
|
||||
|
||||
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
|
||||
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Requested group session not available")
|
||||
|
@ -365,7 +365,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us
|
|||
Int("first_message_index", content.FirstMessageIndex).
|
||||
Logger()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, content.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Acked group session was already redacted")
|
||||
|
@ -385,7 +385,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us
|
|||
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
|
||||
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
|
||||
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, content.SessionID, "outbound session acked")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact group session")
|
||||
}
|
||||
|
|
|
@ -517,7 +517,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
|
|||
Msg("Mismatched session ID while creating inbound group session")
|
||||
return fmt.Errorf("mismatched session ID while creating inbound group session")
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs)
|
||||
if err != nil {
|
||||
log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
|
||||
return fmt.Errorf("failed to store new inbound group session: %w", err)
|
||||
|
@ -557,7 +557,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se
|
|||
}
|
||||
mach.keyWaitersLock.Unlock()
|
||||
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID)
|
||||
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
|
||||
return true
|
||||
}
|
||||
|
@ -565,7 +565,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se
|
|||
case <-ch:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID)
|
||||
// Check if the session somehow appeared in the store without telling us
|
||||
// We accept withheld sessions as received, as then the decryption attempt will show the error.
|
||||
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
|
||||
|
|
|
@ -58,7 +58,7 @@ func TestRatchetMegolmSession(t *testing.T) {
|
|||
mach := newMachine(t, "user1")
|
||||
outSess, err := mach.newOutboundGroupSession(context.TODO(), "meow")
|
||||
assert.NoError(t, err)
|
||||
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID())
|
||||
inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex())
|
||||
err = inSess.RatchetTo(10)
|
||||
|
@ -130,7 +130,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("Error creating inbound megolm session: %v", err)
|
||||
}
|
||||
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil {
|
||||
if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil {
|
||||
t.Errorf("Error storing inbound megolm session: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) {
|
|||
RoomID: "room1",
|
||||
}
|
||||
|
||||
err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs)
|
||||
err = store.PutGroupSession(context.TODO(), igs)
|
||||
if err != nil {
|
||||
t.Errorf("Error storing inbound group session: %v", err)
|
||||
}
|
||||
|
||||
retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID())
|
||||
retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID())
|
||||
if err != nil {
|
||||
t.Errorf("Error retrieving inbound group session: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue