Compare commits

...

44 Commits

Author SHA1 Message Date
Sumner Evans 3b8453bd15
fixup! olm/account: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 21:11:28 -06:00
Sumner Evans 135eccbaa0
fixup! olm/account: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 20:59:46 -06:00
Sumner Evans 44b04d50a4
olm/account: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 20:42:38 -06:00
Sumner Evans 645348695a
fixup! olm/inboundgroupsession: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 20:08:27 -06:00
Sumner Evans 2ded86695b
fixup! olm/inboundgroupsession: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 20:02:06 -06:00
Sumner Evans d8d05ce0a7
olm/inboundgroupsession: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 19:57:10 -06:00
Sumner Evans 059632c845
fixup! olm/outboundgroupsession: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 19:41:01 -06:00
Sumner Evans e768e5fa53
olm/olm.go -> olm/olm_libolm.go
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 19:35:23 -06:00
Sumner Evans 04c7efc0c0
olm/outboundgroupsession: make an interface
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 18:56:36 -06:00
Sumner Evans 48edb28c1f
wip fuzz
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 18:56:36 -06:00
Sumner Evans 0d0f04d51d
goolm/session: use string builder in Describe
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 18:56:36 -06:00
Sumner Evans c8f6fa3a47
olm/session: add basic test
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 18:56:36 -06:00
Sumner Evans 13d0ff3524
wip: olm/session: make an interface
Signed-off-by: Sumner Evans <sumner@beeper.com>
2024-05-12 18:56:36 -06:00
Sumner Evans c0e030fc85
crypto/olm: remove Signatures definition
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
2024-05-12 18:10:48 -06:00
Sumner Evans 2810465ef2
verificationhelper: ensure that the keys are fetched before starting
Signed-off-by: Sumner Evans <sumner@beeper.com>
2024-04-25 09:40:57 -06:00
Malte E 6cc490d9ab
check ghost != nil in correct line (#208) 2024-04-21 15:22:26 +02:00
Sumner Evans ff9e2e0f1d
machine/ShareKeys: save keys before sending server request in case it fails
Signed-off-by: Sumner Evans <sumner@beeper.com>
2024-04-19 08:47:36 -06:00
Tulir Asokan a19dab1897 Bump version to v0.18.1 2024-04-16 13:57:50 +03:00
Tulir Asokan 423d32ddf6 Add real context to HTML parser context struct 2024-04-13 13:57:30 +03:00
Malte E 640086dbf9
Fix default prevContent in bridge membership event handler (#204) 2024-04-05 02:27:36 +03:00
Toni Spets 898b235a84 Allow overriding http.Client with FullRequest 2024-04-02 13:59:48 +03:00
Toni Spets 64cc843952 Invalidate memory cache when storing own cross-signing keys
When another device does cross-signing reset we would incorrectly
cache the old keys indefinitely.
2024-04-02 13:59:07 +03:00
Toni Spets 0095e1fb78 Assume the device list is up-to-date on key backup restore
Fetching devices in a loop can cause request storming if there's a lot
of unknown signatures for a key backup.

A client implementation should always ensure that the devices are
updated from device list changed updates from sync.
2024-03-28 10:42:29 +02:00
Tulir Asokan ade00e8603
Merge pull request #193 from maltee1/join_rule
Join Rule & (Un)ban handling & Knock handling
2024-03-22 20:04:08 +02:00
Toni Spets 9fe66581e5 Check that shared IGS has higher index than stored
Copies the logic from key import.
2024-03-18 13:17:54 +02:00
Adam Van Ymeren 4dd7adc7be
Merge pull request #200 from beeper/adam/hsorder
Fix Unsigned.IsEmpty() when all we have is HSOrder
2024-03-16 11:41:48 -07:00
Adam Van Ymeren 8ba307b28d Fix Unsigned.IsEmpty() when all we have is HSOrder 2024-03-16 11:36:58 -07:00
Tulir Asokan 5dedc9806a Bump version to v0.18.0 2024-03-16 12:55:53 +02:00
Malte E b556d65da9 add handler for accepting/rejecting/retracting invites 2024-03-15 22:29:33 +01:00
Toni Spets fad4448ab7 Use a callback to receive secret response
To properly receive and store a requested secret, we usually need to
validate it against something like a public key to ensure we got the
correct one.

This changes the API so that we instead use a callback to receive any
incoming secret matching our request but we'll fail when we hit the
specified timeout if we never receive anything that is accepted.
2024-03-15 15:12:56 +02:00
Tulir Asokan a7bf485893 Update changelog 2024-03-13 21:23:04 +02:00
Tulir Asokan 20fde3d163 Remove error in ParseCommonIdentifier 2024-03-13 17:01:07 +02:00
Tulir Asokan 5224780563 Split UserID.Parse into generic ParseCommonIdentifier 2024-03-13 16:57:16 +02:00
Toni Spets f0b728f502 Require OGS update to succeed during EncryptMegolmEvent
Otherwise we could end up reusing the same ratchet multiple times.
2024-03-13 11:19:49 +02:00
Tulir Asokan 8128b00e00
Add key server that passes the federation tester (#197) 2024-03-12 21:15:39 +02:00
Brad Murray 08397c8b9a
Fix responding to m.secret.request messages (#195) 2024-03-11 18:50:06 -04:00
Tulir Asokan 94246ffc85 Drop maulogger support 2024-03-11 20:36:06 +02:00
Sumner Evans 2728a8f8aa
olm/pk: add fuzz test for the Sign function
Signed-off-by: Sumner Evans <sumner@beeper.com>
2024-03-11 09:00:11 -06:00
Sumner Evans 3b65d98c0c
olm/pk: make an interface
Signed-off-by: Sumner Evans <sumner@beeper.com>
2024-03-11 09:00:11 -06:00
Tulir Asokan d18dcfc7eb Update dependencies 2024-03-11 15:37:57 +02:00
Toni Spets a36f60a4f3 Parse Beeper inbox preview event in sync 2024-03-11 12:35:55 +02:00
Malte E db41583fdd add knock handling 2024-03-10 13:47:09 +01:00
Malte E 41dfb40064 add ban/unban handling 2024-03-09 21:17:27 +01:00
Malte E 6b1a039beb add join rule handler 2024-03-09 20:34:47 +01:00
68 changed files with 1841 additions and 1107 deletions

View File

@ -1,14 +1,43 @@
## v0.18.0 (unreleased)
## v0.18.1 (2024-04-16)
* *(format)* Added a `context.Context` field to HTMLParser's Context struct.
* *(bridge)* Added support for handling join rules, knocks, invites and bans
(thanks to [@maltee1] in [#193] and [#204]).
* *(crypto)* Changed forwarded room key handling to only accept keys with a
lower first known index than the existing session if there is one.
* *(crypto)* Changed key backup restore to assume own device list is up to date
to avoid re-requesting device list for every deleted device that has signed
key backup.
* *(crypto)* Fixed memory cache not being invalidated when storing own
cross-signing keys
[@maltee1]: https://github.com/maltee1
[#193]: https://github.com/mautrix/go/pull/193
[#204]: https://github.com/mautrix/go/pull/204
## v0.18.0 (2024-03-16)
* **Breaking change *(client, bridge, appservice)*** Dropped support for
maulogger. Only zerolog loggers are now provided by default.
* *(bridge)* Fixed upload size limit not having a default if the server
returned no value.
* *(synapseadmin)* Added wrappers for some room and user admin APIs.
(thanks to [@grvn-ht] in [#181]).
* *(crypto/verificationhelper)* Fixed bugs.
* *(crypto)* Fixed key backup uploading doing too much base64.
* *(crypto)* Changed `EncryptMegolmEvent` to return an error if persisting the
megolm session fails. This ensures that database errors won't cause messages
to be sent with duplicate indexes.
* *(crypto)* Changed `GetOrRequestSecret` to use a callback instead of returning
the value directly. This allows validating the value in order to ignore
invalid secrets.
* *(id)* Added `ParseCommonIdentifier` function to parse any Matrix identifier
in the [Common Identifier Format].
* *(federation)* Added simple key server that passes the federation tester.
[@grvn-ht]: https://github.com/grvn-ht
[#181]: https://github.com/mautrix/go/pull/181
[Common Identifier Format]: https://spec.matrix.org/v1.9/appendices/#common-identifier-format
### beta.1 (2024-02-16)

View File

@ -24,7 +24,6 @@ import (
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
"gopkg.in/yaml.v3"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
@ -355,7 +354,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error {
// This does not do any validation, and it does not cache the client.
// Usually you should prefer [AppService.Client] or [AppService.Intent] over this method.
func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
client := &mautrix.Client{
return &mautrix.Client{
HomeserverURL: as.hsURLForClient,
UserID: userID,
SetAppServiceUserID: true,
@ -366,8 +365,6 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
Client: as.HTTPClient,
DefaultHTTPRetries: as.DefaultHTTPRetries,
}
client.Logger = maulogadapt.ZeroAsMau(&client.Log)
return client
}
// NewExternalMautrixClient creates a new [mautrix.Client] instance for an external user,

View File

@ -29,8 +29,6 @@ import (
"go.mau.fi/util/exzerolog"
"gopkg.in/yaml.v3"
flag "maunium.net/go/mauflag"
"maunium.net/go/maulogger/v2"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
@ -96,6 +94,32 @@ type PowerLevelHandlingPortal interface {
HandleMatrixPowerLevels(sender User, evt *event.Event)
}
type JoinRuleHandlingPortal interface {
Portal
HandleMatrixJoinRule(sender User, evt *event.Event)
}
type BanHandlingPortal interface {
Portal
HandleMatrixBan(sender User, ghost Ghost, evt *event.Event)
HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event)
}
type KnockHandlingPortal interface {
Portal
HandleMatrixKnock(sender User, evt *event.Event)
HandleMatrixRetractKnock(sender User, evt *event.Event)
HandleMatrixAcceptKnock(sender User, ghost Ghost, evt *event.Event)
HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event)
}
type InviteHandlingPortal interface {
Portal
HandleMatrixAcceptInvite(sender User, evt *event.Event)
HandleMatrixRejectInvite(sender User, evt *event.Event)
HandleMatrixRetractInvite(sender User, ghost Ghost, evt *event.Event)
}
type User interface {
GetPermissionLevel() bridgeconfig.PermissionLevel
IsLoggedIn() bool
@ -201,8 +225,6 @@ type Bridge struct {
Crypto Crypto
CryptoPickleKey string
// Deprecated: Switch to ZLog
Log maulogger.Logger
ZLog *zerolog.Logger
MediaConfig mautrix.RespMediaConfig
@ -536,7 +558,6 @@ func (br *Bridge) init() {
os.Exit(12)
}
exzerolog.SetupDefaults(br.ZLog)
br.Log = maulogadapt.ZeroAsMau(br.ZLog)
br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()}

View File

@ -12,7 +12,6 @@ import (
"strings"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
@ -38,8 +37,6 @@ type Event struct {
ReplyTo id.EventID
Ctx context.Context
ZLog *zerolog.Logger
// Deprecated: switch to ZLog
Log maulogger.Logger
}
// MainIntent returns the intent to use when replying to the command.

View File

@ -12,7 +12,6 @@ import (
"strings"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/id"
@ -91,7 +90,6 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.
ReplyTo: replyTo,
Ctx: ctx,
ZLog: &log,
Log: maulogadapt.ZeroAsMau(&log),
}
log.Debug().Msg("Received command")

View File

@ -66,6 +66,7 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler {
br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt)
br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping)
br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels)
br.EventProcessor.On(event.StateJoinRules, handler.HandleJoinRule)
return handler
}
@ -275,27 +276,62 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event)
} else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() {
return
}
mhp, ok := portal.(MembershipHandlingPortal)
if !ok {
bhp, bhpOk := portal.(BanHandlingPortal)
mhp, mhpOk := portal.(MembershipHandlingPortal)
khp, khpOk := portal.(KnockHandlingPortal)
ihp, ihpOk := portal.(InviteHandlingPortal)
if !(mhpOk || bhpOk || khpOk) {
return
}
if content.Membership == event.MembershipLeave {
if evt.Unsigned.PrevContent != nil {
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
prevContent, ok := evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent)
if ok && prevContent.Membership != "join" {
return
prevContent := &event.MemberEventContent{Membership: event.MembershipLeave}
if evt.Unsigned.PrevContent != nil {
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent)
}
if ihpOk && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan {
if content.Membership == event.MembershipJoin {
ihp.HandleMatrixAcceptInvite(user, evt)
}
if content.Membership == event.MembershipLeave {
if isSelf {
ihp.HandleMatrixRejectInvite(user, evt)
} else if ghost != nil {
ihp.HandleMatrixRetractInvite(user, ghost, evt)
}
}
if isSelf {
mhp.HandleMatrixLeave(user, evt)
} else if ghost != nil {
mhp.HandleMatrixKick(user, ghost, evt)
}
if bhpOk && ghost != nil {
if content.Membership == event.MembershipBan {
bhp.HandleMatrixBan(user, ghost, evt)
} else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan {
bhp.HandleMatrixUnban(user, ghost, evt)
}
}
if khpOk {
if content.Membership == event.MembershipKnock {
khp.HandleMatrixKnock(user, evt)
} else if prevContent.Membership == event.MembershipKnock {
if content.Membership == event.MembershipInvite && ghost != nil {
khp.HandleMatrixAcceptKnock(user, ghost, evt)
} else if content.Membership == event.MembershipLeave {
if isSelf {
khp.HandleMatrixRetractKnock(user, evt)
} else if ghost != nil {
khp.HandleMatrixRejectKnock(user, ghost, evt)
}
}
}
}
if mhpOk {
if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin {
if isSelf {
mhp.HandleMatrixLeave(user, evt)
} else if ghost != nil {
mhp.HandleMatrixKick(user, ghost, evt)
}
} else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil {
mhp.HandleMatrixInvite(user, ghost, evt)
}
} else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil {
mhp.HandleMatrixInvite(user, ghost, evt)
}
// TODO kicking/inviting non-ghost users users
}
@ -702,3 +738,18 @@ func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event)
powerLevelPortal.HandleMatrixPowerLevels(user, evt)
}
}
func (mx *MatrixHandler) HandleJoinRule(_ context.Context, evt *event.Event) {
if mx.shouldIgnoreEvent(evt) {
return
}
portal := mx.bridge.Child.GetIPortal(evt.RoomID)
if portal == nil {
return
}
joinRulePortal, ok := portal.(JoinRuleHandlingPortal)
if ok {
user := mx.bridge.Child.GetIUser(evt.Sender, true)
joinRulePortal.HandleMatrixJoinRule(user, evt)
}
}

View File

@ -119,7 +119,7 @@ func (br *Bridge) PingServer() (start, serverTs, end time.Time) {
}
start = time.Now()
var resp wsPingData
br.Log.Debugln("Pinging appservice websocket")
br.ZLog.Debug().Msg("Pinging appservice websocket")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{

View File

@ -19,7 +19,6 @@ import (
"github.com/rs/zerolog"
"go.mau.fi/util/retryafter"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/event"
@ -49,17 +48,6 @@ type VerificationHelper interface {
ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error
}
// Deprecated: switch to zerolog
type Logger interface {
Debugfln(message string, args ...interface{})
}
// Deprecated: switch to zerolog
type WarnLogger interface {
Logger
Warnfln(message string, args ...interface{})
}
// Client represents a Matrix client.
type Client struct {
HomeserverURL *url.URL // The base homeserver URL
@ -75,8 +63,6 @@ type Client struct {
Verification VerificationHelper
Log zerolog.Logger
// Deprecated: switch to the zerolog instance in Log
Logger Logger
RequestHook func(req *http.Request)
ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration)
@ -352,6 +338,7 @@ type FullRequest struct {
SensitiveContent bool
Handler ClientResponseHandler
Logger *zerolog.Logger
Client *http.Client
}
var requestID int32
@ -438,7 +425,10 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b
if len(cli.AccessToken) > 0 {
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
}
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler)
if params.Client == nil {
params.Client = cli.Client
}
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler, params.Client)
}
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
@ -449,7 +439,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
return log
}
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
if req.GetBody == nil {
@ -467,7 +457,7 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff
Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Request failed, retrying")
time.Sleep(backoff)
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler)
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, client)
}
func readRequestBody(req *http.Request, res *http.Response) ([]byte, error) {
@ -549,17 +539,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := cli.Client.Do(req)
res, err := client.Do(req)
duration := time.Now().Sub(startTime)
if res != nil {
defer res.Body.Close()
}
if err != nil {
if retries > 0 {
return cli.doRetry(req, err, retries, backoff, responseJSON, handler)
return cli.doRetry(req, err, retries, backoff, responseJSON, handler, client)
}
err = HTTPError{
Request: req,
@ -574,7 +564,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler)
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, client)
}
var body []byte
@ -2295,7 +2285,7 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli
if err != nil {
return nil, err
}
cli := &Client{
return &Client{
AccessToken: accessToken,
UserAgent: DefaultUserAgent,
HomeserverURL: hsURL,
@ -2307,7 +2297,5 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli
// The client will work with this storer: it just won't remember across restarts.
// In practice, a database backend should be used.
Store: NewMemorySyncStore(),
}
cli.Logger = maulogadapt.ZeroAsMau(&cli.Log)
return cli, nil
}, nil
}

View File

@ -23,27 +23,39 @@ type OlmAccount struct {
func NewOlmAccount() *OlmAccount {
return &OlmAccount{
Internal: *olm.NewAccount(),
Internal: olm.NewAccount(),
}
}
func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) {
if len(account.signingKey) == 0 || len(account.identityKey) == 0 {
account.signingKey, account.identityKey = account.Internal.IdentityKeys()
var err error
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
if err != nil {
panic(err)
}
}
return account.signingKey, account.identityKey
}
func (account *OlmAccount) SigningKey() id.SigningKey {
if len(account.signingKey) == 0 {
account.signingKey, account.identityKey = account.Internal.IdentityKeys()
var err error
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
if err != nil {
panic(err)
}
}
return account.signingKey
}
func (account *OlmAccount) IdentityKey() id.IdentityKey {
if len(account.identityKey) == 0 {
account.signingKey, account.identityKey = account.Internal.IdentityKeys()
var err error
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
if err != nil {
panic(err)
}
}
return account.identityKey
}
@ -71,16 +83,19 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID
func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
if newCount > 0 {
account.Internal.GenOneTimeKeys(uint(newCount))
account.Internal.GenOneTimeKeys(nil, uint(newCount))
}
oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey)
for keyID, key := range account.Internal.OneTimeKeys() {
internalKeys, err := account.Internal.OneTimeKeys()
if err != nil {
panic(err)
}
for keyID, key := range internalKeys {
key := mautrix.OneTimeKey{Key: key}
signature, _ := account.Internal.SignJSON(key)
key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
key.IsSigned = true
oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key
}
account.Internal.MarkKeysAsPublished()
return oneTimeKeys
}

View File

@ -19,16 +19,16 @@ import (
// CrossSigningKeysCache holds the three cross-signing keys for the current user.
type CrossSigningKeysCache struct {
MasterKey *olm.PkSigning
SelfSigningKey *olm.PkSigning
UserSigningKey *olm.PkSigning
MasterKey olm.PKSigning
SelfSigningKey olm.PKSigning
UserSigningKey olm.PKSigning
}
func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache {
return &CrossSigningPublicKeysCache{
MasterKey: cskc.MasterKey.PublicKey,
SelfSigningKey: cskc.SelfSigningKey.PublicKey,
UserSigningKey: cskc.UserSigningKey.PublicKey,
MasterKey: cskc.MasterKey.PublicKey(),
SelfSigningKey: cskc.SelfSigningKey.PublicKey(),
UserSigningKey: cskc.UserSigningKey.PublicKey(),
}
}
@ -40,28 +40,28 @@ type CrossSigningSeeds struct {
func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds {
return CrossSigningSeeds{
MasterKey: mach.CrossSigningKeys.MasterKey.Seed,
SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed,
UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed,
MasterKey: mach.CrossSigningKeys.MasterKey.Seed(),
SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed(),
UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed(),
}
}
func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err error) {
var keysCache CrossSigningKeysCache
if keysCache.MasterKey, err = olm.NewPkSigningFromSeed(keys.MasterKey); err != nil {
if keysCache.MasterKey, err = olm.NewPKSigningFromSeed(keys.MasterKey); err != nil {
return
}
if keysCache.SelfSigningKey, err = olm.NewPkSigningFromSeed(keys.SelfSigningKey); err != nil {
if keysCache.SelfSigningKey, err = olm.NewPKSigningFromSeed(keys.SelfSigningKey); err != nil {
return
}
if keysCache.UserSigningKey, err = olm.NewPkSigningFromSeed(keys.UserSigningKey); err != nil {
if keysCache.UserSigningKey, err = olm.NewPKSigningFromSeed(keys.UserSigningKey); err != nil {
return
}
mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
Str("master", keysCache.MasterKey.PublicKey().String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey().String()).
Msg("Imported own cross-signing keys")
mach.CrossSigningKeys = &keysCache
@ -73,19 +73,19 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro
func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, error) {
var keysCache CrossSigningKeysCache
var err error
if keysCache.MasterKey, err = olm.NewPkSigning(); err != nil {
if keysCache.MasterKey, err = olm.NewPKSigning(); err != nil {
return nil, fmt.Errorf("failed to generate master key: %w", err)
}
if keysCache.SelfSigningKey, err = olm.NewPkSigning(); err != nil {
if keysCache.SelfSigningKey, err = olm.NewPKSigning(); err != nil {
return nil, fmt.Errorf("failed to generate self-signing key: %w", err)
}
if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil {
if keysCache.UserSigningKey, err = olm.NewPKSigning(); err != nil {
return nil, fmt.Errorf("failed to generate user-signing key: %w", err)
}
mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
Str("master", keysCache.MasterKey.PublicKey().String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey().String()).
Msg("Generated cross-signing keys")
return &keysCache, nil
}
@ -93,12 +93,12 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
// PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server.
func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
userID := mach.Client.UserID
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String())
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String())
masterKey := mautrix.CrossSigningKeys{
UserID: userID,
Usage: []id.CrossSigningUsage{id.XSUsageMaster},
Keys: map[id.KeyID]id.Ed25519{
masterKeyID: keys.MasterKey.PublicKey,
masterKeyID: keys.MasterKey.PublicKey(),
},
}
masterSig, err := mach.account.Internal.SignJSON(masterKey)
@ -111,27 +111,27 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
UserID: userID,
Usage: []id.CrossSigningUsage{id.XSUsageSelfSigning},
Keys: map[id.KeyID]id.Ed25519{
id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey.String()): keys.SelfSigningKey.PublicKey,
id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey().String()): keys.SelfSigningKey.PublicKey(),
},
}
selfSig, err := keys.MasterKey.SignJSON(selfKey)
if err != nil {
return fmt.Errorf("failed to sign self-signing key: %w", err)
}
selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), selfSig)
selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), selfSig)
userKey := mautrix.CrossSigningKeys{
UserID: userID,
Usage: []id.CrossSigningUsage{id.XSUsageUserSigning},
Keys: map[id.KeyID]id.Ed25519{
id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey.String()): keys.UserSigningKey.PublicKey,
id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey().String()): keys.UserSigningKey.PublicKey(),
},
}
userSig, err := keys.MasterKey.SignJSON(userKey)
if err != nil {
return fmt.Errorf("failed to sign user-signing key: %w", err)
}
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), userSig)
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
Master: masterKey,

