Add contexts to event handlers

pull/156/head
Tulir Asokan 2024-01-13 18:56:12 +02:00
parent a3883fcf6f
commit 308e3583b0
12 changed files with 196 additions and 186 deletions

View File

@ -4,6 +4,7 @@
functions.
* **Breaking change *(everything)*** Added context parameters to all functions
(started by [@recht] in [#144]).
* *(client)* Moved `EventSource` to `event.Source`.
* *(crypto)* Added experimental pure Go Olm implementation to replace libolm
(thanks to [@DerLukas15] in [#106]).
* You can use the `goolm` build tag to the new implementation.

View File

@ -7,6 +7,7 @@
package appservice
import (
"context"
"encoding/json"
"runtime/debug"
"time"
@ -25,9 +26,9 @@ const (
Sync
)
type EventHandler = func(evt *event.Event)
type OTKHandler = func(otk *mautrix.OTKCount)
type DeviceListHandler = func(lists *mautrix.DeviceLists, since string)
type EventHandler = func(ctx context.Context, evt *event.Event)
type OTKHandler = func(ctx context.Context, otk *mautrix.OTKCount)
type DeviceListHandler = func(ctx context.Context, lists *mautrix.DeviceLists, since string)
type EventProcessor struct {
ExecMode ExecMode
@ -97,34 +98,34 @@ func (ep *EventProcessor) recoverFunc(data interface{}) {
}
}
func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) {
func (ep *EventProcessor) callHandler(ctx context.Context, handler EventHandler, evt *event.Event) {
defer ep.recoverFunc(evt)
handler(evt)
handler(ctx, evt)
}
func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) {
func (ep *EventProcessor) callOTKHandler(ctx context.Context, handler OTKHandler, otk *mautrix.OTKCount) {
defer ep.recoverFunc(otk)
handler(otk)
handler(ctx, otk)
}
func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) {
func (ep *EventProcessor) callDeviceListHandler(ctx context.Context, handler DeviceListHandler, dl *mautrix.DeviceLists) {
defer ep.recoverFunc(dl)
handler(dl, "")
handler(ctx, dl, "")
}
func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) {
func (ep *EventProcessor) DispatchOTK(ctx context.Context, otk *mautrix.OTKCount) {
for _, handler := range ep.otkHandlers {
go ep.callOTKHandler(handler, otk)
go ep.callOTKHandler(ctx, handler, otk)
}
}
func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) {
func (ep *EventProcessor) DispatchDeviceList(ctx context.Context, dl *mautrix.DeviceLists) {
for _, handler := range ep.deviceListHandlers {
go ep.callDeviceListHandler(handler, dl)
go ep.callDeviceListHandler(ctx, handler, dl)
}
}
func (ep *EventProcessor) Dispatch(evt *event.Event) {
func (ep *EventProcessor) Dispatch(ctx context.Context, evt *event.Event) {
handlers, ok := ep.handlers[evt.Type]
if !ok {
return
@ -132,25 +133,25 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) {
switch ep.ExecMode {
case AsyncHandlers:
for _, handler := range handlers {
go ep.callHandler(handler, evt)
go ep.callHandler(ctx, handler, evt)
}
case AsyncLoop:
go func() {
for _, handler := range handlers {
ep.callHandler(handler, evt)
ep.callHandler(ctx, handler, evt)
}
}()
case Sync:
if ep.ExecSyncWarnTime == 0 && ep.ExecSyncTimeout == 0 {
for _, handler := range handlers {
ep.callHandler(handler, evt)
ep.callHandler(ctx, handler, evt)
}
return
}
doneChan := make(chan struct{})
go func() {
for _, handler := range handlers {
ep.callHandler(handler, evt)
ep.callHandler(ctx, handler, evt)
}
close(doneChan)
}()
@ -172,35 +173,35 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) {
}
}
}
func (ep *EventProcessor) startEvents() {
func (ep *EventProcessor) startEvents(ctx context.Context) {
for {
select {
case evt := <-ep.as.Events:
ep.Dispatch(evt)
ep.Dispatch(ctx, evt)
case <-ep.stop:
return
}
}
}
func (ep *EventProcessor) startEncryption() {
func (ep *EventProcessor) startEncryption(ctx context.Context) {
for {
select {
case evt := <-ep.as.ToDeviceEvents:
ep.Dispatch(evt)
ep.Dispatch(ctx, evt)
case otk := <-ep.as.OTKCounts:
ep.DispatchOTK(otk)
ep.DispatchOTK(ctx, otk)
case dl := <-ep.as.DeviceLists:
ep.DispatchDeviceList(dl)
ep.DispatchDeviceList(ctx, dl)
case <-ep.stop:
return
}
}
}
func (ep *EventProcessor) Start() {
go ep.startEvents()
go ep.startEncryption()
func (ep *EventProcessor) Start(ctx context.Context) {
go ep.startEvents(ctx)
go ep.startEncryption(ctx)
}
func (ep *EventProcessor) Stop() {

View File

@ -214,7 +214,7 @@ type Bridge struct {
}
type Crypto interface {
HandleMemberEvent(*event.Event)
HandleMemberEvent(context.Context, *event.Event)
Decrypt(context.Context, *event.Event) (*event.Event, error)
Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error
WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
@ -321,7 +321,7 @@ func (br *Bridge) ensureConnection(ctx context.Context) {
if errors.Is(err, mautrix.MUnknownToken) {
br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
} else if errors.Is(err, mautrix.MExclusive) {
br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?")
br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?")
} else {
br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error")
}
@ -674,7 +674,7 @@ func (br *Bridge) start() {
}
br.ZLog.Debug().Msg("Checking connection to homeserver")
ctx := context.Background()
ctx := br.ZLog.WithContext(context.Background())
br.ensureConnection(ctx)
go br.fetchMediaConfig(ctx)
@ -687,7 +687,7 @@ func (br *Bridge) start() {
}
br.ZLog.Debug().Msg("Starting event processor")
br.EventProcessor.Start()
br.EventProcessor.Start(ctx)
go br.UpdateBotProfile(ctx)
if br.Crypto != nil {

View File

@ -425,10 +425,10 @@ func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID)
}
}
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) {
helper.lock.RLock()
defer helper.lock.RUnlock()
helper.mach.HandleMemberEvent(0, evt)
helper.mach.HandleMemberEvent(ctx, evt)
}
// ShareKeys uploads the given number of one-time-keys to the server.
@ -440,7 +440,7 @@ type cryptoSyncer struct {
*crypto.OlmMachine
}
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error {
done := make(chan struct{})
go func() {
defer func() {
@ -454,7 +454,7 @@ func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string
done <- struct{}{}
}()
syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling")
syncer.ProcessSyncResponse(resp, since)
syncer.ProcessSyncResponse(ctx, resp, since)
syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response")
}()
select {

View File

@ -68,13 +68,13 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler {
return handler
}
func (mx *MatrixHandler) sendBridgeCheckpoint(evt *event.Event) {
func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) {
if !evt.Mautrix.CheckpointSent {
go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0)
}
}
func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) {
defer mx.TrackEventDuration(evt.Type)()
if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 {
return
@ -87,7 +87,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
Msg("Encryption was enabled in room")
portal.MarkEncrypted()
if portal.IsPrivateChat() {
err := mx.as.BotIntent().EnsureJoined(context.TODO(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client})
err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client})
if err != nil {
mx.log.Err(err).
Str("room_id", evt.RoomID.String()).
@ -232,15 +232,14 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event
}
}
func (mx *MatrixHandler) HandleMembership(evt *event.Event) {
func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) {
if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) {
return
}
defer mx.TrackEventDuration(evt.Type)()
ctx := context.TODO()
if mx.bridge.Crypto != nil {
mx.bridge.Crypto.HandleMemberEvent(evt)
mx.bridge.Crypto.HandleMemberEvent(ctx, evt)
}
log := mx.log.With().
@ -300,7 +299,7 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) {
// TODO kicking/inviting non-ghost users users
}
func (mx *MatrixHandler) HandleRoomMetadata(evt *event.Event) {
func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) {
defer mx.TrackEventDuration(evt.Type)()
if mx.shouldIgnoreEvent(evt) {
return
@ -469,20 +468,20 @@ func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *e
mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount)
decrypted.Mautrix.CheckpointSent = true
decrypted.Mautrix.DecryptionDuration = duration
mx.bridge.EventProcessor.Dispatch(decrypted)
decrypted.Mautrix.EventSource |= event.SourceDecrypted
mx.bridge.EventProcessor.Dispatch(ctx, decrypted)
if errorEventID != "" {
_, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID)
}
}
func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) {
defer mx.TrackEventDuration(evt.Type)()
if mx.shouldIgnoreEvent(evt) {
return
}
content := evt.Content.AsEncrypted()
ctx := context.TODO()
log := mx.log.With().
log := zerolog.Ctx(ctx).With().
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
@ -546,14 +545,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev
mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart))
}
func (mx *MatrixHandler) HandleMessage(evt *event.Event) {
func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) {
defer mx.TrackEventDuration(evt.Type)()
log := mx.log.With().
log := zerolog.Ctx(ctx).With().
Str("event_id", evt.ID.String()).
Str("room_id", evt.RoomID.String()).
Str("sender", evt.Sender.String()).
Logger()
ctx := log.WithContext(context.TODO())
ctx = log.WithContext(ctx)
if mx.shouldIgnoreEvent(evt) {
return
} else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require {
@ -604,7 +603,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) {
}
}
func (mx *MatrixHandler) HandleReaction(evt *event.Event) {
func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) {
defer mx.TrackEventDuration(evt.Type)()
if mx.shouldIgnoreEvent(evt) {
return
@ -623,7 +622,7 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) {
}
}
func (mx *MatrixHandler) HandleRedaction(evt *event.Event) {
func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) {
defer mx.TrackEventDuration(evt.Type)()
if mx.shouldIgnoreEvent(evt) {
return
@ -642,7 +641,7 @@ func (mx *MatrixHandler) HandleRedaction(evt *event.Event) {
}
}
func (mx *MatrixHandler) HandleReceipt(evt *event.Event) {
func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) {
portal := mx.bridge.Child.GetIPortal(evt.RoomID)
if portal == nil {
return
@ -676,7 +675,7 @@ func (mx *MatrixHandler) HandleReceipt(evt *event.Event) {
}
}
func (mx *MatrixHandler) HandleTyping(evt *event.Event) {
func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) {
portal := mx.bridge.Child.GetIPortal(evt.RoomID)
if portal == nil {
return

View File

@ -236,8 +236,11 @@ func (cli *Client) SyncWithContext(ctx context.Context) error {
// Save the token now *before* processing it. This means it's possible
// to not process some events, but it means that we won't get constantly stuck processing
// a malformed/buggy event which keeps making us panic.
cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch)
if err = cli.Syncer.ProcessResponse(resSync, nextBatch); err != nil {
err = cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch)
if err != nil {
return err
}
if err = cli.Syncer.ProcessResponse(ctx, resSync, nextBatch); err != nil {
return err
}

View File

@ -245,17 +245,18 @@ var NoSessionFound = crypto.NoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) {
func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Event) {
if helper == nil {
return
}
content := evt.Content.AsEncrypted()
// TODO use context log instead of helper?
log := helper.log.With().
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
log.Debug().Msg("Decrypting received event")
ctx := log.WithContext(context.TODO())
ctx = log.WithContext(ctx)
decrypted, err := helper.Decrypt(ctx, evt)
if errors.Is(err, NoSessionFound) {
@ -266,7 +267,7 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = helper.Decrypt(ctx, evt)
} else {
go helper.waitLongerForSession(ctx, log, src, evt)
go helper.waitLongerForSession(ctx, log, evt)
return
}
}
@ -275,11 +276,12 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) {
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted)
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
decrypted.Mautrix.EventSource |= event.SourceDecrypted
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted)
}
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
@ -309,7 +311,7 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID
}
}
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
@ -329,7 +331,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolo
return
}
helper.postDecrypt(src, decrypted)
helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {

View File

@ -197,9 +197,9 @@ func (mach *OlmMachine) OwnIdentity() *id.Device {
}
type asEventProcessor interface {
On(evtType event.Type, handler func(evt *event.Event))
OnOTK(func(otk *mautrix.OTKCount))
OnDeviceList(func(lists *mautrix.DeviceLists, since string))
On(evtType event.Type, handler func(ctx context.Context, evt *event.Event))
OnOTK(func(ctx context.Context, otk *mautrix.OTKCount))
OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string))
}
func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
@ -220,7 +220,7 @@ func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
}
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists, since string) {
if len(dl.Changed) > 0 {
traceID := time.Now().Format("15:04:05.000000")
mach.Log.Debug().
@ -228,15 +228,15 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string)
Interface("changes", dl.Changed).
Msg("Device list changes in /sync")
if mach.DisableKeyFetching {
mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed)
mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed)
} else {
mach.FetchKeys(context.TODO(), dl.Changed, false)
mach.FetchKeys(ctx, dl.Changed, false)
}
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
}
}
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) {
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Warn().
@ -250,7 +250,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
if otkCount.SignedCurve25519 < int(minCount) {
traceID := time.Now().Format("15:04:05.000000")
log := mach.Log.With().Str("trace_id", traceID).Logger()
ctx := log.WithContext(context.TODO())
ctx = log.WithContext(ctx)
log.Debug().
Int("keys_left", otkCount.Curve25519).
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
@ -268,8 +268,8 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
// This can be easily registered into a mautrix client using .OnSync():
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
mach.HandleDeviceLists(&resp.DeviceLists, since)
func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) bool {
mach.HandleDeviceLists(ctx, &resp.DeviceLists, since)
for _, evt := range resp.ToDevice.Events {
evt.Type.Class = event.ToDeviceEventType
@ -278,10 +278,10 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
continue
}
mach.HandleToDeviceEvent(evt)
mach.HandleToDeviceEvent(ctx, evt)
}
mach.HandleOTKCounts(&resp.DeviceOTKCount)
mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
return true
}
@ -290,8 +290,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
// Currently this is not automatically called, so you must add a listener yourself:
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
ctx := context.TODO()
func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) {
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
Msg("Failed to check if room is encrypted to handle member event")
@ -331,7 +330,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
// don't need to add any custom handlers if you use that method.
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug().
@ -341,12 +340,13 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
return
}
traceID := time.Now().Format("15:04:05.000000")
// TODO use context log?
log := mach.Log.With().
Str("trace_id", traceID).
Str("sender", evt.Sender.String()).
Str("type", evt.Type.Type).
Logger()
ctx := log.WithContext(context.TODO())
ctx = log.WithContext(ctx)
if evt.Type != event.ToDeviceEncrypted {
log.Debug().Msg("Starting handling to-device event")
}