View File

@ -60,7 +60,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe
Str("signature", signature).
Msg("Signed master key of user with our user-signing key")
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey(), signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@ -77,7 +77,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
userID := mach.Client.UserID
deviceID := mach.Client.DeviceID
masterKey := mach.CrossSigningKeys.MasterKey.PublicKey
masterKey := mach.CrossSigningKeys.MasterKey.PublicKey()
masterKeyObj := mautrix.ReqKeysSignatures{
UserID: userID,
@ -149,7 +149,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er
Str("signature", signature).
Msg("Signed own device key with self-signing key")
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey(), signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@ -180,12 +180,12 @@ func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device
}
// signAndUpload signs the given key signatures object and uploads it to the server.
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) {
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key olm.PKSigning) (string, error) {
signature, err := key.SignJSON(req)
if err != nil {
return "", fmt.Errorf("failed to sign JSON: %w", err)
}
req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey.String(), signature)
req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey().String(), signature)
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
userID: map[string]mautrix.ReqKeysSignatures{

View File

@ -110,24 +110,24 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
// UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key.
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed(), key); err != nil {
return err
}
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed(), key); err != nil {
return err
}
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed(), key); err != nil {
return err
}
// Also store these locally
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey); err != nil {
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey()); err != nil {
return err
}
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey); err != nil {
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey()); err != nil {
return err
}
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey); err != nil {
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey()); err != nil {
return err
}

View File

@ -96,5 +96,12 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
}
}
}
// Clear internal cache so that it refreshes from crypto store
if userID == mach.Client.UserID && mach.crossSigningPubkeys != nil {
log.Debug().Msg("Resetting internal cross-signing key cache")
mach.crossSigningPubkeys = nil
mach.crossSigningPubkeysFetched = false
}
}
}

View File

@ -37,13 +37,13 @@ func getOlmMachine(t *testing.T) *OlmMachine {
}
userID := id.UserID("@mautrix")
mk, _ := olm.NewPkSigning()
ssk, _ := olm.NewPkSigning()
usk, _ := olm.NewPkSigning()
mk, _ := olm.NewPKSigning()
ssk, _ := olm.NewPKSigning()
usk, _ := olm.NewPKSigning()
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey)
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey)
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey)
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey())
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey())
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey())
return &OlmMachine{
CryptoStore: sqlStore,
@ -70,10 +70,10 @@ func TestTrustOwnDevice(t *testing.T) {
t.Error("Own device trusted while it shouldn't be")
}
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey,
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(),
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2")
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2")
if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
t.Error("Own user not trusted while they should be")
@ -90,22 +90,22 @@ func TestTrustOtherUser(t *testing.T) {
t.Error("Other user trusted while they shouldn't be")
}
theirMasterKey, _ := olm.NewPkSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
theirMasterKey, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
// sign them with self-signing instead of user-signing key
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig")
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
t.Error("Other user trusted before their master key has been signed with our user-signing key")
}
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
t.Error("Other user not trusted while they should be")
@ -127,29 +127,29 @@ func TestTrustOtherDevice(t *testing.T) {
t.Error("Other device trusted while it shouldn't be")
}
theirMasterKey, _ := olm.NewPkSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
theirSSK, _ := olm.NewPkSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey)
theirMasterKey, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
theirSSK, _ := olm.NewPKSigning()
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey())
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
t.Error("Other user not trusted while they should be")
}
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey,
otherUser, theirMasterKey.PublicKey, "sig3")
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(),
otherUser, theirMasterKey.PublicKey(), "sig3")
if m.IsDeviceTrusted(theirDevice) {
t.Error("Other device trusted before it has been signed with user's SSK")
}
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
otherUser, theirSSK.PublicKey, "sig4")
otherUser, theirSSK.PublicKey(), "sig4")
if !m.IsDeviceTrusted(theirDevice) {
t.Error("Other device not trusted while it should be")

View File

@ -32,7 +32,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
}
theirMSK, ok := theirKeys[id.XSUsageMaster]
if !ok {
mach.machOrContextLog(ctx).Error().
mach.machOrContextLog(ctx).Debug().
Str("user_id", device.UserID.String()).
Msg("Master key of user not found")
return id.TrustStateUnset, nil

View File

@ -93,6 +93,10 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
}
}
// FetchKeys fetches the devices of a list of other users. If includeUntracked
// is set to false, then the users are filtered to to only include user IDs
// whose device lists have been stored with the PutDevices function on the
// [Store]. See the FilterTrackedUsers function on [Store] for details.
func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) {
req := &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{},

View File

@ -118,7 +118,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
if err != nil {
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
return nil, fmt.Errorf("failed to update outbound group session after encrypting: %w", err)
}
encrypted := &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1,

View File

@ -37,7 +37,10 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
Str("olm_session_id", session.ID().String()).
Str("olm_session_description", session.Describe()).
Msg("Encrypting olm message")
msgType, ciphertext := session.Encrypt(plaintext)
msgType, ciphertext, err := session.Encrypt(plaintext)
if err != nil {
panic(err)
}
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")

View File

@ -8,14 +8,18 @@ import (
"fmt"
"io"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/crypto/olm"
)
const (
@ -39,6 +43,9 @@ type Account struct {
NumFallbackKeys uint8 `json:"number_fallback_keys"`
}
// Ensure that Account adheres to the olm.Account interface.
var _ olm.Account = (*Account)(nil)
// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromJSONPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
@ -82,7 +89,7 @@ func NewAccount(reader io.Reader) (*Account, error) {
}
// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a Account) PickleAsJSON(key []byte) ([]byte, error) {
func (a *Account) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
}
@ -92,44 +99,60 @@ func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
func (a Account) IdentityKeysJSON() ([]byte, error) {
func (a *Account) IdentityKeysJSON() ([]byte, error) {
res := struct {
Ed25519 string `json:"ed25519"`
Curve25519 string `json:"curve25519"`
}{}
ed25519, curve25519 := a.IdentityKeys()
ed25519, curve25519, err := a.IdentityKeys()
if err != nil {
return nil, err
}
res.Ed25519 = string(ed25519)
res.Curve25519 = string(curve25519)
return json.Marshal(res)
}
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account.
func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey))
curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey))
return ed25519, curve25519
return ed25519, curve25519, nil
}
// Sign returns the base64-encoded signature of a message using the Ed25519 key
// for this Account.
func (a Account) Sign(message []byte) ([]byte, error) {
func (a *Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
}
return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil
}
// SignJSON signs the given JSON object following the Matrix specification:
// https://matrix.org/docs/spec/appendices#signing-json
func (a *Account) SignJSON(obj any) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signed, err := a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
return string(signed), err
}
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
//
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) OneTimeKeys() map[string]id.Curve25519 {
func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
oneTimeKeys := make(map[string]id.Curve25519)
for _, curKey := range a.OTKeys {
if !curKey.Published {
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
}
}
return oneTimeKeys
return oneTimeKeys, nil
}
//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string.
@ -143,9 +166,12 @@ func (a Account) OneTimeKeys() map[string]id.Curve25519 {
}
}
*/
func (a Account) OneTimeKeysJSON() ([]byte, error) {
func (a *Account) OneTimeKeysJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
otKeys := a.OneTimeKeys()
otKeys, err := a.OneTimeKeys()
if err != nil {
return nil, err
}
res["Curve25519"] = otKeys
return json.Marshal(res)
}
@ -186,7 +212,7 @@ func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
// NewOutboundSession creates a new outbound session to a
// given curve25519 identity Key and one time key.
func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) {
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput)
}
@ -205,13 +231,18 @@ func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25
return s, nil
}
// NewInboundSession creates a new inbound session from an incoming PRE_KEY message.
func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) {
// NewInboundSession creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure.
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
return a.NewInboundSessionFrom(nil, oneTimeKeyMsg)
}
// NewInboundSessionFrom creates a new inbound session from an incoming PRE_KEY message.
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput)
}
var theirIdentityKeyDecoded *crypto.Curve25519PublicKey
var err error
if theirIdentityKey != nil {
theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey))
if err != nil {
@ -221,14 +252,10 @@ func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMs
theirIdentityKeyDecoded = &theirIdentityKeyCurve
}
s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519)
if err != nil {
return nil, err
}
return s, nil
return session.NewInboundOlmSession(theirIdentityKeyDecoded, []byte(oneTimeKeyMsg), a.searchOTKForOur, a.IdKeys.Curve25519)
}
func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
func (a *Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
return &a.OTKeys[curIndex]
@ -244,16 +271,17 @@ func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneT
}
// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s.
func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) {
toFind := s.BobOneTimeKey
func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
toFind := s.(*session.OlmSession).BobOneTimeKey
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
//Remove and return
a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1]
a.OTKeys = a.OTKeys[:len(a.OTKeys)-1]
return
return nil
}
}
return nil
//if the key is a fallback or prevFallback, don't remove it
}
@ -279,7 +307,7 @@ func (a *Account) GenFallbackKey(reader io.Reader) error {
// FallbackKey returns the public part of the current fallback key of the Account.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKey() map[string]id.Curve25519 {
func (a *Account) FallbackKey() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
@ -297,7 +325,7 @@ func (a Account) FallbackKey() map[string]id.Curve25519 {
}
}
*/
func (a Account) FallbackKeyJSON() ([]byte, error) {
func (a *Account) FallbackKeyJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKey()
res["curve25519"] = fbk
@ -306,7 +334,7 @@ func (a Account) FallbackKeyJSON() ([]byte, error) {
// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
@ -324,7 +352,7 @@ func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
}
}
*/
func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
func (a *Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKeyUnpublished()
res["curve25519"] = fbk
@ -448,7 +476,10 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) {
}
// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm().
func (a Account) Pickle(key []byte) ([]byte, error) {
func (a *Account) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, goolm.ErrNoKeyProvided
}
pickeledBytes := make([]byte, a.PickleLen())
written, err := a.PickleLibOlm(pickeledBytes)
if err != nil {
@ -466,7 +497,7 @@ func (a Account) Pickle(key []byte) ([]byte, error) {
// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (a Account) PickleLibOlm(target []byte) (int, error) {
func (a *Account) PickleLibOlm(target []byte) (int, error) {
if len(target) < a.PickleLen() {
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
}
@ -510,7 +541,7 @@ func (a Account) PickleLibOlm(target []byte) (int, error) {
}
// PickleLen returns the number of bytes the pickled Account will have.
func (a Account) PickleLen() int {
func (a *Account) PickleLen() int {
length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM)
length += a.IdKeys.Ed25519.PickleLen()
length += a.IdKeys.Curve25519.PickleLen()
@ -521,3 +552,9 @@ func (a Account) PickleLen() int {
length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID)
return length
}
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
// Account can store.
func (a *Account) MaxNumberOfOneTimeKeys() uint {
return uint(MaxOneTimeKeys)
}

View File

@ -71,7 +71,7 @@ func TestAccount(t *testing.T) {
t.Fatal("IdentityKeys Ed25519 public unequal")
}
if len(firstAccount.OneTimeKeys()) != 2 {
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 2 {
t.Fatal("should get 2 unpublished oneTimeKeys")
}
if len(firstAccount.FallbackKeyUnpublished()) == 0 {
@ -84,7 +84,7 @@ func TestAccount(t *testing.T) {
if len(firstAccount.FallbackKeyUnpublished()) != 0 {
t.Fatal("should get no fallbackKey")
}
if len(firstAccount.OneTimeKeys()) != 0 {
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 0 {
t.Fatal("should get no oneTimeKeys")
}
}
@ -139,7 +139,7 @@ func TestSessions(t *testing.T) {
t.Fatal(err)
}
plaintext := []byte("test message")
msgType, crypttext, err := aliceSession.Encrypt(plaintext, nil)
msgType, crypttext, err := aliceSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
@ -147,11 +147,11 @@ func TestSessions(t *testing.T) {
t.Fatal("wrong message type")
}
bobSession, err := bobAccount.NewInboundSession(nil, crypttext)
bobSession, err := bobAccount.NewInboundSession(string(crypttext))
if err != nil {
t.Fatal(err)
}
decodedText, err := bobSession.Decrypt(crypttext, msgType)
decodedText, err := bobSession.Decrypt(string(crypttext), msgType)
if err != nil {
t.Fatal(err)
}
@ -252,7 +252,7 @@ func TestLoopback(t *testing.T) {
}
plainText := []byte("Hello, World")
msgType, message1, err := aliceSession.Encrypt(plainText, nil)
msgType, message1, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -260,12 +260,12 @@ func TestLoopback(t *testing.T) {
t.Fatal("wrong message type")
}
bobSession, err := accountB.NewInboundSession(nil, message1)
bobSession, err := accountB.NewInboundSession(string(message1))
if err != nil {
t.Fatal(err)
}
// Check that the inbound session matches the message it was created from.
sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1)
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
if err != nil {
t.Fatal(err)
}
@ -274,7 +274,7 @@ func TestLoopback(t *testing.T) {
}
// Check that the inbound session matches the key this message is supposed to be from.
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1)
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
@ -283,7 +283,7 @@ func TestLoopback(t *testing.T) {
}
// Check that the inbound session isn't from a different user.
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1)
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
@ -291,7 +291,7 @@ func TestLoopback(t *testing.T) {
t.Fatal("session is sad to be from b but is from a")
}
// Check that we can decrypt the message.
decryptedMessage, err := bobSession.Decrypt(message1, msgType)
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
if err != nil {
t.Fatal(err)
}
@ -299,7 +299,7 @@ func TestLoopback(t *testing.T) {
t.Fatal("messages are not the same")
}
msgTyp2, message2, err := bobSession.Encrypt(plainText, nil)
msgTyp2, message2, err := bobSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -307,7 +307,7 @@ func TestLoopback(t *testing.T) {
t.Fatal("wrong message type")
}
decryptedMessage2, err := aliceSession.Decrypt(message2, msgTyp2)
decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2)
if err != nil {
t.Fatal(err)
}
@ -316,7 +316,7 @@ func TestLoopback(t *testing.T) {
}
//decrypting again should fail, as the chain moved on
_, err = aliceSession.Decrypt(message2, msgTyp2)
_, err = aliceSession.Decrypt(string(message2), msgTyp2)
if err == nil {
t.Fatal("expected error")
}
@ -348,7 +348,7 @@ func TestMoreMessages(t *testing.T) {
}
plainText := []byte("Hello, World")
msgType, message1, err := aliceSession.Encrypt(plainText, nil)
msgType, message1, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -356,11 +356,11 @@ func TestMoreMessages(t *testing.T) {
t.Fatal("wrong message type")
}
bobSession, err := accountB.NewInboundSession(nil, message1)
bobSession, err := accountB.NewInboundSession(string(message1))
if err != nil {
t.Fatal(err)
}
decryptedMessage, err := bobSession.Decrypt(message1, msgType)
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
if err != nil {
t.Fatal(err)
}
@ -370,7 +370,7 @@ func TestMoreMessages(t *testing.T) {
for i := 0; i < 8; i++ {
//alice sends, bob reveices
msgType, message, err := aliceSession.Encrypt(plainText, nil)
msgType, message, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -384,7 +384,7 @@ func TestMoreMessages(t *testing.T) {
t.Fatal("wrong message type")
}
}
decryptedMessage, err := bobSession.Decrypt(message, msgType)
decryptedMessage, err := bobSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
@ -393,14 +393,14 @@ func TestMoreMessages(t *testing.T) {
}
//now bob sends, alice receives
msgType, message, err = bobSession.Encrypt(plainText, nil)
msgType, message, err = bobSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType == id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
decryptedMessage, err = aliceSession.Decrypt(message, msgType)
decryptedMessage, err = aliceSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
@ -435,7 +435,7 @@ func TestFallbackKey(t *testing.T) {
}
plainText := []byte("Hello, World")
msgType, message1, err := aliceSession.Encrypt(plainText, nil)
msgType, message1, err := aliceSession.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -443,12 +443,12 @@ func TestFallbackKey(t *testing.T) {
t.Fatal("wrong message type")
}
bobSession, err := accountB.NewInboundSession(nil, message1)
bobSession, err := accountB.NewInboundSession(string(message1))
if err != nil {
t.Fatal(err)
}
// Check that the inbound session matches the message it was created from.
sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1)
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
if err != nil {
t.Fatal(err)
}
@ -457,7 +457,7 @@ func TestFallbackKey(t *testing.T) {
}
// Check that the inbound session matches the key this message is supposed to be from.
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1)
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
@ -466,7 +466,7 @@ func TestFallbackKey(t *testing.T) {
}
// Check that the inbound session isn't from a different user.
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1)
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
if err != nil {
t.Fatal(err)
}
@ -474,7 +474,7 @@ func TestFallbackKey(t *testing.T) {
t.Fatal("session is sad to be from b but is from a")
}
// Check that we can decrypt the message.
decryptedMessage, err := bobSession.Decrypt(message1, msgType)
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
if err != nil {
t.Fatal(err)
}
@ -493,7 +493,7 @@ func TestFallbackKey(t *testing.T) {
t.Fatal(err)
}
msgType2, message2, err := aliceSession2.Encrypt(plainText, nil)
msgType2, message2, err := aliceSession2.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -502,19 +502,19 @@ func TestFallbackKey(t *testing.T) {
}
// bobSession should not be valid for the message2
// Check that the inbound session matches the message it was created from.
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(nil, message2)
sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2))
if err != nil {
t.Fatal(err)
}
if sessionIsOK {
t.Fatal("session was detected to be valid but should not")
}
bobSession2, err := accountB.NewInboundSession(nil, message2)
bobSession2, err := accountB.NewInboundSession(string(message2))
if err != nil {
t.Fatal(err)
}
// Check that the inbound session matches the message it was created from.
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(nil, message2)
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2))
if err != nil {
t.Fatal(err)
}
@ -522,7 +522,7 @@ func TestFallbackKey(t *testing.T) {
t.Fatal("session was not detected to be valid")
}
// Check that the inbound session matches the key this message is supposed to be from.
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&aIDKey, message2)
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2))
if err != nil {
t.Fatal(err)
}
@ -530,7 +530,7 @@ func TestFallbackKey(t *testing.T) {
t.Fatal("session is sad to be not from a but it should")
}
// Check that the inbound session isn't from a different user.
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&bIDKey, message2)
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2))
if err != nil {
t.Fatal(err)
}
@ -538,7 +538,7 @@ func TestFallbackKey(t *testing.T) {
t.Fatal("session is sad to be from b but is from a")
}
// Check that we can decrypt the message.
decryptedMessage2, err := bobSession2.Decrypt(message2, msgType2)
decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2)
if err != nil {
t.Fatal(err)
}
@ -553,14 +553,14 @@ func TestFallbackKey(t *testing.T) {
if err != nil {
t.Fatal(err)
}
msgType3, message3, err := aliceSession3.Encrypt(plainText, nil)
msgType3, message3, err := aliceSession3.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
if msgType3 != id.OlmMsgTypePreKey {
t.Fatal("wrong message type")
}
_, err = accountB.NewInboundSession(nil, message3)
_, err = accountB.NewInboundSession(string(message3))
if err == nil {
t.Fatal("expected error")
}

View File

@ -95,10 +95,10 @@ func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Cu
}
// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations.
func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) {
func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) {
var err error
if !r.SenderChains.IsSet {
newRatchetKey, err := crypto.Curve25519GenerateKey(reader)
newRatchetKey, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}

View File

@ -45,7 +45,7 @@ func TestSendReceive(t *testing.T) {
plainText := []byte("Hello Bob")
//Alice sends Bob a message
encryptedMessage, err := aliceRatchet.Encrypt(plainText, nil)
encryptedMessage, err := aliceRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -60,7 +60,7 @@ func TestSendReceive(t *testing.T) {
//Bob sends Alice a message
plainText = []byte("Hello Alice")
encryptedMessage, err = bobRatchet.Encrypt(plainText, nil)
encryptedMessage, err = bobRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -83,11 +83,11 @@ func TestOutOfOrder(t *testing.T) {
plainText2 := []byte("Second Messsage. A bit longer than the first.")
/* Alice sends Bob two messages and they arrive out of order */
message1Encrypted, err := aliceRatchet.Encrypt(plainText1, nil)
message1Encrypted, err := aliceRatchet.Encrypt(plainText1)
if err != nil {
t.Fatal(err)
}
message2Encrypted, err := aliceRatchet.Encrypt(plainText2, nil)
message2Encrypted, err := aliceRatchet.Encrypt(plainText2)
if err != nil {
t.Fatal(err)
}
@ -115,7 +115,7 @@ func TestMoreMessages(t *testing.T) {
}
plainText := []byte("These 15 bytes")
for i := 0; i < 8; i++ {
messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil)
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -128,7 +128,7 @@ func TestMoreMessages(t *testing.T) {
}
}
for i := 0; i < 8; i++ {
messageEncrypted, err := bobRatchet.Encrypt(plainText, nil)
messageEncrypted, err := bobRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -140,7 +140,7 @@ func TestMoreMessages(t *testing.T) {
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
}
}
messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil)
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}
@ -171,7 +171,7 @@ func TestJSONEncoding(t *testing.T) {
plainText := []byte("These 15 bytes")
messageEncrypted, err := newRatcher.Encrypt(plainText, nil)
messageEncrypted, err := newRatcher.Encrypt(plainText)
if err != nil {
t.Fatal(err)
}

View File

@ -45,8 +45,8 @@ func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decrypti
return s, nil
}
// PubKey returns the public key base 64 encoded.
func (s Decryption) PubKey() id.Curve25519 {
// PublicKey returns the public key base 64 encoded.
func (s Decryption) PublicKey() id.Curve25519 {
return s.KeyPair.B64Encoded()
}

View File

@ -30,14 +30,14 @@ func TestEncryptionDecryption(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) {
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
t.Fatal("public key not correct")
}
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
t.Fatal("private key not correct")
}
encryption, err := pk.NewEncryption(decryption.PubKey())
encryption, err := pk.NewEncryption(decryption.PublicKey())
if err != nil {
t.Fatal(err)
}
@ -66,7 +66,10 @@ func TestSigning(t *testing.T) {
}
message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.")
signing, _ := pk.NewSigningFromSeed(seed)
signature := signing.Sign(message)
signature, err := signing.Sign(message)
if err != nil {
t.Fatal(err)
}
signatureDecoded, err := goolm.Base64Decode(signature)
if err != nil {
t.Fatal(err)
@ -101,7 +104,7 @@ func TestDecryptionPickling(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) {
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
t.Fatal("public key not correct")
}
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
@ -125,7 +128,7 @@ func TestDecryptionPickling(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !bytes.Equal([]byte(newDecription.PubKey()), alicePublic) {
if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) {
t.Fatal("public key not correct")
}
if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) {

View File

@ -2,7 +2,11 @@ package pk
import (
"crypto/rand"
"encoding/json"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
@ -10,15 +14,15 @@ import (
// Signing is used for signing a pk
type Signing struct {
KeyPair crypto.Ed25519KeyPair `json:"key_pair"`
Seed []byte `json:"seed"`
keyPair crypto.Ed25519KeyPair
seed []byte
}
// NewSigningFromSeed constructs a new Signing based on a seed.
func NewSigningFromSeed(seed []byte) (*Signing, error) {
s := &Signing{}
s.Seed = seed
s.KeyPair = crypto.Ed25519GenerateFromSeed(seed)
s.seed = seed
s.keyPair = crypto.Ed25519GenerateFromSeed(seed)
return s, nil
}
@ -32,13 +36,34 @@ func NewSigning() (*Signing, error) {
return NewSigningFromSeed(seed)
}
// Sign returns the signature of the message base64 encoded.
func (s Signing) Sign(message []byte) []byte {
signature := s.KeyPair.Sign(message)
return goolm.Base64Encode(signature)
// Seed returns the seed of the key pair.
func (s Signing) Seed() []byte {
return s.seed
}
// PublicKey returns the public key of the key pair base 64 encoded.
func (s Signing) PublicKey() id.Ed25519 {
return s.KeyPair.B64Encoded()
return s.keyPair.B64Encoded()
}
// Sign returns the signature of the message base64 encoded.
func (s Signing) Sign(message []byte) ([]byte, error) {
signature := s.keyPair.Sign(message)
return goolm.Base64Encode(signature), nil
}
// SignJSON creates a signature for the given object after encoding it to
// canonical JSON.
func (s Signing) SignJSON(obj any) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
if err != nil {
return "", err
}
return string(signature), nil
}

View File

@ -89,7 +89,7 @@ func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession
}
// getRatchet tries to find the correct ratchet for a messageIndex.
func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
// pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that
if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) {
o.Ratchet.AdvanceTo(messageIndex)
@ -107,7 +107,10 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
}
// Decrypt decrypts a base64 encoded group message.
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) {
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) {
if len(ciphertext) == 0 {
return nil, 0, goolm.ErrEmptyInput
}
if o.SigningKey == nil {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
@ -143,17 +146,17 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error
return nil, 0, err
}
o.SigningKeyVerified = true
return decrypted, msg.MessageIndex, nil
return decrypted, uint(msg.MessageIndex), nil
}
// SessionID returns the base64 endoded signing key
func (o MegolmInboundSession) SessionID() id.SessionID {
// ID returns the base64 endoded signing key
func (o *MegolmInboundSession) ID() id.SessionID {
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey))
}
// PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key)
}
@ -162,8 +165,14 @@ func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON)
}
// SessionExportMessage creates an base64 encoded export of the session.
func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) {
// Export returns the base64-encoded ratchet key for this session, at the given
// index, in a format which can be used by
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
// InboundGroupSession using the supplied key. Returns error on failure.
// if we do not have a session key corresponding to the given index (ie, it was
// sent before the session key was shared with us) the error will be
// returned.
func (o *MegolmInboundSession) Export(messageIndex uint32) ([]byte, error) {
ratchet, err := o.getRatchet(messageIndex)
if err != nil {
return nil, err
@ -174,6 +183,11 @@ func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte,
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return goolm.ErrNoKeyProvided
} else if len(pickled) == 0 {
return goolm.ErrEmptyInput
}
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
@ -223,7 +237,10 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) {
}
// Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm().
func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, goolm.ErrNoKeyProvided
}
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
@ -241,7 +258,7 @@ func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
func (o *MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort)
}
@ -266,7 +283,7 @@ func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
}
// PickleLen returns the number of bytes the pickled session will have.
func (o MegolmInboundSession) PickleLen() int {
func (o *MegolmInboundSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm)
length += o.InitialRatchet.PickleLen()
length += o.Ratchet.PickleLen()
@ -274,3 +291,15 @@ func (o MegolmInboundSession) PickleLen() int {
length += libolmpickle.PickleBoolLen(o.SigningKeyVerified)
return length
}
// FirstKnownIndex returns the first message index we know how to decrypt.
func (s *MegolmInboundSession) FirstKnownIndex() uint32 {
return s.InitialRatchet.Counter
}
// IsVerified check if the session has been verified as a valid session. (A
// session is verified either because the original session share was signed, or
// because we have subsequently successfully decrypted a message.)
func (s *MegolmInboundSession) IsVerified() bool {
return s.SigningKeyVerified
}

View File

@ -55,14 +55,14 @@ func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSessi
}
a := &MegolmOutboundSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
return a, err
}
// Encrypt encrypts the plaintext as a base64 encoded group message.
func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
if len(plaintext) == 0 {
return nil, goolm.ErrEmptyInput
}
encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey)
if err != nil {
return nil, err
@ -71,12 +71,12 @@ func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
}
// SessionID returns the base64 endoded public signing key
func (o MegolmOutboundSession) SessionID() id.SessionID {
func (o *MegolmOutboundSession) ID() id.SessionID {
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey))
}
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key)
}
@ -88,6 +88,9 @@ func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error {
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return goolm.ErrNoKeyProvided
}
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
@ -122,7 +125,10 @@ func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) {
}
// Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm().
func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, goolm.ErrNoKeyProvided
}
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
@ -140,7 +146,7 @@ func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
func (o *MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
}
@ -159,13 +165,28 @@ func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
}
// PickleLen returns the number of bytes the pickled session will have.
func (o MegolmOutboundSession) PickleLen() int {
func (o *MegolmOutboundSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm)
length += o.Ratchet.PickleLen()
length += o.SigningKey.PickleLen()
return length
}
func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
return o.Ratchet.SessionSharingMessage(o.SigningKey)
}
// MessageIndex returns the message index for this session. Each message is
// sent with an increasing index; this returns the index for the next message.
func (s *MegolmOutboundSession) MessageIndex() uint {
return uint(s.Ratchet.Counter)
}
// Key returns the base64-encoded current ratchet key for this session.
func (s *MegolmOutboundSession) Key() string {
message, err := s.SessionSharingMessage()
if err != nil {
panic(err)
}
return string(message)
}

View File