View File

@ -105,6 +105,8 @@ func (evt *Event) MarshalJSON() ([]byte, error) {
}
type MautrixInfo struct {
EventSource Source
TrustState id.TrustState
ForwardedKeys bool
WasEncrypted bool

72
event/eventsource.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package event
import (
"fmt"
)
// Source represents the part of the sync response that an event came from.
type Source int
const (
SourcePresence Source = 1 << iota
SourceJoin
SourceInvite
SourceLeave
SourceAccountData
SourceTimeline
SourceState
SourceEphemeral
SourceToDevice
SourceDecrypted
)
const primaryTypes = SourcePresence | SourceAccountData | SourceToDevice | SourceTimeline | SourceState
const roomSections = SourceJoin | SourceInvite | SourceLeave
const roomableTypes = SourceAccountData | SourceTimeline | SourceState
const encryptableTypes = roomableTypes | SourceToDevice
func (es Source) String() string {
var typeName string
switch es & primaryTypes {
case SourcePresence:
typeName = "presence"
case SourceAccountData:
typeName = "account data"
case SourceToDevice:
typeName = "to-device"
case SourceTimeline:
typeName = "timeline"
case SourceState:
typeName = "state"
default:
return fmt.Sprintf("unknown (%d)", es)
}
if es&roomableTypes != 0 {
switch es & roomSections {
case SourceJoin:
typeName = "joined room " + typeName
case SourceInvite:
typeName = "invited room " + typeName
case SourceLeave:
typeName = "left room " + typeName
default:
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
es &^= roomSections
}
if es&encryptableTypes != 0 && es&SourceDecrypted != 0 {
typeName += " (decrypted)"
es &^= SourceDecrypted
}
es &^= primaryTypes
if es != 0 {
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
return typeName
}

View File

@ -67,8 +67,8 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
// client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler)
//
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt)
func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) {
UpdateStateStore(ctx, cli.StateStore, evt)
}
type MemoryStateStore struct {

140
sync.go
View File

@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@ -7,6 +7,7 @@
package mautrix
import (
"context"
"errors"
"fmt"
"runtime/debug"
@ -16,78 +17,17 @@ import (
"maunium.net/go/mautrix/id"
)
// EventSource represents the part of the sync response that an event came from.
type EventSource int
const (
EventSourcePresence EventSource = 1 << iota
EventSourceJoin
EventSourceInvite
EventSourceLeave
EventSourceAccountData
EventSourceTimeline
EventSourceState
EventSourceEphemeral
EventSourceToDevice
EventSourceDecrypted
)
const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState
const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave
const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState
const encryptableTypes = roomableTypes | EventSourceToDevice
func (es EventSource) String() string {
var typeName string
switch es & primaryTypes {
case EventSourcePresence:
typeName = "presence"
case EventSourceAccountData:
typeName = "account data"
case EventSourceToDevice:
typeName = "to-device"
case EventSourceTimeline:
typeName = "timeline"
case EventSourceState:
typeName = "state"
default:
return fmt.Sprintf("unknown (%d)", es)
}
if es&roomableTypes != 0 {
switch es & roomSections {
case EventSourceJoin:
typeName = "joined room " + typeName
case EventSourceInvite:
typeName = "invited room " + typeName
case EventSourceLeave:
typeName = "left room " + typeName
default:
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
es &^= roomSections
}
if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 {
typeName += " (decrypted)"
es &^= EventSourceDecrypted
}
es &^= primaryTypes
if es != 0 {
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
return typeName
}
// EventHandler handles a single event from a sync response.
type EventHandler func(source EventSource, evt *event.Event)
type EventHandler func(ctx context.Context, evt *event.Event)
// SyncHandler handles a whole sync response. If the return value is false, handling will be stopped completely.
type SyncHandler func(resp *RespSync, since string) bool
type SyncHandler func(ctx context.Context, resp *RespSync, since string) bool
// Syncer is an interface that must be satisfied in order to do /sync requests on a client.
type Syncer interface {
// ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response.
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently.
ProcessResponse(resp *RespSync, since string) error
ProcessResponse(ctx context.Context, resp *RespSync, since string) error
// OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently.
OnFailedSync(res *RespSync, err error) (time.Duration, error)
// GetFilterJSON for the given user ID. NOT the filter ID.
@ -101,7 +41,7 @@ type ExtensibleSyncer interface {
}
type DispatchableSyncer interface {
Dispatch(source EventSource, evt *event.Event)
Dispatch(ctx context.Context, evt *event.Event)
}
// DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively
@ -144,7 +84,7 @@ func NewDefaultSyncer() *DefaultSyncer {
// ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of
// unrepeating events. Returns a fatal error if a listener panics.
func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) {
func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, since string) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
@ -152,38 +92,38 @@ func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error)
}()
for _, listener := range s.syncListeners {
if !listener(res, since) {
if !listener(ctx, res, since) {
return
}
}
s.processSyncEvents("", res.ToDevice.Events, EventSourceToDevice)
s.processSyncEvents("", res.Presence.Events, EventSourcePresence)
s.processSyncEvents("", res.AccountData.Events, EventSourceAccountData)
s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice)
s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence)
s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData)
for roomID, roomData := range res.Rooms.Join {
s.processSyncEvents(roomID, roomData.State.Events, EventSourceJoin|EventSourceState)
s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceJoin|EventSourceTimeline)
s.processSyncEvents(roomID, roomData.Ephemeral.Events, EventSourceJoin|EventSourceEphemeral)
s.processSyncEvents(roomID, roomData.AccountData.Events, EventSourceJoin|EventSourceAccountData)
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState)
s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline)
s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral)
s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData)
}
for roomID, roomData := range res.Rooms.Invite {
s.processSyncEvents(roomID, roomData.State.Events, EventSourceInvite|EventSourceState)
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState)
}
for roomID, roomData := range res.Rooms.Leave {
s.processSyncEvents(roomID, roomData.State.Events, EventSourceLeave|EventSourceState)
s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceLeave|EventSourceTimeline)
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState)
s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline)
}
return
}
func (s *DefaultSyncer) processSyncEvents(roomID id.RoomID, events []*event.Event, source EventSource) {
func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) {
for _, evt := range events {
s.processSyncEvent(roomID, evt, source)
s.processSyncEvent(ctx, roomID, evt, source)
}
}
func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, source EventSource) {
func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) {
evt.RoomID = roomID
// Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer.
@ -191,11 +131,11 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
switch {
case evt.StateKey != nil:
evt.Type.Class = event.StateEventType
case source == EventSourcePresence, source&EventSourceEphemeral != 0:
case source == event.SourcePresence, source&event.SourceEphemeral != 0:
evt.Type.Class = event.EphemeralEventType
case source&EventSourceAccountData != 0:
case source&event.SourceAccountData != 0:
evt.Type.Class = event.AccountDataEventType
case source == EventSourceToDevice:
case source == event.SourceToDevice:
evt.Type.Class = event.ToDeviceEventType
default:
evt.Type.Class = event.MessageEventType
@ -208,17 +148,18 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
}
}
s.Dispatch(source, evt)
evt.Mautrix.EventSource = source
s.Dispatch(ctx, evt)
}
func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) {
func (s *DefaultSyncer) Dispatch(ctx context.Context, evt *event.Event) {
for _, fn := range s.globalListeners {
fn(source, evt)
fn(ctx, evt)
}
listeners, exists := s.listeners[evt.Type]
if exists {
for _, fn := range listeners {
fn(source, evt)
fn(ctx, evt)
}
}
}
@ -266,31 +207,18 @@ func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter {
return s.FilterJSON
}
// OldEventIgnorer is a utility struct for bots to ignore events from before the bot joined the room.
//
// Deprecated: Use Client.DontProcessOldEvents instead.
type OldEventIgnorer struct {
UserID id.UserID
}
func (oei *OldEventIgnorer) Register(syncer ExtensibleSyncer) {
syncer.OnSync(oei.DontProcessOldEvents)
}
func (oei *OldEventIgnorer) DontProcessOldEvents(resp *RespSync, since string) bool {
return dontProcessOldEvents(oei.UserID, resp, since)
}
// DontProcessOldEvents is a sync handler that removes rooms that the user just joined.
// It's meant for bots to ignore events from before the bot joined the room.
//
// To use it, register it with your Syncer, e.g.:
//
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.DontProcessOldEvents)
func (cli *Client) DontProcessOldEvents(resp *RespSync, since string) bool {
func (cli *Client) DontProcessOldEvents(_ context.Context, resp *RespSync, since string) bool {
return dontProcessOldEvents(cli.UserID, resp, since)
}
var _ SyncHandler = (*Client)(nil).DontProcessOldEvents
func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
if since == "" {
return false
@ -327,7 +255,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// To use it, register it with your Syncer, e.g.:
//
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool {
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
var inviteState []event.StrippedState
var inviteEvt *event.Event
@ -352,3 +280,5 @@ func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool {
}
return true
}
var _ SyncHandler = (*Client)(nil).MoveInviteState