@ -33,7 +33,7 @@ func TestOutboundPickleJSON(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if sess.SessionID() != newSession.SessionID() {
if sess.ID() != newSession.ID() {
t.Fatal("session ids not equal")
}
if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) {
@ -75,7 +75,7 @@ func TestInboundPickleJSON(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if sess.SessionID() != newSession.SessionID() {
if sess.ID() != newSession.ID() {
t.Fatal("sess ids not equal")
}
if !bytes.Equal(sess.SigningKey, newSession.SigningKey) {
@ -128,7 +128,7 @@ func TestGroupSendReceive(t *testing.T) {
if !inboundSession.SigningKeyVerified {
t.Fatal("key not verified")
}
if inboundSession.SessionID() != outboundSession.SessionID() {
if inboundSession.ID() != outboundSession.ID() {
t.Fatal("session ids not equal")
}
@ -174,7 +174,7 @@ func TestGroupSessionExportImport(t *testing.T) {
}
//Export the keys
exported, err := inboundSession.SessionExportMessage(0)
exported, err := inboundSession.Export(0)
if err != nil {
t.Fatal(err)
}

View File

@ -5,7 +5,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"strings"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
@ -204,7 +204,7 @@ func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error {
// ID returns an identifier for this Session. Will be the same for both ends of the conversation.
// Generated by hashing the public keys used to create the session.
func (s OlmSession) ID() id.SessionID {
func (s *OlmSession) ID() id.SessionID {
message := make([]byte, 3*crypto.Curve25519KeyLength)
copy(message, s.AliceIdentityKey)
copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey)
@ -215,15 +215,39 @@ func (s OlmSession) ID() id.SessionID {
}
// HasReceivedMessage returns true if this session has received any message.
func (s OlmSession) HasReceivedMessage() bool {
func (s *OlmSession) HasReceivedMessage() bool {
return s.ReceivedMessage
}
// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match.
func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match. Returns error on
// failure.
func (s *OlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
return s.matchesInboundSession(nil, []byte(oneTimeKeyMsg))
}
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match. Returns error on
// failure.
func (s *OlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
var theirKey *id.Curve25519
if theirIdentityKey != "" {
theirs := id.Curve25519(theirIdentityKey)
theirKey = &theirs
}
return s.matchesInboundSession(theirKey, []byte(oneTimeKeyMsg))
}
// matchesInboundSession checks if the oneTimeKeyMsg message is set for this
// inbound Session. This can happen if multiple messages are sent to this
// Account before this Account sends a message in reply. Returns true if the
// session matches. Returns false if the session does not match.
func (s *OlmSession) matchesInboundSession(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
if len(receivedOTKMsg) == 0 {
return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput)
}
@ -266,20 +290,20 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2
// EncryptMsgType returns the type of the next message that Encrypt will
// return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg.
// Returns MsgTypeMsg if the message will be a normal message.
func (s OlmSession) EncryptMsgType() id.OlmMsgType {
func (s *OlmSession) EncryptMsgType() id.OlmMsgType {
if s.ReceivedMessage {
return id.OlmMsgTypeMsg
}
return id.OlmMsgTypePreKey
}
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations.
func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) {
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded.
func (s *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput)
}
messageType := s.EncryptMsgType()
encrypted, err := s.Ratchet.Encrypt(plaintext, reader)
encrypted, err := s.Ratchet.Encrypt(plaintext)
if err != nil {
return 0, nil, err
}
@ -304,11 +328,11 @@ func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType,
}
// Decrypt decrypts a base64 encoded message using the Session.
func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) {
func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, error) {
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput)
}
decodedCrypttext, err := goolm.Base64Decode(crypttext)
decodedCrypttext, err := goolm.Base64Decode([]byte(crypttext))
if err != nil {
return nil, err
}
@ -333,6 +357,9 @@ func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, e
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *OlmSession) Unpickle(pickled, key []byte) error {
if len(pickled) == 0 {
return goolm.ErrEmptyInput
}
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
@ -386,26 +413,26 @@ func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) {
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm().
func (o OlmSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
// Pickle returns a base64 encoded and with key encrypted pickled olmSession
// using PickleLibOlm().
func (s *OlmSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, goolm.ErrNoKeyProvided
}
pickeledBytes := make([]byte, s.PickleLen())
written, err := s.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
return cipher.Pickle(key, pickeledBytes)
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
func (o *OlmSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
}
@ -435,7 +462,7 @@ func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
}
// PickleLen returns the actual number of bytes the pickled session will have.
func (o OlmSession) PickleLen() int {
func (o *OlmSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
length += o.AliceIdentityKey.PickleLen()
@ -446,7 +473,7 @@ func (o OlmSession) PickleLen() int {
}
// PickleLenMin returns the minimum number of bytes the pickled session must have.
func (o OlmSession) PickleLenMin() int {
func (o *OlmSession) PickleLenMin() int {
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
length += o.AliceIdentityKey.PickleLen()
@ -457,20 +484,17 @@ func (o OlmSession) PickleLenMin() int {
}
// Describe returns a string describing the current state of the session for debugging.
func (o OlmSession) Describe() string {
var res string
if o.Ratchet.SenderChains.IsSet {
res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index)
} else {
res += "sender chain index: "
}
res += "receiver chain indicies:"
func (o *OlmSession) Describe() string {
var builder strings.Builder
builder.WriteString("sender chain index: ")
builder.WriteString(fmt.Sprint(o.Ratchet.SenderChains.CKey.Index))
builder.WriteString(" receiver chain indices:")
for _, curChain := range o.Ratchet.ReceiverChains {
res += fmt.Sprintf(" %d", curChain.CKey.Index)
builder.WriteString(fmt.Sprintf(" %d", curChain.CKey.Index))
}
res += " skipped message keys:"
builder.WriteString(" skipped message keys:")
for _, curSkip := range o.Ratchet.SkippedMessageKeys {
res += fmt.Sprintf(" %d", curSkip.MKey.Index)
builder.WriteString(fmt.Sprintf(" %d", curSkip.MKey.Index))
}
return res
return builder.String()
}

View File

@ -32,7 +32,7 @@ func TestOlmSession(t *testing.T) {
}
//create a message so that there are more keys to marshal
plaintext := []byte("Test message from Alice to Bob")
msgType, message, err := aliceSession.Encrypt(plaintext, nil)
msgType, message, err := aliceSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
@ -55,7 +55,7 @@ func TestOlmSession(t *testing.T) {
if err != nil {
t.Fatal(err)
}
decryptedMsg, err := bobSession.Decrypt(message, msgType)
decryptedMsg, err := bobSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
@ -71,7 +71,7 @@ func TestOlmSession(t *testing.T) {
//bob sends a message
plaintext = []byte("A message from Bob to Alice")
msgType, message, err = bobSession.Encrypt(plaintext, nil)
msgType, message, err = bobSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
@ -86,7 +86,7 @@ func TestOlmSession(t *testing.T) {
}
//Alice receives message
decryptedMsg, err = newAliceSession.Decrypt(message, msgType)
decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
@ -95,14 +95,14 @@ func TestOlmSession(t *testing.T) {
}
//Alice receives message again
_, err = newAliceSession.Decrypt(message, msgType)
_, err = newAliceSession.Decrypt(string(message), msgType)
if err == nil {
t.Fatal("should have gotten an error")
}
//Alice sends another message
plaintext = []byte("A second message to Bob")
msgType, message, err = newAliceSession.Encrypt(plaintext, nil)
msgType, message, err = newAliceSession.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
@ -110,7 +110,7 @@ func TestOlmSession(t *testing.T) {
t.Fatal("Wrong message type")
}
//bob receives message
decryptedMsg, err = bobSession.Decrypt(message, msgType)
decryptedMsg, err = bobSession.Decrypt(string(message), msgType)
if err != nil {
t.Fatal(err)
}
@ -165,7 +165,7 @@ func TestDecrypts(t *testing.T) {
t.Fatal(err)
}
for curIndex, curMessage := range messages {
_, err := sess.Decrypt(curMessage, id.OlmMsgTypePreKey)
_, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey)
if err != nil {
if !errors.Is(err, expectedErr[curIndex]) {
t.Fatal(err)

View File

@ -66,8 +66,10 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context)
var key id.Ed25519
if keyName == crossSigningPubkeys.MasterKey.String() {
key = crossSigningPubkeys.MasterKey
} else if device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil {
log.Warn().Err(err).Msg("Failed to fetch device")
} else if device, err := mach.CryptoStore.GetDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil {
return nil, fmt.Errorf("failed to get device %s/%s from store: %w", mach.Client.UserID, keyName, err)
} else if device == nil {
log.Warn().Err(err).Msg("Device does not exist, ignoring signature")
continue
} else if !mach.IsDeviceTrusted(device) {
log.Warn().Err(err).Msg("Device is not trusted")
@ -163,7 +165,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
}
igs := &InboundGroupSession{
Internal: *igsInternal,
Internal: igsInternal,
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,

View File

@ -104,7 +104,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
return false, ErrMismatchingExportedSessionID
}
igs := &InboundGroupSession{
Internal: *igsInternal,
Internal: igsInternal,
SigningKey: session.SenderClaimedKeys.Ed25519,
SenderKey: session.SenderKey,
RoomID: session.RoomID,

View File

@ -172,7 +172,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
igs := &InboundGroupSession{
Internal: *igsInternal,
Internal: igsInternal,
SigningKey: evt.Keys.Ed25519,
SenderKey: content.SenderKey,
RoomID: content.RoomID,
@ -184,6 +184,11 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
// We already have an equivalent or better session in the store, so don't override it.
return false
}
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
if err != nil {
log.Error().Err(err).Msg("Failed to store new inbound group session")

View File

@ -681,6 +681,11 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
log.Debug().Msg("No one-time keys nor device keys got when trying to share keys")
return nil
}
// Save the keys before sending the upload request in case there is a
// network failure.
if err := mach.saveAccount(ctx); err != nil {
return err
}
req := &mautrix.ReqUploadKeys{
DeviceKeys: deviceKeys,
OneTimeKeys: oneTimeKeys,
@ -691,6 +696,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
return err
}
mach.lastOTKUpload = time.Now()
mach.account.Internal.MarkKeysAsPublished()
mach.account.Shared = true
return mach.saveAccount(ctx)
}

View File

@ -77,6 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
otk = otkTmp
break
}
machineIn.account.Internal.MarkKeysAsPublished()
// create outbound olm session for sending machine using OTK
olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)

View File

@ -1,154 +1,20 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build goolm
package olm
import (
"encoding/json"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/id"
)
// Account stores a device account for end to end encrypted messaging.
type Account struct {
account.Account
}
import "maunium.net/go/mautrix/crypto/goolm/account"
// NewAccount creates a new Account.
func NewAccount() *Account {
a, err := account.NewAccount(nil)
if err != nil {
panic(err)
}
ac := &Account{}
ac.Account = *a
return ac
func NewAccount() Account {
return account.NewAccount()
}
func NewBlankAccount() *Account {
return &Account{}
}
// Clear clears the memory used to back this Account.
func (a *Account) Clear() error {
a.Account = account.Account{}
return nil
}
// Pickle returns an Account as a base64 string. Encrypts the Account using the
// supplied key.
func (a *Account) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := a.Account.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
func (a *Account) IdentityKeysJSON() []byte {
identityKeys, err := a.Account.IdentityKeysJSON()
if err != nil {
panic(err)
}
return identityKeys
}
// Sign returns the signature of a message using the ed25519 key for this
// Account.
func (a *Account) Sign(message []byte) []byte {
if len(message) == 0 {
panic(EmptyInput)
}
signature, err := a.Account.Sign(message)
if err != nil {
panic(err)
}
return signature
}
// SignJSON signs the given JSON object following the Matrix specification:
// https://matrix.org/docs/spec/appendices#signing-json
func (a *Account) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil
}
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
// Account can store.
func (a *Account) MaxNumberOfOneTimeKeys() uint {
return uint(account.MaxOneTimeKeys)
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
// keys are discarded.
func (a *Account) GenOneTimeKeys(num uint) {
err := a.Account.GenOneTimeKeys(nil, num)
if err != nil {
panic(err)
}
}
// NewOutboundSession creates a new out-bound session for sending messages to a
// given curve25519 identityKey and oneTimeKey. Returns error on failure.
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, EmptyInput
}
s := &Session{}
newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey)
if err != nil {
return nil, err
}
s.OlmSession = *newSession
return s, nil
}
// NewInboundSession creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure.
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, EmptyInput
}
s := &Session{}
newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg))
if err != nil {
return nil, err
}
s.OlmSession = *newSession
return s, nil
}
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure.
func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
return nil, EmptyInput
}
s := &Session{}
newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg))
if err != nil {
return nil, err
}
s.OlmSession = *newSession
return s, nil
}
// RemoveOneTimeKeys removes the one time keys that the session used from the
// Account. Returns error on failure.
func (a *Account) RemoveOneTimeKeys(s *Session) error {
a.Account.RemoveOneTimeKeys(&s.OlmSession)
return nil
func NewBlankAccount() Account {
return &account.Account{}
}

View File

@ -0,0 +1,95 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package olm
import (
"io"
"maunium.net/go/mautrix/id"
)
type Account interface {
// Pickle returns an Account as a base64 string. Encrypts the Account using the
// supplied key.
Pickle(key []byte) ([]byte, error)
// Unpickle loads an Account from a pickled base64 string. Decrypts the
// Account using the supplied key. Returns error on failure.
Unpickle(pickled, key []byte) error
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
IdentityKeysJSON() ([]byte, error)
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity
// keys for the Account.
IdentityKeys() (id.Ed25519, id.Curve25519, error)
// Sign returns the signature of a message using the ed25519 key for this
// Account.
Sign(message []byte) ([]byte, error)
// SignJSON signs the given JSON object following the Matrix specification:
// https://matrix.org/docs/spec/appendices#signing-json
SignJSON(obj any) (string, error)
// OneTimeKeys returns the public parts of the unpublished one time keys for
// the Account.
//
// The returned data is a struct with the single value "Curve25519", which is
// itself an object mapping key id to base64-encoded Curve25519 key. For
// example:
//
// {
// Curve25519: {
// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
// }
// }
OneTimeKeys() (map[string]id.Curve25519, error)
// MarkKeysAsPublished marks the current set of one time keys as being
// published.
MarkKeysAsPublished()
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
// Account can store.
MaxNumberOfOneTimeKeys() uint
// GenOneTimeKeys generates a number of new one time keys. If the total
// number of keys stored by this Account exceeds MaxNumberOfOneTimeKeys
// then the old keys are discarded. Reads random data from the given
// reader, or if nil is passed, defaults to crypto/rand.
GenOneTimeKeys(reader io.Reader, num uint) error
// NewOutboundSession creates a new out-bound session for sending messages to a
// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error)
// NewInboundSession creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure. If
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
// the message was for an unsupported protocol version then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
// time key then the error will be "BAD_MESSAGE_KEY_ID".
NewInboundSession(oneTimeKeyMsg string) (Session, error)
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure. If
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
// the message was for an unsupported protocol version then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
// time key then the error will be "BAD_MESSAGE_KEY_ID".
NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error)
// RemoveOneTimeKeys removes the one time keys that the session used from the
// Account. Returns error on failure. If the Account doesn't have any
// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID".
RemoveOneTimeKeys(s Session) error
}

View File

@ -10,6 +10,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"io"
"unsafe"
"github.com/tidwall/gjson"
@ -19,18 +20,21 @@ import (
"maunium.net/go/mautrix/id"
)
// Account stores a device account for end to end encrypted messaging.
type Account struct {
// LibOlmAccount stores a device account for end to end encrypted messaging.
type LibOlmAccount struct {
int *C.OlmAccount
mem []byte
}
// Ensure that LibOlmAccount implements Account.
var _ Account = (*LibOlmAccount)(nil)
// AccountFromPickled loads an Account from a pickled base64 string. Decrypts
// the Account using the supplied key. Returns error on failure. If the key
// doesn't match the one used to encrypt the Account then the error will be
// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be
// "INVALID_BASE64".
func AccountFromPickled(pickled, key []byte) (*Account, error) {
func AccountFromPickled(pickled, key []byte) (Account, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
@ -38,17 +42,21 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) {
return a, a.Unpickle(pickled, key)
}
func NewBlankAccount() *Account {
func NewBlankLibOlmAccount() *LibOlmAccount {
memory := make([]byte, accountSize())
return &Account{
return &LibOlmAccount{
int: C.olm_account(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
func NewBlankAccount() Account {
return NewBlankLibOlmAccount()
}
// NewAccount creates a new Account.
func NewAccount() *Account {
a := NewBlankAccount()
func NewAccount() Account {
a := NewBlankLibOlmAccount()
random := make([]byte, a.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
@ -72,12 +80,12 @@ func accountSize() uint {
// lastError returns an error describing the most recent error to happen to an
// account.
func (a *Account) lastError() error {
func (a *LibOlmAccount) lastError() error {
return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int))))
}
// Clear clears the memory used to back this Account.
func (a *Account) Clear() error {
func (a *LibOlmAccount) Clear() error {
r := C.olm_clear_account((*C.OlmAccount)(a.int))
if r == errorVal() {
return a.lastError()
@ -87,36 +95,36 @@ func (a *Account) Clear() error {
}
// pickleLen returns the number of bytes needed to store an Account.
func (a *Account) pickleLen() uint {
func (a *LibOlmAccount) pickleLen() uint {
return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int)))
}
// createRandomLen returns the number of random bytes needed to create an
// Account.
func (a *Account) createRandomLen() uint {
func (a *LibOlmAccount) createRandomLen() uint {
return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int)))
}
// identityKeysLen returns the size of the output buffer needed to hold the
// identity keys.
func (a *Account) identityKeysLen() uint {
func (a *LibOlmAccount) identityKeysLen() uint {
return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int)))
}
// signatureLen returns the length of an ed25519 signature encoded as base64.
func (a *Account) signatureLen() uint {
func (a *LibOlmAccount) signatureLen() uint {
return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int)))
}
// oneTimeKeysLen returns the size of the output buffer needed to hold the one
// time keys.
func (a *Account) oneTimeKeysLen() uint {
func (a *LibOlmAccount) oneTimeKeysLen() uint {
return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int)))
}
// genOneTimeKeysRandomLen returns the number of random bytes needed to
// generate a given number of new one time keys.
func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
func (a *LibOlmAccount) genOneTimeKeysRandomLen(num uint) uint {
return uint(C.olm_account_generate_one_time_keys_random_length(
(*C.OlmAccount)(a.int),
C.size_t(num)))
@ -124,9 +132,9 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
// Pickle returns an Account as a base64 string. Encrypts the Account using the
// supplied key.
func (a *Account) Pickle(key []byte) []byte {
func (a *LibOlmAccount) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
panic(NoKeyProvided)
return nil, NoKeyProvided
}
pickled := make([]byte, a.pickleLen())
r := C.olm_pickle_account(
@ -136,12 +144,12 @@ func (a *Account) Pickle(key []byte) []byte {
unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
panic(a.lastError())
return nil, a.lastError()
}
return pickled[:r]
return pickled[:r], nil
}
func (a *Account) Unpickle(pickled, key []byte) error {
func (a *LibOlmAccount) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
}
@ -158,18 +166,21 @@ func (a *Account) Unpickle(pickled, key []byte) error {
}
// Deprecated
func (a *Account) GobEncode() ([]byte, error) {
pickled := a.Pickle(pickleKey)
func (a *LibOlmAccount) GobEncode() ([]byte, error) {
pickled, err := a.Pickle(pickleKey)
if err != nil {
return nil, err
}
length := base64.RawStdEncoding.DecodedLen(len(pickled))
rawPickled := make([]byte, length)
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
return rawPickled, err
}
// Deprecated
func (a *Account) GobDecode(rawPickled []byte) error {
func (a *LibOlmAccount) GobDecode(rawPickled []byte) error {
if a.int == nil {
*a = *NewBlankAccount()
*a = *NewBlankLibOlmAccount()
}
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
pickled := make([]byte, length)
@ -178,8 +189,11 @@ func (a *Account) GobDecode(rawPickled []byte) error {
}
// Deprecated
func (a *Account) MarshalJSON() ([]byte, error) {
pickled := a.Pickle(pickleKey)
func (a *LibOlmAccount) MarshalJSON() ([]byte, error) {
pickled, err := a.Pickle(pickleKey)
if err != nil {
return nil, err
}
quotes := make([]byte, len(pickled)+2)
quotes[0] = '"'
quotes[len(quotes)-1] = '"'
@ -188,41 +202,44 @@ func (a *Account) MarshalJSON() ([]byte, error) {
}
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
func (a *LibOlmAccount) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString
}
if a.int == nil {
*a = *NewBlankAccount()
*a = *NewBlankLibOlmAccount()
}
return a.Unpickle(data[1:len(data)-1], pickleKey)
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
func (a *Account) IdentityKeysJSON() []byte {
func (a *LibOlmAccount) IdentityKeysJSON() ([]byte, error) {
identityKeys := make([]byte, a.identityKeysLen())
r := C.olm_account_identity_keys(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&identityKeys[0]),
C.size_t(len(identityKeys)))
if r == errorVal() {
panic(a.lastError())
return nil, a.lastError()
} else {
return identityKeys
return identityKeys, nil
}
}
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity
// keys for the Account.
func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
identityKeysJSON := a.IdentityKeysJSON()
func (a *LibOlmAccount) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
identityKeysJSON, err := a.IdentityKeysJSON()
if err != nil {
return "", "", err
}
results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519")
return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str)
return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil
}
// Sign returns the signature of a message using the ed25519 key for this
// Account.
func (a *Account) Sign(message []byte) []byte {
func (a *LibOlmAccount) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
panic(EmptyInput)
}
@ -236,19 +253,20 @@ func (a *Account) Sign(message []byte) []byte {
if r == errorVal() {
panic(a.lastError())
}
return signature
return signature, nil
}
// SignJSON signs the given JSON object following the Matrix specification:
// https://matrix.org/docs/spec/appendices#signing-json
func (a *Account) SignJSON(obj interface{}) (string, error) {
func (a *LibOlmAccount) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil
signed, err := a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
return string(signed), err
}
// OneTimeKeys returns the public parts of the unpublished one time keys for
@ -264,45 +282,44 @@ func (a *Account) SignJSON(obj interface{}) (string, error) {
// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
// }
// }
func (a *Account) OneTimeKeys() map[string]id.Curve25519 {
func (a *LibOlmAccount) OneTimeKeys() (map[string]id.Curve25519, error) {
oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen())
r := C.olm_account_one_time_keys(
(*C.OlmAccount)(a.int),
unsafe.Pointer(&oneTimeKeysJSON[0]),
C.size_t(len(oneTimeKeysJSON)))
if r == errorVal() {
panic(a.lastError())
return nil, a.lastError()
}
var oneTimeKeys struct {
Curve25519 map[string]id.Curve25519 `json:"curve25519"`
}
err := json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys)
if err != nil {
panic(err)
}
return oneTimeKeys.Curve25519
return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys)
}
// MarkKeysAsPublished marks the current set of one time keys as being
// published.
func (a *Account) MarkKeysAsPublished() {
func (a *LibOlmAccount) MarkKeysAsPublished() {
C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int))
}
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
// Account can store.
func (a *Account) MaxNumberOfOneTimeKeys() uint {
func (a *LibOlmAccount) MaxNumberOfOneTimeKeys() uint {
return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int)))
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
// keys are discarded.
func (a *Account) GenOneTimeKeys(num uint) {
func (a *LibOlmAccount) GenOneTimeKeys(reader io.Reader, num uint) error {
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
_, err := rand.Read(random)
if reader == nil {
reader = rand.Reader
}
_, err := reader.Read(random)
if err != nil {
panic(NotEnoughGoRandom)
return NotEnoughGoRandom
}
r := C.olm_account_generate_one_time_keys(
(*C.OlmAccount)(a.int),
@ -310,18 +327,19 @@ func (a *Account) GenOneTimeKeys(num uint) {
unsafe.Pointer(&random[0]),
C.size_t(len(random)))
if r == errorVal() {
panic(a.lastError())
return a.lastError()
}
return nil
}
// NewOutboundSession creates a new out-bound session for sending messages to a
// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) {
func (a *LibOlmAccount) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, EmptyInput
}
s := NewBlankSession()
s := NewBlankLibOlmSession()
random := make([]byte, s.createOutboundRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
@ -349,11 +367,11 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
func (a *LibOlmAccount) NewInboundSession(oneTimeKeyMsg string) (Session, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, EmptyInput
}
s := NewBlankSession()
s := NewBlankLibOlmSession()
r := C.olm_create_inbound_session(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
@ -372,16 +390,16 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
// time key then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
func (a *LibOlmAccount) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error) {
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
return nil, EmptyInput
}
s := NewBlankSession()
s := NewBlankLibOlmSession()
r := C.olm_create_inbound_session_from(
(*C.OlmSession)(s.int),
(*C.OlmAccount)(a.int),
unsafe.Pointer(&([]byte(theirIdentityKey)[0])),
C.size_t(len(theirIdentityKey)),
unsafe.Pointer(&([]byte(*theirIdentityKey)[0])),
C.size_t(len(*theirIdentityKey)),
unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
C.size_t(len(oneTimeKeyMsg)))
if r == errorVal() {
@ -393,10 +411,10 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeK
// RemoveOneTimeKeys removes the one time keys that the session used from the
// Account. Returns error on failure. If the Account doesn't have any
// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID".
func (a *Account) RemoveOneTimeKeys(s *Session) error {
func (a *LibOlmAccount) RemoveOneTimeKeys(s Session) error {
r := C.olm_remove_one_time_keys(
(*C.OlmAccount)(a.int),
(*C.OlmSession)(s.int))
(*C.OlmSession)(s.(*LibOlmSession).int))
if r == errorVal() {
return a.lastError()
}

View File

@ -4,146 +4,39 @@ package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
// InboundGroupSession stores an inbound encrypted messaging session for a
// group.
type InboundGroupSession struct {
session.MegolmInboundSession
}
// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled
// base64 string. Decrypts the InboundGroupSession using the supplied key.
// Returns error on failure.
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
lenKey := len(key)
if lenKey == 0 {
if len(key) == 0 {
key = []byte(" ")
}
megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key)
if err != nil {
return nil, err
}
return &InboundGroupSession{
MegolmInboundSession: *megolmSession,
}, nil
return session.MegolmInboundSessionFromPickled(pickled, key)
}
// NewInboundGroupSession creates a new inbound group session from a key
// exported from OutboundGroupSession.Key(). Returns error on failure.
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
func NewInboundGroupSession(sessionKey []byte) (InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, EmptyInput
}
megolmSession, err := session.NewMegolmInboundSession(sessionKey)
if err != nil {
return nil, err
}
return &InboundGroupSession{
MegolmInboundSession: *megolmSession,
}, nil
return session.NewMegolmInboundSession(sessionKey)
}
// InboundGroupSessionImport imports an inbound group session from a previous
// export. Returns error on failure.
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
func InboundGroupSessionImport(sessionKey []byte) (InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, EmptyInput
}
megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey)
if err != nil {
return nil, err
}
return &InboundGroupSession{
MegolmInboundSession: *megolmSession,
}, nil
return session.NewMegolmInboundSessionFromExport(sessionKey)
}
func NewBlankInboundGroupSession() *InboundGroupSession {
return &InboundGroupSession{}
}
// Clear clears the memory used to back this InboundGroupSession.
func (s *InboundGroupSession) Clear() error {
s.MegolmInboundSession = session.MegolmInboundSession{}
return nil
}
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := s.MegolmInboundSession.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
} else if len(pickled) == 0 {
return EmptyInput
}
sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key)
if err != nil {
return err
}
s.MegolmInboundSession = *sOlm
return nil
}
// Decrypt decrypts a message using the InboundGroupSession. Returns the the
// plain-text and message index on success. Returns error on failure.
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
return nil, 0, EmptyInput
}
plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message)
if err != nil {
return nil, 0, err
}
return plaintext, uint(messageIndex), nil
}
// ID returns a base64-encoded identifier for this session.
func (s *InboundGroupSession) ID() id.SessionID {
return s.MegolmInboundSession.SessionID()
}
// FirstKnownIndex returns the first message index we know how to decrypt.
func (s *InboundGroupSession) FirstKnownIndex() uint32 {
return s.MegolmInboundSession.InitialRatchet.Counter
}
// IsVerified check if the session has been verified as a valid session. (A
// session is verified either because the original session share was signed, or
// because we have subsequently successfully decrypted a message.)
func (s *InboundGroupSession) IsVerified() uint {
if s.MegolmInboundSession.SigningKeyVerified {
return 1
}
return 0
}
// Export returns the base64-encoded ratchet key for this session, at the given
// index, in a format which can be used by
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
// InboundGroupSession using the supplied key. Returns error on failure.
// if we do not have a session key corresponding to the given index (ie, it was
// sent before the session key was shared with us) the error will be
// returned.
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex)
if err != nil {
return nil, err
}
return res, nil
func NewBlankInboundGroupSession() InboundGroupSession {
return &session.MegolmInboundSession{}
}

View File

@ -0,0 +1,57 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
type InboundGroupSession interface {
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
// InboundGroupSession using the supplied key.
Pickle(key []byte) ([]byte, error)
// Unpickle loads an [InboundGroupSession] from a pickled base64 string.
// Decrypts the [InboundGroupSession] using the supplied key.
Unpickle(pickled, key []byte) error
// Decrypt decrypts a message using the [InboundGroupSession]. Returns the
// plain-text and message index on success. Returns error on failure. If
// the base64 couldn't be decoded then the error will be "INVALID_BASE64".
// If the message is for an unsupported version of the protocol then the
// error will be "BAD_MESSAGE_VERSION". If the message couldn't be decoded
// then the error will be BAD_MESSAGE_FORMAT". If the MAC on the message
// was invalid then the error will be "BAD_MESSAGE_MAC". If we do not have
// a session key corresponding to the message's index (ie, it was sent
// before the session key was shared with us) the error will be
// "OLM_UNKNOWN_MESSAGE_INDEX".
Decrypt(message []byte) ([]byte, uint, error)
// ID returns a base64-encoded identifier for this session.
ID() id.SessionID
// FirstKnownIndex returns the first message index we know how to decrypt.
FirstKnownIndex() uint32
// IsVerified check if the session has been verified as a valid session.
// (A session is verified either because the original session share was
// signed, or because we have subsequently successfully decrypted a
// message.)
IsVerified() bool
// Export returns the base64-encoded ratchet key for this session, at the
// given index, in a format which can be used by
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
// InboundGroupSession using the supplied key. Returns error on failure.
// if we do not have a session key corresponding to the given index (ie, it
// was sent before the session key was shared with us) the error will be
// "OLM_UNKNOWN_MESSAGE_INDEX".
Export(messageIndex uint32) ([]byte, error)
}
var _ InboundGroupSession = (*session.MegolmInboundSession)(nil)

View File

@ -13,19 +13,22 @@ import (
"maunium.net/go/mautrix/id"
)
// InboundGroupSession stores an inbound encrypted messaging session for a
// LibOlmInboundGroupSession stores an inbound encrypted messaging session for a
// group.
type InboundGroupSession struct {
type LibOlmInboundGroupSession struct {
int *C.OlmInboundGroupSession
mem []byte
}
// Ensure that LibOlmInboundGroupSession implements InboundGroupSession.
var _ InboundGroupSession = (*LibOlmInboundGroupSession)(nil)
// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled
// base64 string. Decrypts the InboundGroupSession using the supplied key.
// Returns error on failure. If the key doesn't match the one used to encrypt
// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
@ -42,7 +45,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
// If the sessionKey is not valid base64 the error will be
// "OLM_INVALID_BASE64". If the session_key is invalid the error will be
// "OLM_BAD_SESSION_KEY".
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
func NewInboundGroupSession(sessionKey []byte) (*LibOlmInboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, EmptyInput
}
@ -61,7 +64,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
// export. Returns error on failure. If the sessionKey is not valid base64
// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the
// error will be "OLM_BAD_SESSION_KEY".
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
func InboundGroupSessionImport(sessionKey []byte) (*LibOlmInboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, EmptyInput
}
@ -83,9 +86,9 @@ func inboundGroupSessionSize() uint {
}
// newInboundGroupSession initialises an empty InboundGroupSession.
func NewBlankInboundGroupSession() *InboundGroupSession {
func NewBlankInboundGroupSession() *LibOlmInboundGroupSession {
memory := make([]byte, inboundGroupSessionSize())
return &InboundGroupSession{
return &LibOlmInboundGroupSession{
int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])),
mem: memory,
}
@ -93,12 +96,12 @@ func NewBlankInboundGroupSession() *InboundGroupSession {
// lastError returns an error describing the most recent error to happen to an
// inbound group session.
func (s *InboundGroupSession) lastError() error {
func (s *LibOlmInboundGroupSession) lastError() error {
return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int))))
}
// Clear clears the memory used to back this InboundGroupSession.
func (s *InboundGroupSession) Clear() error {
func (s *LibOlmInboundGroupSession) Clear() error {
r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int))
if r == errorVal() {
return s.lastError()
@ -108,15 +111,15 @@ func (s *InboundGroupSession) Clear() error {
// pickleLen returns the number of bytes needed to store an inbound group
// session.
func (s *InboundGroupSession) pickleLen() uint {
func (s *LibOlmInboundGroupSession) pickleLen() uint {
return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int)))
}
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) []byte {
func (s *LibOlmInboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
panic(NoKeyProvided)
return nil, NoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_inbound_group_session(
@ -126,12 +129,12 @@ func (s *InboundGroupSession) Pickle(key []byte) []byte {
unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
panic(s.lastError())
return nil, s.lastError()
}
return pickled[:r]
return pickled[:r], nil
}
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
func (s *LibOlmInboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
} else if len(pickled) == 0 {
@ -150,16 +153,19 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
}
// Deprecated
func (s *InboundGroupSession) GobEncode() ([]byte, error) {
pickled := s.Pickle(pickleKey)
func (s *LibOlmInboundGroupSession) GobEncode() ([]byte, error) {
pickled, err := s.Pickle(pickleKey)
if err != nil {
return nil, err
}
length := base64.RawStdEncoding.DecodedLen(len(pickled))
rawPickled := make([]byte, length)
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
return rawPickled, err
}
// Deprecated
func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
func (s *LibOlmInboundGroupSession) GobDecode(rawPickled []byte) error {
if s == nil || s.int == nil {
*s = *NewBlankInboundGroupSession()
}
@ -170,8 +176,11 @@ func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
}
// Deprecated
func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
pickled := s.Pickle(pickleKey)
func (s *LibOlmInboundGroupSession) MarshalJSON() ([]byte, error) {
pickled, err := s.Pickle(pickleKey)
if err != nil {
return nil, err
}
quotes := make([]byte, len(pickled)+2)
quotes[0] = '"'
quotes[len(quotes)-1] = '"'
@ -180,7 +189,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
}
// Deprecated
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
func (s *LibOlmInboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString
}
@ -203,7 +212,7 @@ func clone(original []byte) []byte {
// unsupported version of the protocol then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error
// will be "BAD_MESSAGE_FORMAT".
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
func (s *LibOlmInboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
if len(message) == 0 {
return 0, EmptyInput
}
@ -228,7 +237,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
// error will be "BAD_MESSAGE_MAC". If we do not have a session key
// corresponding to the message's index (ie, it was sent before the session key
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
func (s *LibOlmInboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
return nil, 0, EmptyInput
}
@ -254,12 +263,12 @@ func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
}
// sessionIdLen returns the number of bytes needed to store a session ID.
func (s *InboundGroupSession) sessionIdLen() uint {
func (s *LibOlmInboundGroupSession) sessionIdLen() uint {
return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int)))
}
// ID returns a base64-encoded identifier for this session.
func (s *InboundGroupSession) ID() id.SessionID {
func (s *LibOlmInboundGroupSession) ID() id.SessionID {
sessionID := make([]byte, s.sessionIdLen())
r := C.olm_inbound_group_session_id(
(*C.OlmInboundGroupSession)(s.int),
@ -272,20 +281,20 @@ func (s *InboundGroupSession) ID() id.SessionID {
}
// FirstKnownIndex returns the first message index we know how to decrypt.
func (s *InboundGroupSession) FirstKnownIndex() uint32 {
func (s *LibOlmInboundGroupSession) FirstKnownIndex() uint32 {
return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int)))
}
// IsVerified check if the session has been verified as a valid session. (A
// session is verified either because the original session share was signed, or
// because we have subsequently successfully decrypted a message.)
func (s *InboundGroupSession) IsVerified() uint {
return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int)))
func (s *LibOlmInboundGroupSession) IsVerified() bool {
return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) == 1
}
// exportLen returns the number of bytes needed to export an inbound group
// session.
func (s *InboundGroupSession) exportLen() uint {
func (s *LibOlmInboundGroupSession) exportLen() uint {
return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int)))
}
@ -296,7 +305,7 @@ func (s *InboundGroupSession) exportLen() uint {
// if we do not have a session key corresponding to the given index (ie, it was
// sent before the session key was shared with us) the error will be
// "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
func (s *LibOlmInboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
key := make([]byte, s.exportLen())
r := C.olm_export_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),

View File

@ -2,13 +2,6 @@
package olm
import (
"maunium.net/go/mautrix/id"
)
// Signatures is the data structure used to sign JSON objects.
type Signatures map[id.UserID]map[id.DeviceKeyID]string
// Version returns the version number of the olm library.
func Version() (major, minor, patch uint8) {
return 3, 2, 15

View File

@ -5,12 +5,6 @@ package olm
// #cgo LDFLAGS: -lolm -lstdc++
// #include <olm/olm.h>
import "C"
import (
"maunium.net/go/mautrix/id"
)
// Signatures is the data structure used to sign JSON objects.
type Signatures map[id.UserID]map[id.DeviceKeyID]string
// Version returns the version number of the olm library.
func Version() (major, minor, patch uint8) {

View File

@ -4,21 +4,14 @@ package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
// OutboundGroupSession stores an outbound encrypted messaging session for a
// group.
type OutboundGroupSession struct {
session.MegolmOutboundSession
}
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
// base64 string. Decrypts the OutboundGroupSession using the supplied key.
// Returns error on failure. If the key doesn't match the one used to encrypt
// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) {
func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
@ -26,86 +19,19 @@ func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession
if lenKey == 0 {
key = []byte(" ")
}
megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key)
if err != nil {
return nil, err
}
return &OutboundGroupSession{
MegolmOutboundSession: *megolmSession,
}, nil
return session.MegolmOutboundSessionFromPickled(pickled, key)
}
// NewOutboundGroupSession creates a new outbound group session.
func NewOutboundGroupSession() *OutboundGroupSession {
megolmSession, err := session.NewMegolmOutboundSession()
func NewOutboundGroupSession() OutboundGroupSession {
session, err := session.NewMegolmOutboundSession()
if err != nil {
panic(err)
}
return &OutboundGroupSession{
MegolmOutboundSession: *megolmSession,
}
return session
}
// newOutboundGroupSession initialises an empty OutboundGroupSession.
func NewBlankOutboundGroupSession() *OutboundGroupSession {
return &OutboundGroupSession{}
}
// Clear clears the memory used to back this OutboundGroupSession.
func (s *OutboundGroupSession) Clear() error {
s.MegolmOutboundSession = session.MegolmOutboundSession{}
return nil
}
// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := s.MegolmOutboundSession.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
}
return s.MegolmOutboundSession.Unpickle(pickled, key)
}
// Encrypt encrypts a message using the Session. Returns the encrypted message
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
if len(plaintext) == 0 {
panic(EmptyInput)
}
message, err := s.MegolmOutboundSession.Encrypt(plaintext)
if err != nil {
panic(err)
}
return message
}
// ID returns a base64-encoded identifier for this session.
func (s *OutboundGroupSession) ID() id.SessionID {
return s.MegolmOutboundSession.SessionID()
}
// MessageIndex returns the message index for this session. Each message is
// sent with an increasing index; this returns the index for the next message.
func (s *OutboundGroupSession) MessageIndex() uint {
return uint(s.MegolmOutboundSession.Ratchet.Counter)
}
// Key returns the base64-encoded current ratchet key for this session.
func (s *OutboundGroupSession) Key() string {
message, err := s.MegolmOutboundSession.SessionSharingMessage()
if err != nil {
panic(err)
}
return string(message)
// NewBlankOutboundGroupSession initialises an empty OutboundGroupSession.
func NewBlankOutboundGroupSession() OutboundGroupSession {
return &session.MegolmOutboundSession{}
}

View File

@ -0,0 +1,39 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
type OutboundGroupSession interface {
// Pickle returns a Session as a base64 string. Encrypts the Session using
// the supplied key.
Pickle(key []byte) ([]byte, error)
// Unpickle loads an [OutboundGroupSession] from a pickled base64 string.
// Decrypts the [OutboundGroupSession] using the supplied key.
Unpickle(pickled, key []byte) error
// Encrypt encrypts a message using the [OutboundGroupSession]. Returns the
// encrypted message as base64.
Encrypt(plaintext []byte) ([]byte, error)
// ID returns a base64-encoded identifier for this session.
ID() id.SessionID
// MessageIndex returns the message index for this session. Each message
// is sent with an increasing index; this returns the index for the next
// message.
MessageIndex() uint
// Key returns the base64-encoded current ratchet key for this session.
Key() string
}
var _ OutboundGroupSession = (*session.MegolmOutboundSession)(nil)

View File

@ -14,29 +14,32 @@ import (
"maunium.net/go/mautrix/id"
)
// OutboundGroupSession stores an outbound encrypted messaging session for a
// group.
type OutboundGroupSession struct {
// LibOlmOutboundGroupSession stores an outbound encrypted messaging session
// for a group.
type LibOlmOutboundGroupSession struct {
int *C.OlmOutboundGroupSession
mem []byte
}
// Ensure that LibOlmOutboundGroupSession implements OutboundGroupSession.
var _ OutboundGroupSession = (*LibOlmOutboundGroupSession)(nil)
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
// base64 string. Decrypts the OutboundGroupSession using the supplied key.
// Returns error on failure. If the key doesn't match the one used to encrypt
// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) {
func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
s := NewBlankOutboundGroupSession()
s := NewBlankLibOlmOutboundGroupSession()
return s, s.Unpickle(pickled, key)
}
// NewOutboundGroupSession creates a new outbound group session.
func NewOutboundGroupSession() *OutboundGroupSession {
s := NewBlankOutboundGroupSession()
func NewOutboundGroupSession() OutboundGroupSession {
s := NewBlankLibOlmOutboundGroupSession()
random := make([]byte, s.createRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
@ -58,23 +61,29 @@ func outboundGroupSessionSize() uint {
return uint(C.olm_outbound_group_session_size())
}
// newOutboundGroupSession initialises an empty OutboundGroupSession.
func NewBlankOutboundGroupSession() *OutboundGroupSession {
// NewBlankLibOlmOutboundGroupSession initialises an empty
// LibOlmOutboundGroupSession.
func NewBlankLibOlmOutboundGroupSession() *LibOlmOutboundGroupSession {
memory := make([]byte, outboundGroupSessionSize())
return &OutboundGroupSession{
return &LibOlmOutboundGroupSession{
int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
// NewBlankOutboundGroupSession initialises an empty OutboundGroupSession.
func NewBlankOutboundGroupSession() OutboundGroupSession {
return NewBlankLibOlmOutboundGroupSession()
}
// lastError returns an error describing the most recent error to happen to an
// outbound group session.
func (s *OutboundGroupSession) lastError() error {
func (s *LibOlmOutboundGroupSession) lastError() error {
return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int))))
}
// Clear clears the memory used to back this OutboundGroupSession.
func (s *OutboundGroupSession) Clear() error {
func (s *LibOlmOutboundGroupSession) Clear() error {
r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int))
if r == errorVal() {
return s.lastError()
@ -85,15 +94,15 @@ func (s *OutboundGroupSession) Clear() error {
// pickleLen returns the number of bytes needed to store an outbound group
// session.
func (s *OutboundGroupSession) pickleLen() uint {
func (s *LibOlmOutboundGroupSession) pickleLen() uint {
return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int)))
}
// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) []byte {
func (s *LibOlmOutboundGroupSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
panic(NoKeyProvided)
return nil, NoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_outbound_group_session(
@ -103,12 +112,12 @@ func (s *OutboundGroupSession) Pickle(key []byte) []byte {
unsafe.Pointer(&pickled[0]),
C.size_t(len(pickled)))
if r == errorVal() {
panic(s.lastError())
return nil, s.lastError()
}
return pickled[:r]
return pickled[:r], nil
}
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
func (s *LibOlmOutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
}
@ -125,18 +134,21 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
}
// Deprecated
func (s *OutboundGroupSession) GobEncode() ([]byte, error) {
pickled := s.Pickle(pickleKey)
func (s *LibOlmOutboundGroupSession) GobEncode() ([]byte, error) {
pickled, err := s.Pickle(pickleKey)
if err != nil {
return nil, err
}
length := base64.RawStdEncoding.DecodedLen(len(pickled))
rawPickled := make([]byte, length)
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
return rawPickled, err
}
// Deprecated
func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
func (s *LibOlmOutboundGroupSession) GobDecode(rawPickled []byte) error {
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
*s = *NewBlankLibOlmOutboundGroupSession()
}
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
pickled := make([]byte, length)
@ -145,8 +157,11 @@ func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
}
// Deprecated
func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
pickled := s.Pickle(pickleKey)
func (s *LibOlmOutboundGroupSession) MarshalJSON() ([]byte, error) {
pickled, err := s.Pickle(pickleKey)
if err != nil {
return nil, err
}
quotes := make([]byte, len(pickled)+2)
quotes[0] = '"'
quotes[len(quotes)-1] = '"'
@ -155,33 +170,33 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
}
// Deprecated
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
func (s *LibOlmOutboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
*s = *NewBlankLibOlmOutboundGroupSession()
}
return s.Unpickle(data[1:len(data)-1], pickleKey)
}
// createRandomLen returns the number of random bytes needed to create an
// Account.
func (s *OutboundGroupSession) createRandomLen() uint {
func (s *LibOlmOutboundGroupSession) createRandomLen() uint {
return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int)))
}
// encryptMsgLen returns the size of the next message in bytes for the given
// number of plain-text bytes.
func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
func (s *LibOlmOutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen)))
}
// Encrypt encrypts a message using the Session. Returns the encrypted message
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
func (s *LibOlmOutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
if len(plaintext) == 0 {
panic(EmptyInput)
return nil, EmptyInput
}
message := make([]byte, s.encryptMsgLen(len(plaintext)))
r := C.olm_group_encrypt(
@ -191,18 +206,18 @@ func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
(*C.uint8_t)(&message[0]),
C.size_t(len(message)))
if r == errorVal() {
panic(s.lastError())
return nil, s.lastError()
}
return message[:r]
return message[:r], nil
}
// sessionIdLen returns the number of bytes needed to store a session ID.
func (s *OutboundGroupSession) sessionIdLen() uint {
func (s *LibOlmOutboundGroupSession) sessionIdLen() uint {
return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int)))
}
// ID returns a base64-encoded identifier for this session.
func (s *OutboundGroupSession) ID() id.SessionID {
func (s *LibOlmOutboundGroupSession) ID() id.SessionID {
sessionID := make([]byte, s.sessionIdLen())
r := C.olm_outbound_group_session_id(
(*C.OlmOutboundGroupSession)(s.int),
@ -216,17 +231,17 @@ func (s *OutboundGroupSession) ID() id.SessionID {
// MessageIndex returns the message index for this session. Each message is
// sent with an increasing index; this returns the index for the next message.
func (s *OutboundGroupSession) MessageIndex() uint {
func (s *LibOlmOutboundGroupSession) MessageIndex() uint {
return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int)))
}
// sessionKeyLen returns the number of bytes needed to store a session key.
func (s *OutboundGroupSession) sessionKeyLen() uint {
func (s *LibOlmOutboundGroupSession) sessionKeyLen() uint {
return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int)))
}
// Key returns the base64-encoded current ratchet key for this session.
func (s *OutboundGroupSession) Key() string {
func (s *LibOlmOutboundGroupSession) Key() string {
sessionKey := make([]byte, s.sessionKeyLen())
r := C.olm_outbound_group_session_key(
(*C.OlmOutboundGroupSession)(s.int),

View File

@ -1,71 +1,29 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// When the goolm build flag is enabled, this file will make [PKSigning]
// constructors use the goolm constuctors.
//go:build goolm
package olm
import (
"encoding/json"
import "maunium.net/go/mautrix/crypto/goolm/pk"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/pk"
"maunium.net/go/mautrix/id"
)
// PkSigning stores a key pair for signing messages.
type PkSigning struct {
pk.Signing
PublicKey id.Ed25519
Seed []byte
// NewPKSigningFromSeed creates a new PKSigning object using the given seed.
func NewPKSigningFromSeed(seed []byte) (PKSigning, error) {
return pk.NewSigningFromSeed(seed)
}
// Clear clears the underlying memory of a PkSigning object.
func (p *PkSigning) Clear() {
p.Signing = pk.Signing{}
// NewPKSigning creates a new [PKSigning] object, containing a key pair for
// signing messages.
func NewPKSigning() (PKSigning, error) {
return pk.NewSigning()
}
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
p := &PkSigning{}
signing, err := pk.NewSigningFromSeed(seed)
if err != nil {
return nil, err
}
p.Signing = *signing
p.Seed = seed
p.PublicKey = p.Signing.PublicKey()
return p, nil
}
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
func NewPkSigning() (*PkSigning, error) {
p := &PkSigning{}
signing, err := pk.NewSigning()
if err != nil {
return nil, err
}
p.Signing = *signing
p.Seed = signing.Seed
p.PublicKey = p.Signing.PublicKey()
return p, err
}
// Sign creates a signature for the given message using this key.
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
return p.Signing.Sign(message), nil
}
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
if err != nil {
return "", err
}
return string(signature), nil
func NewPKDecryption(privateKey []byte) (PKDecryption, error) {
return pk.NewDecryption()
}

View File

@ -0,0 +1,41 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/pk"
"maunium.net/go/mautrix/id"
)
// PKSigning is an interface for signing messages.
type PKSigning interface {
// Seed returns the seed of the key.
Seed() []byte
// PublicKey returns the public key.
PublicKey() id.Ed25519
// Sign creates a signature for the given message using this key.
Sign(message []byte) ([]byte, error)
// SignJSON creates a signature for the given object after encoding it to
// canonical JSON.
SignJSON(obj any) (string, error)
}
var _ PKSigning = (*pk.Signing)(nil)
// PKDecryption is an interface for decrypting messages.
type PKDecryption interface {
// PublicKey returns the public key.
PublicKey() id.Curve25519
// Decrypt verifies and decrypts the given message.
Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error)
}
var _ PKDecryption = (*pk.Decryption)(nil)

View File

@ -1,3 +1,9 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !goolm
package olm
@ -18,14 +24,17 @@ import (
"maunium.net/go/mautrix/id"
)
// PkSigning stores a key pair for signing messages.
type PkSigning struct {
// LibOlmPKSigning stores a key pair for signing messages.
type LibOlmPKSigning struct {
int *C.OlmPkSigning
mem []byte
PublicKey id.Ed25519
Seed []byte
publicKey id.Ed25519
seed []byte
}
// Ensure that LibOlmPKSigning implements PKSigning.
var _ PKSigning = (*LibOlmPKSigning)(nil)
func pkSigningSize() uint {
return uint(C.olm_pk_signing_size())
}
@ -42,48 +51,57 @@ func pkSigningSignatureLength() uint {
return uint(C.olm_pk_signature_length())
}
func NewBlankPkSigning() *PkSigning {
func newBlankPKSigning() *LibOlmPKSigning {
memory := make([]byte, pkSigningSize())
return &PkSigning{
return &LibOlmPKSigning{
int: C.olm_pk_signing(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
// Clear clears the underlying memory of a PkSigning object.
func (p *PkSigning) Clear() {
C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int))
}
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
p := NewBlankPkSigning()
p.Clear()
// NewPKSigningFromSeed creates a new [PKSigning] object using the given seed.
func NewPKSigningFromSeed(seed []byte) (PKSigning, error) {
p := newBlankPKSigning()
p.clear()
pubKey := make([]byte, pkSigningPublicKeyLength())
if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int),
unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() {
return nil, p.lastError()
}
p.PublicKey = id.Ed25519(pubKey)
p.Seed = seed
p.publicKey = id.Ed25519(pubKey)
p.seed = seed
return p, nil
}
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
func NewPkSigning() (*PkSigning, error) {
// NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for
// signing messages.
func NewPKSigning() (PKSigning, error) {
// Generate the seed
seed := make([]byte, pkSigningSeedLength())
_, err := rand.Read(seed)
if err != nil {
panic(NotEnoughGoRandom)
}
pk, err := NewPkSigningFromSeed(seed)
pk, err := NewPKSigningFromSeed(seed)
return pk, err
}
func (p *LibOlmPKSigning) PublicKey() id.Ed25519 {
return p.publicKey
}
func (p *LibOlmPKSigning) Seed() []byte {
return p.seed
}
// clear clears the underlying memory of a LibOlmPKSigning object.
func (p *LibOlmPKSigning) clear() {
C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int))
}
// Sign creates a signature for the given message using this key.
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) {
signature := make([]byte, pkSigningSignatureLength())
if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)),
(*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() {
@ -93,7 +111,7 @@ func (p *PkSigning) Sign(message []byte) ([]byte, error) {
}
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
@ -107,12 +125,13 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
return string(signature), nil
}
// lastError returns the last error that happened in relation to this PkSigning object.
func (p *PkSigning) lastError() error {
// lastError returns the last error that happened in relation to this
// LibOlmPKSigning object.
func (p *LibOlmPKSigning) lastError() error {
return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int))))
}
type PkDecryption struct {
type LibOlmPKDecryption struct {
int *C.OlmPkDecryption
mem []byte
PublicKey []byte
@ -126,13 +145,13 @@ func pkDecryptionPublicKeySize() uint {
return uint(C.olm_pk_key_length())
}
func NewPkDecryption(privateKey []byte) (*PkDecryption, error) {
func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) {
memory := make([]byte, pkDecryptionSize())
p := &PkDecryption{
p := &LibOlmPKDecryption{
int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])),
mem: memory,
}
p.Clear()
p.clear()
pubKey := make([]byte, pkDecryptionPublicKeySize())
if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int),
@ -145,7 +164,7 @@ func NewPkDecryption(privateKey []byte) (*PkDecryption, error) {
return p, nil
}
func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext))))
plaintext := make([]byte, maxPlaintextLength)
@ -162,11 +181,12 @@ func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byt
}
// Clear clears the underlying memory of a PkDecryption object.
func (p *PkDecryption) Clear() {
func (p *LibOlmPKDecryption) clear() {
C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int))
}
// lastError returns the last error that happened in relation to this PkDecryption object.
func (p *PkDecryption) lastError() error {
// lastError returns the last error that happened in relation to this
// LibOlmPKDecryption object.
func (p *LibOlmPKDecryption) lastError() error {
return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int))))
}

45
crypto/olm/pk_test.go Normal file
View File

@ -0,0 +1,45 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Only run this test if goolm is disabled (that is, libolm is used).
//go:build !goolm
package olm_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"maunium.net/go/mautrix/crypto/goolm/pk"
"maunium.net/go/mautrix/crypto/olm"
)
func FuzzSign(f *testing.F) {
seed := []byte("Quohboh3ka3ooghequier9lee8Bahwoh")
goolmPkSigning, err := pk.NewSigningFromSeed(seed)
require.NoError(f, err)
libolmPkSigning, err := olm.NewPKSigningFromSeed(seed)
require.NoError(f, err)
f.Add([]byte("message"))
f.Fuzz(func(t *testing.T, message []byte) {
// libolm breaks with empty messages, so don't perform differential
// fuzzing on that.
if len(message) == 0 {
return
}
libolmResult, libolmErr := libolmPkSigning.Sign(message)
goolmResult, goolmErr := goolmPkSigning.Sign(message)
assert.Equal(t, goolmErr, libolmErr)
assert.Equal(t, goolmResult, libolmResult)
})
}

View File

@ -1,110 +1,28 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// When the goolm build flag is enabled, this file will make [PKSigning]
// constructors use the goolm constuctors.
//go:build goolm
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
// Session stores an end to end encrypted messaging session.
type Session struct {
session.OlmSession
}
import "maunium.net/go/mautrix/crypto/goolm/session"
// SessionFromPickled loads a Session from a pickled base64 string. Decrypts
// the Session using the supplied key. Returns error on failure.
func SessionFromPickled(pickled, key []byte) (*Session, error) {
func SessionFromPickled(pickled, key []byte) (Session, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
s := NewBlankSession()
s := session.NewOlmSession()
return s, s.Unpickle(pickled, key)
}
func NewBlankSession() *Session {
return &Session{}
}
// Clear clears the memory used to back this Session.
func (s *Session) Clear() error {
s.OlmSession = session.OlmSession{}
return nil
}
// Pickle returns a Session as a base64 string. Encrypts the Session using the
// supplied key.
func (s *Session) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := s.OlmSession.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
func (s *Session) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
} else if len(pickled) == 0 {
return EmptyInput
}
sOlm, err := session.OlmSessionFromPickled(pickled, key)
if err != nil {
return err
}
s.OlmSession = *sOlm
return nil
}
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match. Returns error on
// failure.
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
return s.MatchesInboundSessionFrom("", oneTimeKeyMsg)
}
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match. Returns error on
// failure.
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if theirIdentityKey != "" {
theirKey := id.Curve25519(theirIdentityKey)
return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg))
}
return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg))
}
// Encrypt encrypts a message using the Session. Returns the encrypted message
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
if len(plaintext) == 0 {
panic(EmptyInput)
}
messageType, message, err := s.OlmSession.Encrypt(plaintext, nil)
if err != nil {
panic(err)
}
return messageType, message
}
// Decrypt decrypts a message using the Session. Returns the the plain-text on
// success. Returns error on failure.
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
return nil, EmptyInput
}
return s.OlmSession.Decrypt([]byte(message), msgType)
}
// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
func (s *Session) Describe() string {
return s.OlmSession.Describe()
func NewBlankSession() Session {
return session.NewOlmSession()
}

View File

@ -0,0 +1,75 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
type Session interface {
// Pickle returns a Session as a base64 string. Encrypts the Session using
// the supplied key.
Pickle(key []byte) ([]byte, error)
// Unpickle loads a Session from a pickled base64 string. Decrypts the
// Session using the supplied key.
Unpickle(pickled, key []byte) error
// ID returns an identifier for this Session. Will be the same for both
// ends of the conversation.
ID() id.SessionID
// HasReceivedMessage returns true if this session has received any
// message.
HasReceivedMessage() bool
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the
// session matches. Returns false if the session does not match. Returns
// error on failure. If the base64 couldn't be decoded then the error will
// be "INVALID_BASE64". If the message was for an unsupported protocol
// version then the error will be "BAD_MESSAGE_VERSION". If the message
// couldn't be decoded then then the error will be "BAD_MESSAGE_FORMAT".
MatchesInboundSession(oneTimeKeyMsg string) (bool, error)
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this
// in-bound Session. This can happen if multiple messages are sent to this
// Account before this Account sends a message in reply. Returns true if
// the session matches. Returns false if the session does not match.
// Returns error on failure. If the base64 couldn't be decoded then the
// error will be "INVALID_BASE64". If the message was for an unsupported
// protocol version then the error will be "BAD_MESSAGE_VERSION". If the
// message couldn't be decoded then then the error will be
// "BAD_MESSAGE_FORMAT".
MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error)
// EncryptMsgType returns the type of the next message that Encrypt will
// return. Returns MsgTypePreKey if the message will be a PRE_KEY message.
// Returns MsgTypeMsg if the message will be a normal message.
EncryptMsgType() id.OlmMsgType
// Encrypt encrypts a message using the Session. Returns the encrypted
// message as base64.
Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error)
// Decrypt decrypts a message using the Session. Returns the plain-text on
// success. Returns error on failure. If the base64 couldn't be decoded
// then the error will be "INVALID_BASE64". If the message is for an
// unsupported version of the protocol then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error
// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then
// the error will be "BAD_MESSAGE_MAC".
Decrypt(message string, msgType id.OlmMsgType) ([]byte, error)
// Describe generates a string describing the internal state of an olm
// session for debugging and logging purposes.
Describe() string
}
var _ Session = (*session.OlmSession)(nil)

View File

@ -1,3 +1,9 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !goolm
package olm
@ -24,12 +30,15 @@ import (
"maunium.net/go/mautrix/id"
)
// Session stores an end to end encrypted messaging session.
type Session struct {
// LibOlmSession stores an end to end encrypted messaging session.
type LibOlmSession struct {
int *C.OlmSession
mem []byte
}
// Ensure that LibOlmSession implements Session.
var _ Session = (*LibOlmSession)(nil)
// sessionSize is the size of a session object in bytes.
func sessionSize() uint {
return uint(C.olm_session_size())
@ -40,7 +49,7 @@ func sessionSize() uint {
// doesn't match the one used to encrypt the Session then the error will be
// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be
// "INVALID_BASE64".
func SessionFromPickled(pickled, key []byte) (*Session, error) {
func SessionFromPickled(pickled, key []byte) (Session, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
@ -48,22 +57,26 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) {
return s, s.Unpickle(pickled, key)
}
func NewBlankSession() *Session {
func NewBlankLibOlmSession() *LibOlmSession {
memory := make([]byte, sessionSize())
return &Session{
return &LibOlmSession{
int: C.olm_session(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
func NewBlankSession() Session {
return NewBlankLibOlmSession()
}
// lastError returns an error describing the most recent error to happen to a
// session.
func (s *Session) lastError() error {
func (s *LibOlmSession) lastError() error {
return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int))))
}
// Clear clears the memory used to back this Session.
func (s *Session) Clear() error {
func (s *LibOlmSession) Clear() error {
r := C.olm_clear_session((*C.OlmSession)(s.int))
if r == errorVal() {
return s.lastError()
@ -72,31 +85,31 @@ func (s *Session) Clear() error {
}
// pickleLen returns the number of bytes needed to store a session.
func (s *Session) pickleLen() uint {
func (s *LibOlmSession) pickleLen() uint {
return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int)))
}
// createOutboundRandomLen returns the number of random bytes needed to create
// an outbound session.
func (s *Session) createOutboundRandomLen() uint {
func (s *LibOlmSession) createOutboundRandomLen() uint {
return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int)))
}
// idLen returns the length of the buffer needed to return the id for this
// session.
func (s *Session) idLen() uint {
func (s *LibOlmSession) idLen() uint {
return uint(C.olm_session_id_length((*C.OlmSession)(s.int)))
}
// encryptRandomLen returns the number of random bytes needed to encrypt the
// next message.
func (s *Session) encryptRandomLen() uint {
func (s *LibOlmSession) encryptRandomLen() uint {
return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int)))
}
// encryptMsgLen returns the size of the next message in bytes for the given
// number of plain-text bytes.
func (s *Session) encryptMsgLen(plainTextLen int) uint {
func (s *LibOlmSession) encryptMsgLen(plainTextLen int) uint {
return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen)))
}
@ -107,7 +120,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
// unsupported version of the protocol then the error will be
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error
// will be "BAD_MESSAGE_FORMAT".
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
func (s *LibOlmSession) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
if len(message) == 0 {
return 0, EmptyInput
}
@ -124,9 +137,9 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
// Pickle returns a Session as a base64 string. Encrypts the Session using the
// supplied key.
func (s *Session) Pickle(key []byte) []byte {
func (s *LibOlmSession) Pickle(key []byte) ([]byte, error) {
if len(key) == 0 {
panic(NoKeyProvided)
return nil, NoKeyProvided
}
pickled := make([]byte, s.pickleLen())
r := C.olm_pickle_session(
@ -138,10 +151,12 @@ func (s *Session) Pickle(key []byte) []byte {
if r == errorVal() {
panic(s.lastError())
}
return pickled[:r]
return pickled[:r], nil
}
func (s *Session) Unpickle(pickled, key []byte) error {
// Unpickle unpickles the base64-encoded Olm session decrypting it with the
// provided key. This function mutates the input pickled data slice.
func (s *LibOlmSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
}
@ -158,18 +173,21 @@ func (s *Session) Unpickle(pickled, key []byte) error {
}
// Deprecated
func (s *Session) GobEncode() ([]byte, error) {
pickled := s.Pickle(pickleKey)
func (s *LibOlmSession) GobEncode() ([]byte, error) {
pickled, err := s.Pickle(pickleKey)
if err != nil {
return nil, err
}
length := base64.RawStdEncoding.DecodedLen(len(pickled))
rawPickled := make([]byte, length)
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
return rawPickled, err
}
// Deprecated
func (s *Session) GobDecode(rawPickled []byte) error {
func (s *LibOlmSession) GobDecode(rawPickled []byte) error {
if s == nil || s.int == nil {
*s = *NewBlankSession()
*s = *NewBlankLibOlmSession()
}
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
pickled := make([]byte, length)
@ -178,8 +196,11 @@ func (s *Session) GobDecode(rawPickled []byte) error {
}
// Deprecated
func (s *Session) MarshalJSON() ([]byte, error) {
pickled := s.Pickle(pickleKey)
func (s *LibOlmSession) MarshalJSON() ([]byte, error) {
pickled, err := s.Pickle(pickleKey)
if err != nil {
return nil, err
}
quotes := make([]byte, len(pickled)+2)
quotes[0] = '"'
quotes[len(quotes)-1] = '"'
@ -188,19 +209,19 @@ func (s *Session) MarshalJSON() ([]byte, error) {
}
// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
func (s *LibOlmSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString
}
if s == nil || s.int == nil {
*s = *NewBlankSession()
*s = *NewBlankLibOlmSession()
}
return s.Unpickle(data[1:len(data)-1], pickleKey)
}
// Id returns an identifier for this Session. Will be the same for both ends
// of the conversation.
func (s *Session) ID() id.SessionID {
func (s *LibOlmSession) ID() id.SessionID {
sessionID := make([]byte, s.idLen())
r := C.olm_session_id(
(*C.OlmSession)(s.int),
@ -213,7 +234,7 @@ func (s *Session) ID() id.SessionID {
}
// HasReceivedMessage returns true if this session has received any message.
func (s *Session) HasReceivedMessage() bool {
func (s *LibOlmSession) HasReceivedMessage() bool {
switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) {
case 0:
return false
@ -230,7 +251,7 @@ func (s *Session) HasReceivedMessage() bool {
// "INVALID_BASE64". If the message was for an unsupported protocol version
// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
func (s *LibOlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
if len(oneTimeKeyMsg) == 0 {
return false, EmptyInput
}
@ -255,7 +276,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
// "INVALID_BASE64". If the message was for an unsupported protocol version
// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
func (s *LibOlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
return false, EmptyInput
}
@ -278,7 +299,7 @@ func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg stri
// return. Returns MsgTypePreKey if the message will be a PRE_KEY message.
// Returns MsgTypeMsg if the message will be a normal message. Returns error
// on failure.
func (s *Session) EncryptMsgType() id.OlmMsgType {
func (s *LibOlmSession) EncryptMsgType() id.OlmMsgType {
switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) {
case C.size_t(id.OlmMsgTypePreKey):
return id.OlmMsgTypePreKey
@ -291,15 +312,16 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
// Encrypt encrypts a message using the Session. Returns the encrypted message
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
func (s *LibOlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
panic(EmptyInput)
return 0, nil, EmptyInput
}
// Make the slice be at least length 1
random := make([]byte, s.encryptRandomLen()+1)
_, err := rand.Read(random)
if err != nil {
panic(NotEnoughGoRandom)
// TODO can we just return err here?
return 0, nil, NotEnoughGoRandom
}
messageType := s.EncryptMsgType()
message := make([]byte, s.encryptMsgLen(len(plaintext)))
@ -312,9 +334,9 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
unsafe.Pointer(&message[0]),
C.size_t(len(message)))
if r == errorVal() {
panic(s.lastError())
return 0, nil, s.lastError()
}
return messageType, message[:r]
return messageType, message[:r], nil
}
// Decrypt decrypts a message using the Session. Returns the the plain-text on
@ -324,7 +346,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT".
// If the MAC on the message was invalid then the error will be
// "BAD_MESSAGE_MAC".
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
func (s *LibOlmSession) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
return nil, EmptyInput
}
@ -351,7 +373,7 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error)
const maxDescribeSize = 600
// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
func (s *Session) Describe() string {
func (s *LibOlmSession) Describe() string {
desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize)))
defer C.free(unsafe.Pointer(desc))
C.meowlm_session_describe(

View File

@ -0,0 +1,87 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Only run this test if goolm is disabled (that is, libolm is used).
//go:build !goolm
package olm_test
import (
"testing"
"github.com/stretchr/testify/assert"
goolmsession "maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/olm"
)
func TestBlankSession(t *testing.T) {
libolmSession := olm.NewBlankLibOlmSession()
goolmSession := goolmsession.NewOlmSession()
assert.Equal(t, libolmSession.ID(), goolmSession.ID())
assert.Equal(t, libolmSession.HasReceivedMessage(), goolmSession.HasReceivedMessage())
assert.Equal(t, libolmSession.EncryptMsgType(), goolmSession.EncryptMsgType())
assert.Equal(t, libolmSession.Describe(), goolmSession.Describe())
libolmPickled, err := libolmSession.Pickle([]byte("test"))
assert.NoError(t, err)
goolmPickled, err := goolmSession.Pickle([]byte("test"))
assert.NoError(t, err)
assert.Equal(t, goolmPickled, libolmPickled)
}
func TestSessionPickle(t *testing.T) {
pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT")
pickleKey := []byte("secret_key")
goolmSession := goolmsession.NewOlmSession()
err := goolmSession.Unpickle(pickledDataFromLibOlm, pickleKey)
assert.NoError(t, err)
libolmSession := olm.NewBlankLibOlmSession()
err = libolmSession.Unpickle(pickledDataFromLibOlm, pickleKey)
assert.NoError(t, err)
// Reset the pickle data since libolmSession.Unpickle modifies it.
pickledDataFromLibOlm = []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT")
goolmPickled, err := goolmSession.Pickle(pickleKey)
assert.NoError(t, err)
assert.Equal(t, pickledDataFromLibOlm, goolmPickled)
libolmPickled, err := libolmSession.Pickle(pickleKey)
assert.Equal(t, pickledDataFromLibOlm, libolmPickled)
assert.NoError(t, err)
}
// func FuzzSession(f *testing.F) {
// f.Add([]byte("plaintext"))
// identityKeyAlice, err := crypto.Curve25519GenerateKey(nil)
// require.NoError(f, err)
// identityKeyBob, err := crypto.Curve25519GenerateKey(nil)
// require.NoError(f, err)
// f.Fuzz(func(t *testing.T, plaintext []byte) {
// // identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey
// goolmSession, err := goolmsession.NewOutboundOlmSession(identityKeyAlice, identityKeyBob.PublicKey, otk)
// assert.NoError(t, err)
// libolmAccount := olm.NewAccount()
// libolmSession, err := libolmAccount.NewInboundSessionFrom(id.Curve25519(identityKeyBob.PublicKey), string(otk))
// goolmMsgType, goolmCiphertext, goolmErr := goolmSession.Encrypt(plaintext)
// assert.NoError(t, goolmErr)
// libolmMsgType, libolmCiphertext, libolmErr := libolmSession.Encrypt(plaintext)
// assert.NoError(t, libolmErr)
// assert.Equal(t, goolmMsgType, libolmMsgType)
// assert.Equal(t, goolmCiphertext, libolmCiphertext)
// })
// }

View File

@ -54,9 +54,9 @@ func (session *OlmSession) Describe() string {
return session.Internal.Describe()
}
func wrapSession(session *olm.Session) *OlmSession {
func wrapSession(session olm.Session) *OlmSession {
return &OlmSession{
Internal: *session,
Internal: session,
ExpirationMixin: ExpirationMixin{
TimeMixin: TimeMixin{
CreationTime: time.Now(),
@ -68,7 +68,7 @@ func wrapSession(session *olm.Session) *OlmSession {
}
func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) {
session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext)
session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext)
if err != nil {
return nil, err
}
@ -76,7 +76,7 @@ func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, cipher
return wrapSession(session), nil
}
func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
session.LastEncryptedTime = time.Now()
return session.Internal.Encrypt(plaintext)
}
@ -120,7 +120,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
return nil, err
}
return &InboundGroupSession{
Internal: *igs,
Internal: igs,
SigningKey: signingKey,
SenderKey: senderKey,
RoomID: roomID,
@ -148,7 +148,7 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error {
if err != nil {
return err
}
igs.Internal = *imported
igs.Internal = imported
return nil
}
@ -182,7 +182,7 @@ type OutboundGroupSession struct {
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession {
ogs := &OutboundGroupSession{
Internal: *olm.NewOutboundGroupSession(),
Internal: olm.NewOutboundGroupSession(),
ExpirationMixin: ExpirationMixin{
TimeMixin: TimeMixin{
CreationTime: time.Now(),
@ -237,7 +237,7 @@ func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
}
ogs.MessageCount++
ogs.LastEncryptedTime = time.Now()
return ogs.Internal.Encrypt(plaintext), nil
return ogs.Internal.Encrypt(plaintext)
}
type TimeMixin struct {

View File

@ -16,13 +16,26 @@ import (
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) {
secret, err = mach.CryptoStore.GetSecret(ctx, name)
if err != nil || secret != "" {
return
// Callback function to process a received secret.
//
// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses.
type SecretReceiverFunc func(string) (bool, error)
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// always offer our stored secret first, if any
secret, err := mach.CryptoStore.GetSecret(ctx, name)
if err != nil {
return err
} else if secret != "" {
if ok, err := receiver(secret); ok || err != nil {
return err
}
}
requestID, secretChan := random.String(64), make(chan string, 1)
requestID, secretChan := random.String(64), make(chan string, 5)
mach.secretLock.Lock()
mach.secretListeners[requestID] = secretChan
mach.secretLock.Unlock()
@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret,
return
}
select {
case <-ctx.Done():
err = ctx.Err()
case <-time.After(timeout):
case secret = <-secretChan:
}
// best effort cancel request from all devices when returning
defer func() {
go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: requestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()
if secret != "" {
err = mach.CryptoStore.PutSecret(ctx, name, secret)
for {
select {
case <-ctx.Done():
return ctx.Err()
case secret = <-secretChan:
if ok, err := receiver(secret); err != nil {
return err
} else if ok {
return mach.CryptoStore.PutSecret(ctx, name, secret)
}
}
}
return
}
func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) {
@ -121,9 +144,11 @@ func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserI
return
} else if secret != "" {
log.Debug().Msg("Responding to secret request")
mach.sendToOneDevice(ctx, mach.Client.UserID, content.RequestingDeviceID, event.ToDeviceSecretRequest, &event.SecretSendEventContent{
RequestID: content.RequestID,
Secret: secret,
mach.SendEncryptedToDevice(ctx, device, event.ToDeviceSecretSend, event.Content{
Parsed: event.SecretSendEventContent{
RequestID: content.RequestID,
Secret: secret,
},
})
} else {
log.Debug().Msg("No stored secret found, secret request ignored")
@ -157,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven
return
}
// secret channel is buffered and we don't want to block
// at worst we drop _some_ of the responses
select {
case secretChan <- content.Secret:
default:
}
// best effort cancel this for all other targets
go func() {
mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: content.RequestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()
}

View File

@ -123,8 +123,11 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi
// PutAccount stores an OlmAccount in the database.
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
store.Account = account
bytes := account.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, `
bytes, err := account.Internal.Pickle(store.PickleKey)
if err != nil {
return err
}
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
account=excluded.account, account_id=excluded.account_id,
@ -137,7 +140,7 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
if store.Account == nil {
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID)
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
acc := &OlmAccount{Internal: olm.NewBlankAccount()}
var accountBytes []byte
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion)
if err == sql.ErrNoRows {
@ -183,7 +186,7 @@ func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey)
defer store.olmSessionCacheLock.Unlock()
cache := store.getOlmSessionCache(key)
for rows.Next() {
sess := OlmSession{Internal: *olm.NewBlankSession()}
sess := OlmSession{Internal: olm.NewBlankSession()}
var sessionBytes []byte
var sessionID id.SessionID
err = rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime)
@ -220,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender
row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
key, store.AccountID)
sess := OlmSession{Internal: *olm.NewBlankSession()}
sess := OlmSession{Internal: olm.NewBlankSession()}
var sessionBytes []byte
var sessionID id.SessionID
@ -246,8 +249,11 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender
func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
if err != nil {
return err
}
_, err = store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID)
store.getOlmSessionCache(key)[session.ID()] = session
return err
@ -255,8 +261,11 @@ func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, s
// UpdateSession replaces the Olm session for a sender in the database.
func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
if err != nil {
return err
}
_, err = store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID)
return err
}
@ -277,7 +286,10 @@ func datePtr(t time.Time) *time.Time {
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
if err != nil {
return err
}
forwardingChains := strings.Join(session.ForwardingChains, ",")
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
@ -335,7 +347,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
senderKey = id.Curve25519(senderKeyDB.String)
}
return &InboundGroupSession{
Internal: *igs,
Internal: igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: senderKey,
RoomID: roomID,
@ -447,7 +459,7 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID
}, nil
}
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
igs = olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
@ -480,7 +492,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
}
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
return &InboundGroupSession{
Internal: *igs,
Internal: igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
@ -523,8 +535,11 @@ func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, `
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
if err != nil {
return err
}
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_outbound_session
(room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@ -539,8 +554,11 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, sessio
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
if err != nil {
return err
}
_, err = store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID)
return err
}
@ -565,7 +583,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID
if err != nil {
return nil, err
}
ogs.Internal = *intOGS
ogs.Internal = intOGS
ogs.RoomID = roomID
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
return &ogs, nil

View File

@ -115,7 +115,7 @@ func TestStoreOlmSession(t *testing.T) {
olmSess := OlmSession{
id: olmSessID,
Internal: *olmInternal,
Internal: olmInternal,
}
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
if err != nil {
@ -133,7 +133,13 @@ func TestStoreOlmSession(t *testing.T) {
if retrieved.ID() != olmSessID {
t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID())
}
if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled {
pickled, err := retrieved.Internal.Pickle([]byte("test"))
if err != nil {
t.Fatalf("Error pickling Olm session: %v", err)
}
if string(pickled) != olmPickled {
t.Error("Pickled Olm session does not match original")
}
})
@ -152,7 +158,7 @@ func TestStoreMegolmSession(t *testing.T) {
}
igs := &InboundGroupSession{
Internal: *internal,
Internal: internal,
SigningKey: acc.SigningKey(),
SenderKey: acc.IdentityKey(),
RoomID: "room1",
@ -168,7 +174,9 @@ func TestStoreMegolmSession(t *testing.T) {
t.Errorf("Error retrieving inbound group session: %v", err)
}
if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession {
if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil {
t.Fatalf("Error pickling inbound group session: %v", err)
} else if string(pickled) != groupSession {
t.Error("Pickled inbound group session does not match original")
}
})

View File

@ -337,6 +337,12 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI
devices, err := vh.mach.CryptoStore.GetDevices(ctx, to)
if err != nil {
return "", fmt.Errorf("failed to get devices for user: %w", err)
} else if len(devices) == 0 {
// HACK: we are doing this because the client doesn't wait until it has
// the devices before starting verification.
if _, err := vh.mach.FetchKeys(ctx, []id.UserID{to}, true); err != nil {
return "", err
}
}
vh.getLog(ctx).Info().

View File

@ -149,5 +149,6 @@ type Unsigned struct {
func (us *Unsigned) IsEmpty() bool {
return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 &&
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
us.BeeperHSOrder == 0
}

203
federation/keyserver.go Normal file
View File

@ -0,0 +1,203 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
type ServerVersion struct {
Name string `json:"name"`
Version string `json:"version"`
}
// ServerKeyProvider is an interface that returns private server keys for server key requests.
type ServerKeyProvider interface {
Get(r *http.Request) (serverName string, key *SigningKey)
}
// StaticServerKey is an implementation of [ServerKeyProvider] that always returns the same server name and key.
type StaticServerKey struct {
ServerName string
Key *SigningKey
}
func (ssk *StaticServerKey) Get(r *http.Request) (serverName string, key *SigningKey) {
return ssk.ServerName, ssk.Key
}
// KeyServer implements a basic Matrix key server that can serve its own keys, plus the federation version endpoint.
//
// It does not implement querying keys of other servers, nor any other federation endpoints.
type KeyServer struct {
KeyProvider ServerKeyProvider
Version ServerVersion
WellKnownTarget string
}
// Register registers the key server endpoints to the given router.
func (ks *KeyServer) Register(r *mux.Router) {
r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Unrecognized endpoint",
})
})
keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
ErrCode: mautrix.MUnrecognized.ErrCode,
Err: "Invalid method for endpoint",
})
})
}
func jsonResponse(w http.ResponseWriter, code int, data any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(data)
}
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
type RespWellKnown struct {
Server string `json:"m.server"`
}
// GetWellKnown implements the `GET /.well-known/matrix/server` endpoint
//
// https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver
func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) {
if ks.WellKnownTarget == "" {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MNotFound.ErrCode,
Err: "No well-known target set",
})
} else {
jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
}
}
// RespServerVersion is the response body for the `GET /_matrix/federation/v1/version` endpoint
type RespServerVersion struct {
Server ServerVersion `json:"server"`
}
// GetServerVersion implements the `GET /_matrix/federation/v1/version` endpoint
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version
func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
}
// GetServerKey implements the `GET /_matrix/key/v2/server` endpoint.
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2server
func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) {
domain, key := ks.KeyProvider.Get(r)
if key == nil {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MNotFound.ErrCode,
Err: fmt.Sprintf("No signing key found for %q", r.Host),
})
} else {
jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
}
}
// ReqQueryKeys is the request body for the `POST /_matrix/key/v2/query` endpoint
type ReqQueryKeys struct {
ServerKeys map[string]map[id.KeyID]QueryKeysCriteria `json:"server_keys"`
}
type QueryKeysCriteria struct {
MinimumValidUntilTS jsontime.UnixMilli `json:"minimum_valid_until_ts"`
}
// PostQueryKeysResponse is the response body for the `POST /_matrix/key/v2/query` endpoint
type PostQueryKeysResponse struct {
ServerKeys map[string]*ServerKeyResponse `json:"server_keys"`
}
// PostQueryKeys implements the `POST /_matrix/key/v2/query` endpoint
//
// https://spec.matrix.org/v1.9/server-server-api/#post_matrixkeyv2query
func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
var req ReqQueryKeys
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
ErrCode: mautrix.MBadJSON.ErrCode,
Err: fmt.Sprintf("failed to parse request: %v", err),
})
return
}
resp := &PostQueryKeysResponse{
ServerKeys: make(map[string]*ServerKeyResponse),
}
for serverName, keys := range req.ServerKeys {
domain, key := ks.KeyProvider.Get(r)
if domain != serverName {
continue
}
for keyID, criteria := range keys {
if key.ID == keyID && criteria.MinimumValidUntilTS.Before(time.Now().Add(24*time.Hour)) {
resp.ServerKeys[serverName] = key.GenerateKeyResponse(serverName, nil)
}
}
}
jsonResponse(w, http.StatusOK, resp)
}
// GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint
type GetQueryKeysResponse struct {
ServerKeys []*ServerKeyResponse `json:"server_keys"`
}
// GetQueryKeys implements the `GET /_matrix/key/v2/query/{serverName}` endpoint
//
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
serverName := mux.Vars(r)["serverName"]
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
if err != nil && minimumValidUntilTSString != "" {
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
ErrCode: mautrix.MInvalidParam.ErrCode,
Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err),
})
return
} else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) {
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
ErrCode: mautrix.MInvalidParam.ErrCode,
Err: "minimum_valid_until_ts may not be more than 24 hours in the future",
})
return
}
resp := &GetQueryKeysResponse{
ServerKeys: []*ServerKeyResponse{},
}
if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName {
resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
}
jsonResponse(w, http.StatusOK, resp)
}

123
federation/signingkey.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package federation
import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
)
// SigningKey is a Matrix federation signing key pair.
type SigningKey struct {
ID id.KeyID
Pub id.SigningKey
Priv ed25519.PrivateKey
}
// SynapseString returns a string representation of the private key compatible with Synapse's .signing.key file format.
//
// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
func (sk *SigningKey) SynapseString() string {
alg, id := sk.ID.Parse()
return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
}
// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
func ParseSynapseKey(key string) (*SigningKey, error) {
parts := strings.Split(key, " ")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid key format (expected 3 space-separated parts, got %d)", len(parts))
} else if parts[0] != string(id.KeyAlgorithmEd25519) {
return nil, fmt.Errorf("unsupported key algorithm %s (only ed25519 is supported)", parts[0])
}
seed, err := base64.RawStdEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid private key: %w", err)
}
priv := ed25519.NewKeyFromSeed(seed)
pub := base64.RawStdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey))
return &SigningKey{
ID: id.NewKeyID(id.KeyAlgorithmEd25519, parts[1]),
Pub: id.SigningKey(pub),
Priv: priv,
}, nil
}
// GenerateSigningKey generates a new random signing key.
func GenerateSigningKey() *SigningKey {
pub, priv, err := ed25519.GenerateKey(nil)
if err != nil {
panic(err)
}
return &SigningKey{
ID: id.NewKeyID(id.KeyAlgorithmEd25519, base64.RawURLEncoding.EncodeToString(pub[:4])),
Pub: id.SigningKey(base64.RawStdEncoding.EncodeToString(pub)),
Priv: priv,
}
}
// ServerKeyResponse is the response body for the `GET /_matrix/key/v2/server` endpoint.
// It's also used inside the query endpoint response structs.
type ServerKeyResponse struct {
ServerName string `json:"server_name"`
VerifyKeys map[id.KeyID]ServerVerifyKey `json:"verify_keys"`
OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"`
Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"`
}
type ServerVerifyKey struct {
Key id.SigningKey `json:"key"`
}
type OldVerifyKey struct {
Key id.SigningKey `json:"key"`
ExpiredTS jsontime.UnixMilli `json:"expired_ts"`
}
func (sk *SigningKey) SignJSON(data any) ([]byte, error) {
marshaled, err := json.Marshal(data)
if err != nil {
return nil, err
}
return sk.SignRawJSON(marshaled), nil
}
func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte {
return ed25519.Sign(sk.Priv, canonicaljson.CanonicalJSONAssumeValid(data))
}
// GenerateKeyResponse generates a key response signed by this key with the given server name and optionally some old verify keys.
func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[id.KeyID]OldVerifyKey) *ServerKeyResponse {
skr := &ServerKeyResponse{
ServerName: serverName,
OldVerifyKeys: oldVerifyKeys,
ValidUntilTS: jsontime.UM(time.Now().Add(24 * time.Hour)),
VerifyKeys: map[id.KeyID]ServerVerifyKey{
sk.ID: {Key: sk.Pub},
},
}
signature, err := sk.SignJSON(skr)
if err != nil {
panic(err)
}
skr.Signatures = map[string]map[id.KeyID]string{
serverName: {
sk.ID: base64.RawURLEncoding.EncodeToString(signature),
},
}
return skr
}

View File

@ -7,6 +7,7 @@
package format
import (
"context"
"fmt"
"math"
"strconv"
@ -33,14 +34,16 @@ func (ts TagStack) Has(tag string) bool {
}
type Context struct {
Ctx context.Context
ReturnData map[string]any
TagStack TagStack
PreserveWhitespace bool
}
func NewContext() Context {
func NewContext(ctx context.Context) Context {
return Context{
Ctx: ctx,
ReturnData: map[string]any{},
TagStack: make(TagStack, 0, 4),
}
@ -411,7 +414,7 @@ func HTMLToText(html string) string {
Newline: "\n",
HorizontalLine: "\n---\n",
PillConverter: DefaultPillConverter,
}).Parse(html, NewContext())
}).Parse(html, NewContext(context.TODO()))
}
// HTMLToMarkdown converts Matrix HTML into markdown with the default settings.
@ -429,5 +432,5 @@ func HTMLToMarkdown(html string) string {
}
return fmt.Sprintf("[%s](%s)", text, href)
},
}).Parse(html, NewContext())
}).Parse(html, NewContext(context.TODO()))
}

15
go.mod
View File

@ -8,18 +8,17 @@ require (
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.22
github.com/rs/zerolog v1.32.0
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1
github.com/tidwall/sjson v1.2.5
github.com/yuin/goldmark v1.7.0
go.mau.fi/util v0.4.0
github.com/yuin/goldmark v1.7.1
go.mau.fi/util v0.4.2
go.mau.fi/zeroconfig v0.1.2
golang.org/x/crypto v0.19.0
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a
golang.org/x/net v0.21.0
golang.org/x/crypto v0.22.0
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8
golang.org/x/net v0.24.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mauflag v1.0.0
maunium.net/go/maulogger/v2 v2.4.1
)
require (
@ -30,6 +29,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/sys v0.19.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
)

30
go.sum
View File

@ -24,8 +24,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@ -35,23 +35,23 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/util v0.4.0 h1:S2X3qU4pUcb/vxBRfAuZjbrR9xVMAXSjQojNBLPBbhs=
go.mau.fi/util v0.4.0/go.mod h1:leeiHtgVBuN+W9aDii3deAXnfC563iN3WK6BF8/AjNw=
github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
go.mau.fi/util v0.4.2 h1:RR3TOcRHmCF9Bx/3YG4S65MYfa+nV6/rn8qBWW4Mi30=
go.mau.fi/util v0.4.2/go.mod h1:PlAVfUUcPyHPrwnvjkJM9UFcPE7qGPDJqk+Oufa1Gtw=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE=
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc=
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
@ -60,5 +60,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=

View File

@ -36,19 +36,34 @@ var (
ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
)
// ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format
func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, localpart, homeserver string) {
if len(identifier) == 0 {
return
}
sigil = identifier[0]
strIdentifier := string(identifier)
if strings.ContainsRune(strIdentifier, ':') {
parts := strings.SplitN(strIdentifier, ":", 2)
localpart = parts[0][1:]
homeserver = parts[1]
} else {
localpart = strIdentifier[1:]
}
return
}
// Parse parses the user ID into the localpart and server name.
//
// Note that this only enforces very basic user ID formatting requirements: user IDs start with
// a @, and contain a : after the @. If you want to enforce localpart validity, see the
// ParseAndValidate and ValidateUserLocalpart functions.
func (userID UserID) Parse() (localpart, homeserver string, err error) {
if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') {
// This error wrapping lets you use errors.Is() nicely even though the message contains the user ID
var sigil byte
sigil, localpart, homeserver = ParseCommonIdentifier(userID)
if sigil != '@' || homeserver == "" {
err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID)
return
}
parts := strings.SplitN(string(userID), ":", 2)
localpart, homeserver = strings.TrimPrefix(parts[0], "@"), parts[1]
return
}

View File

@ -309,6 +309,12 @@ func (slr SyncLeftRoom) MarshalJSON() ([]byte, error) {
return marshalAndDeleteEmpty((marshalableSyncLeftRoom)(slr), syncLeftRoomPathsToDelete)
}
type BeeperInboxPreviewEvent struct {
EventID id.EventID `json:"event_id"`
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
Event *event.Event `json:"event,omitempty"`
}
type SyncJoinedRoom struct {
Summary LazyLoadSummary `json:"summary"`
State SyncEventsList `json:"state"`
@ -319,6 +325,8 @@ type SyncJoinedRoom struct {
UnreadNotifications *UnreadNotificationCounts `json:"unread_notifications,omitempty"`
// https://github.com/matrix-org/matrix-spec-proposals/pull/2654
MSC2654UnreadCount *int `json:"org.matrix.msc2654.unread_count,omitempty"`
// Beeper extension
BeeperInboxPreview *BeeperInboxPreviewEvent `json:"com.beeper.inbox.preview,omitempty"`
}
type UnreadNotificationCounts struct {

View File

@ -7,7 +7,7 @@ import (
"strings"
)
const Version = "v0.18.0-beta.1"
const Version = "v0.18.1"
var GoModVersion = ""
var Commit = ""