mirror of https://github.com/mautrix/go.git
Compare commits
44 Commits
404c65f05c
...
3b8453bd15
Author | SHA1 | Date |
---|---|---|
Sumner Evans | 3b8453bd15 | |
Sumner Evans | 135eccbaa0 | |
Sumner Evans | 44b04d50a4 | |
Sumner Evans | 645348695a | |
Sumner Evans | 2ded86695b | |
Sumner Evans | d8d05ce0a7 | |
Sumner Evans | 059632c845 | |
Sumner Evans | e768e5fa53 | |
Sumner Evans | 04c7efc0c0 | |
Sumner Evans | 48edb28c1f | |
Sumner Evans | 0d0f04d51d | |
Sumner Evans | c8f6fa3a47 | |
Sumner Evans | 13d0ff3524 | |
Sumner Evans | c0e030fc85 | |
Sumner Evans | 2810465ef2 | |
Malte E | 6cc490d9ab | |
Sumner Evans | ff9e2e0f1d | |
Tulir Asokan | a19dab1897 | |
Tulir Asokan | 423d32ddf6 | |
Malte E | 640086dbf9 | |
Toni Spets | 898b235a84 | |
Toni Spets | 64cc843952 | |
Toni Spets | 0095e1fb78 | |
Tulir Asokan | ade00e8603 | |
Toni Spets | 9fe66581e5 | |
Adam Van Ymeren | 4dd7adc7be | |
Adam Van Ymeren | 8ba307b28d | |
Tulir Asokan | 5dedc9806a | |
Malte E | b556d65da9 | |
Toni Spets | fad4448ab7 | |
Tulir Asokan | a7bf485893 | |
Tulir Asokan | 20fde3d163 | |
Tulir Asokan | 5224780563 | |
Toni Spets | f0b728f502 | |
Tulir Asokan | 8128b00e00 | |
Brad Murray | 08397c8b9a | |
Tulir Asokan | 94246ffc85 | |
Sumner Evans | 2728a8f8aa | |
Sumner Evans | 3b65d98c0c | |
Tulir Asokan | d18dcfc7eb | |
Toni Spets | a36f60a4f3 | |
Malte E | db41583fdd | |
Malte E | 41dfb40064 | |
Malte E | 6b1a039beb |
31
CHANGELOG.md
31
CHANGELOG.md
|
@ -1,14 +1,43 @@
|
|||
## v0.18.0 (unreleased)
|
||||
## v0.18.1 (2024-04-16)
|
||||
|
||||
* *(format)* Added a `context.Context` field to HTMLParser's Context struct.
|
||||
* *(bridge)* Added support for handling join rules, knocks, invites and bans
|
||||
(thanks to [@maltee1] in [#193] and [#204]).
|
||||
* *(crypto)* Changed forwarded room key handling to only accept keys with a
|
||||
lower first known index than the existing session if there is one.
|
||||
* *(crypto)* Changed key backup restore to assume own device list is up to date
|
||||
to avoid re-requesting device list for every deleted device that has signed
|
||||
key backup.
|
||||
* *(crypto)* Fixed memory cache not being invalidated when storing own
|
||||
cross-signing keys
|
||||
|
||||
[@maltee1]: https://github.com/maltee1
|
||||
[#193]: https://github.com/mautrix/go/pull/193
|
||||
[#204]: https://github.com/mautrix/go/pull/204
|
||||
|
||||
## v0.18.0 (2024-03-16)
|
||||
|
||||
* **Breaking change *(client, bridge, appservice)*** Dropped support for
|
||||
maulogger. Only zerolog loggers are now provided by default.
|
||||
* *(bridge)* Fixed upload size limit not having a default if the server
|
||||
returned no value.
|
||||
* *(synapseadmin)* Added wrappers for some room and user admin APIs.
|
||||
(thanks to [@grvn-ht] in [#181]).
|
||||
* *(crypto/verificationhelper)* Fixed bugs.
|
||||
* *(crypto)* Fixed key backup uploading doing too much base64.
|
||||
* *(crypto)* Changed `EncryptMegolmEvent` to return an error if persisting the
|
||||
megolm session fails. This ensures that database errors won't cause messages
|
||||
to be sent with duplicate indexes.
|
||||
* *(crypto)* Changed `GetOrRequestSecret` to use a callback instead of returning
|
||||
the value directly. This allows validating the value in order to ignore
|
||||
invalid secrets.
|
||||
* *(id)* Added `ParseCommonIdentifier` function to parse any Matrix identifier
|
||||
in the [Common Identifier Format].
|
||||
* *(federation)* Added simple key server that passes the federation tester.
|
||||
|
||||
[@grvn-ht]: https://github.com/grvn-ht
|
||||
[#181]: https://github.com/mautrix/go/pull/181
|
||||
[Common Identifier Format]: https://spec.matrix.org/v1.9/appendices/#common-identifier-format
|
||||
|
||||
### beta.1 (2024-02-16)
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
"gopkg.in/yaml.v3"
|
||||
"maunium.net/go/maulogger/v2/maulogadapt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
@ -355,7 +354,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error {
|
|||
// This does not do any validation, and it does not cache the client.
|
||||
// Usually you should prefer [AppService.Client] or [AppService.Intent] over this method.
|
||||
func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
|
||||
client := &mautrix.Client{
|
||||
return &mautrix.Client{
|
||||
HomeserverURL: as.hsURLForClient,
|
||||
UserID: userID,
|
||||
SetAppServiceUserID: true,
|
||||
|
@ -366,8 +365,6 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
|
|||
Client: as.HTTPClient,
|
||||
DefaultHTTPRetries: as.DefaultHTTPRetries,
|
||||
}
|
||||
client.Logger = maulogadapt.ZeroAsMau(&client.Log)
|
||||
return client
|
||||
}
|
||||
|
||||
// NewExternalMautrixClient creates a new [mautrix.Client] instance for an external user,
|
||||
|
|
|
@ -29,8 +29,6 @@ import (
|
|||
"go.mau.fi/util/exzerolog"
|
||||
"gopkg.in/yaml.v3"
|
||||
flag "maunium.net/go/mauflag"
|
||||
"maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/maulogger/v2/maulogadapt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
|
@ -96,6 +94,32 @@ type PowerLevelHandlingPortal interface {
|
|||
HandleMatrixPowerLevels(sender User, evt *event.Event)
|
||||
}
|
||||
|
||||
type JoinRuleHandlingPortal interface {
|
||||
Portal
|
||||
HandleMatrixJoinRule(sender User, evt *event.Event)
|
||||
}
|
||||
|
||||
type BanHandlingPortal interface {
|
||||
Portal
|
||||
HandleMatrixBan(sender User, ghost Ghost, evt *event.Event)
|
||||
HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event)
|
||||
}
|
||||
|
||||
type KnockHandlingPortal interface {
|
||||
Portal
|
||||
HandleMatrixKnock(sender User, evt *event.Event)
|
||||
HandleMatrixRetractKnock(sender User, evt *event.Event)
|
||||
HandleMatrixAcceptKnock(sender User, ghost Ghost, evt *event.Event)
|
||||
HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event)
|
||||
}
|
||||
|
||||
type InviteHandlingPortal interface {
|
||||
Portal
|
||||
HandleMatrixAcceptInvite(sender User, evt *event.Event)
|
||||
HandleMatrixRejectInvite(sender User, evt *event.Event)
|
||||
HandleMatrixRetractInvite(sender User, ghost Ghost, evt *event.Event)
|
||||
}
|
||||
|
||||
type User interface {
|
||||
GetPermissionLevel() bridgeconfig.PermissionLevel
|
||||
IsLoggedIn() bool
|
||||
|
@ -201,8 +225,6 @@ type Bridge struct {
|
|||
Crypto Crypto
|
||||
CryptoPickleKey string
|
||||
|
||||
// Deprecated: Switch to ZLog
|
||||
Log maulogger.Logger
|
||||
ZLog *zerolog.Logger
|
||||
|
||||
MediaConfig mautrix.RespMediaConfig
|
||||
|
@ -536,7 +558,6 @@ func (br *Bridge) init() {
|
|||
os.Exit(12)
|
||||
}
|
||||
exzerolog.SetupDefaults(br.ZLog)
|
||||
br.Log = maulogadapt.ZeroAsMau(br.ZLog)
|
||||
|
||||
br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()}
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
|
@ -38,8 +37,6 @@ type Event struct {
|
|||
ReplyTo id.EventID
|
||||
Ctx context.Context
|
||||
ZLog *zerolog.Logger
|
||||
// Deprecated: switch to ZLog
|
||||
Log maulogger.Logger
|
||||
}
|
||||
|
||||
// MainIntent returns the intent to use when replying to the command.
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/maulogger/v2/maulogadapt"
|
||||
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -91,7 +90,6 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.
|
|||
ReplyTo: replyTo,
|
||||
Ctx: ctx,
|
||||
ZLog: &log,
|
||||
Log: maulogadapt.ZeroAsMau(&log),
|
||||
}
|
||||
log.Debug().Msg("Received command")
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler {
|
|||
br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt)
|
||||
br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping)
|
||||
br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels)
|
||||
br.EventProcessor.On(event.StateJoinRules, handler.HandleJoinRule)
|
||||
return handler
|
||||
}
|
||||
|
||||
|
@ -275,27 +276,62 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event)
|
|||
} else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() {
|
||||
return
|
||||
}
|
||||
|
||||
mhp, ok := portal.(MembershipHandlingPortal)
|
||||
if !ok {
|
||||
bhp, bhpOk := portal.(BanHandlingPortal)
|
||||
mhp, mhpOk := portal.(MembershipHandlingPortal)
|
||||
khp, khpOk := portal.(KnockHandlingPortal)
|
||||
ihp, ihpOk := portal.(InviteHandlingPortal)
|
||||
if !(mhpOk || bhpOk || khpOk) {
|
||||
return
|
||||
}
|
||||
|
||||
if content.Membership == event.MembershipLeave {
|
||||
if evt.Unsigned.PrevContent != nil {
|
||||
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
|
||||
prevContent, ok := evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent)
|
||||
if ok && prevContent.Membership != "join" {
|
||||
return
|
||||
prevContent := &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
if evt.Unsigned.PrevContent != nil {
|
||||
_ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
|
||||
prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent)
|
||||
}
|
||||
if ihpOk && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan {
|
||||
if content.Membership == event.MembershipJoin {
|
||||
ihp.HandleMatrixAcceptInvite(user, evt)
|
||||
}
|
||||
if content.Membership == event.MembershipLeave {
|
||||
if isSelf {
|
||||
ihp.HandleMatrixRejectInvite(user, evt)
|
||||
} else if ghost != nil {
|
||||
ihp.HandleMatrixRetractInvite(user, ghost, evt)
|
||||
}
|
||||
}
|
||||
if isSelf {
|
||||
mhp.HandleMatrixLeave(user, evt)
|
||||
} else if ghost != nil {
|
||||
mhp.HandleMatrixKick(user, ghost, evt)
|
||||
}
|
||||
if bhpOk && ghost != nil {
|
||||
if content.Membership == event.MembershipBan {
|
||||
bhp.HandleMatrixBan(user, ghost, evt)
|
||||
} else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan {
|
||||
bhp.HandleMatrixUnban(user, ghost, evt)
|
||||
}
|
||||
}
|
||||
if khpOk {
|
||||
if content.Membership == event.MembershipKnock {
|
||||
khp.HandleMatrixKnock(user, evt)
|
||||
} else if prevContent.Membership == event.MembershipKnock {
|
||||
if content.Membership == event.MembershipInvite && ghost != nil {
|
||||
khp.HandleMatrixAcceptKnock(user, ghost, evt)
|
||||
} else if content.Membership == event.MembershipLeave {
|
||||
if isSelf {
|
||||
khp.HandleMatrixRetractKnock(user, evt)
|
||||
} else if ghost != nil {
|
||||
khp.HandleMatrixRejectKnock(user, ghost, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if mhpOk {
|
||||
if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin {
|
||||
if isSelf {
|
||||
mhp.HandleMatrixLeave(user, evt)
|
||||
} else if ghost != nil {
|
||||
mhp.HandleMatrixKick(user, ghost, evt)
|
||||
}
|
||||
} else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil {
|
||||
mhp.HandleMatrixInvite(user, ghost, evt)
|
||||
}
|
||||
} else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil {
|
||||
mhp.HandleMatrixInvite(user, ghost, evt)
|
||||
}
|
||||
// TODO kicking/inviting non-ghost users users
|
||||
}
|
||||
|
@ -702,3 +738,18 @@ func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event)
|
|||
powerLevelPortal.HandleMatrixPowerLevels(user, evt)
|
||||
}
|
||||
}
|
||||
|
||||
func (mx *MatrixHandler) HandleJoinRule(_ context.Context, evt *event.Event) {
|
||||
if mx.shouldIgnoreEvent(evt) {
|
||||
return
|
||||
}
|
||||
portal := mx.bridge.Child.GetIPortal(evt.RoomID)
|
||||
if portal == nil {
|
||||
return
|
||||
}
|
||||
joinRulePortal, ok := portal.(JoinRuleHandlingPortal)
|
||||
if ok {
|
||||
user := mx.bridge.Child.GetIUser(evt.Sender, true)
|
||||
joinRulePortal.HandleMatrixJoinRule(user, evt)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -119,7 +119,7 @@ func (br *Bridge) PingServer() (start, serverTs, end time.Time) {
|
|||
}
|
||||
start = time.Now()
|
||||
var resp wsPingData
|
||||
br.Log.Debugln("Pinging appservice websocket")
|
||||
br.ZLog.Debug().Msg("Pinging appservice websocket")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{
|
||||
|
|
38
client.go
38
client.go
|
@ -19,7 +19,6 @@ import (
|
|||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/retryafter"
|
||||
"maunium.net/go/maulogger/v2/maulogadapt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/backup"
|
||||
"maunium.net/go/mautrix/event"
|
||||
|
@ -49,17 +48,6 @@ type VerificationHelper interface {
|
|||
ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error
|
||||
}
|
||||
|
||||
// Deprecated: switch to zerolog
|
||||
type Logger interface {
|
||||
Debugfln(message string, args ...interface{})
|
||||
}
|
||||
|
||||
// Deprecated: switch to zerolog
|
||||
type WarnLogger interface {
|
||||
Logger
|
||||
Warnfln(message string, args ...interface{})
|
||||
}
|
||||
|
||||
// Client represents a Matrix client.
|
||||
type Client struct {
|
||||
HomeserverURL *url.URL // The base homeserver URL
|
||||
|
@ -75,8 +63,6 @@ type Client struct {
|
|||
Verification VerificationHelper
|
||||
|
||||
Log zerolog.Logger
|
||||
// Deprecated: switch to the zerolog instance in Log
|
||||
Logger Logger
|
||||
|
||||
RequestHook func(req *http.Request)
|
||||
ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration)
|
||||
|
@ -352,6 +338,7 @@ type FullRequest struct {
|
|||
SensitiveContent bool
|
||||
Handler ClientResponseHandler
|
||||
Logger *zerolog.Logger
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
var requestID int32
|
||||
|
@ -438,7 +425,10 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b
|
|||
if len(cli.AccessToken) > 0 {
|
||||
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
|
||||
}
|
||||
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler)
|
||||
if params.Client == nil {
|
||||
params.Client = cli.Client
|
||||
}
|
||||
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler, params.Client)
|
||||
}
|
||||
|
||||
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
||||
|
@ -449,7 +439,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
|||
return log
|
||||
}
|
||||
|
||||
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
|
||||
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) {
|
||||
log := zerolog.Ctx(req.Context())
|
||||
if req.Body != nil {
|
||||
if req.GetBody == nil {
|
||||
|
@ -467,7 +457,7 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff
|
|||
Int("retry_in_seconds", int(backoff.Seconds())).
|
||||
Msg("Request failed, retrying")
|
||||
time.Sleep(backoff)
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler)
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, client)
|
||||
}
|
||||
|
||||
func readRequestBody(req *http.Request, res *http.Response) ([]byte, error) {
|
||||
|
@ -549,17 +539,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
|
||||
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) {
|
||||
cli.RequestStart(req)
|
||||
startTime := time.Now()
|
||||
res, err := cli.Client.Do(req)
|
||||
res, err := client.Do(req)
|
||||
duration := time.Now().Sub(startTime)
|
||||
if res != nil {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
if err != nil {
|
||||
if retries > 0 {
|
||||
return cli.doRetry(req, err, retries, backoff, responseJSON, handler)
|
||||
return cli.doRetry(req, err, retries, backoff, responseJSON, handler, client)
|
||||
}
|
||||
err = HTTPError{
|
||||
Request: req,
|
||||
|
@ -574,7 +564,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
|
|||
|
||||
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
|
||||
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
|
||||
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler)
|
||||
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, client)
|
||||
}
|
||||
|
||||
var body []byte
|
||||
|
@ -2295,7 +2285,7 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cli := &Client{
|
||||
return &Client{
|
||||
AccessToken: accessToken,
|
||||
UserAgent: DefaultUserAgent,
|
||||
HomeserverURL: hsURL,
|
||||
|
@ -2307,7 +2297,5 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli
|
|||
// The client will work with this storer: it just won't remember across restarts.
|
||||
// In practice, a database backend should be used.
|
||||
Store: NewMemorySyncStore(),
|
||||
}
|
||||
cli.Logger = maulogadapt.ZeroAsMau(&cli.Log)
|
||||
return cli, nil
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -23,27 +23,39 @@ type OlmAccount struct {
|
|||
|
||||
func NewOlmAccount() *OlmAccount {
|
||||
return &OlmAccount{
|
||||
Internal: *olm.NewAccount(),
|
||||
Internal: olm.NewAccount(),
|
||||
}
|
||||
}
|
||||
|
||||
func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) {
|
||||
if len(account.signingKey) == 0 || len(account.identityKey) == 0 {
|
||||
account.signingKey, account.identityKey = account.Internal.IdentityKeys()
|
||||
var err error
|
||||
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return account.signingKey, account.identityKey
|
||||
}
|
||||
|
||||
func (account *OlmAccount) SigningKey() id.SigningKey {
|
||||
if len(account.signingKey) == 0 {
|
||||
account.signingKey, account.identityKey = account.Internal.IdentityKeys()
|
||||
var err error
|
||||
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return account.signingKey
|
||||
}
|
||||
|
||||
func (account *OlmAccount) IdentityKey() id.IdentityKey {
|
||||
if len(account.identityKey) == 0 {
|
||||
account.signingKey, account.identityKey = account.Internal.IdentityKeys()
|
||||
var err error
|
||||
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return account.identityKey
|
||||
}
|
||||
|
@ -71,16 +83,19 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID
|
|||
func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
|
||||
newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
|
||||
if newCount > 0 {
|
||||
account.Internal.GenOneTimeKeys(uint(newCount))
|
||||
account.Internal.GenOneTimeKeys(nil, uint(newCount))
|
||||
}
|
||||
oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey)
|
||||
for keyID, key := range account.Internal.OneTimeKeys() {
|
||||
internalKeys, err := account.Internal.OneTimeKeys()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for keyID, key := range internalKeys {
|
||||
key := mautrix.OneTimeKey{Key: key}
|
||||
signature, _ := account.Internal.SignJSON(key)
|
||||
key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
|
||||
key.IsSigned = true
|
||||
oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key
|
||||
}
|
||||
account.Internal.MarkKeysAsPublished()
|
||||
return oneTimeKeys
|
||||
}
|
||||
|
|
|
@ -19,16 +19,16 @@ import (
|
|||
|
||||
// CrossSigningKeysCache holds the three cross-signing keys for the current user.
|
||||
type CrossSigningKeysCache struct {
|
||||
MasterKey *olm.PkSigning
|
||||
SelfSigningKey *olm.PkSigning
|
||||
UserSigningKey *olm.PkSigning
|
||||
MasterKey olm.PKSigning
|
||||
SelfSigningKey olm.PKSigning
|
||||
UserSigningKey olm.PKSigning
|
||||
}
|
||||
|
||||
func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache {
|
||||
return &CrossSigningPublicKeysCache{
|
||||
MasterKey: cskc.MasterKey.PublicKey,
|
||||
SelfSigningKey: cskc.SelfSigningKey.PublicKey,
|
||||
UserSigningKey: cskc.UserSigningKey.PublicKey,
|
||||
MasterKey: cskc.MasterKey.PublicKey(),
|
||||
SelfSigningKey: cskc.SelfSigningKey.PublicKey(),
|
||||
UserSigningKey: cskc.UserSigningKey.PublicKey(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,28 +40,28 @@ type CrossSigningSeeds struct {
|
|||
|
||||
func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds {
|
||||
return CrossSigningSeeds{
|
||||
MasterKey: mach.CrossSigningKeys.MasterKey.Seed,
|
||||
SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed,
|
||||
UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed,
|
||||
MasterKey: mach.CrossSigningKeys.MasterKey.Seed(),
|
||||
SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed(),
|
||||
UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed(),
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err error) {
|
||||
var keysCache CrossSigningKeysCache
|
||||
if keysCache.MasterKey, err = olm.NewPkSigningFromSeed(keys.MasterKey); err != nil {
|
||||
if keysCache.MasterKey, err = olm.NewPKSigningFromSeed(keys.MasterKey); err != nil {
|
||||
return
|
||||
}
|
||||
if keysCache.SelfSigningKey, err = olm.NewPkSigningFromSeed(keys.SelfSigningKey); err != nil {
|
||||
if keysCache.SelfSigningKey, err = olm.NewPKSigningFromSeed(keys.SelfSigningKey); err != nil {
|
||||
return
|
||||
}
|
||||
if keysCache.UserSigningKey, err = olm.NewPkSigningFromSeed(keys.UserSigningKey); err != nil {
|
||||
if keysCache.UserSigningKey, err = olm.NewPKSigningFromSeed(keys.UserSigningKey); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug().
|
||||
Str("master", keysCache.MasterKey.PublicKey.String()).
|
||||
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
|
||||
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
|
||||
Str("master", keysCache.MasterKey.PublicKey().String()).
|
||||
Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()).
|
||||
Str("user_signing", keysCache.UserSigningKey.PublicKey().String()).
|
||||
Msg("Imported own cross-signing keys")
|
||||
|
||||
mach.CrossSigningKeys = &keysCache
|
||||
|
@ -73,19 +73,19 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro
|
|||
func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, error) {
|
||||
var keysCache CrossSigningKeysCache
|
||||
var err error
|
||||
if keysCache.MasterKey, err = olm.NewPkSigning(); err != nil {
|
||||
if keysCache.MasterKey, err = olm.NewPKSigning(); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate master key: %w", err)
|
||||
}
|
||||
if keysCache.SelfSigningKey, err = olm.NewPkSigning(); err != nil {
|
||||
if keysCache.SelfSigningKey, err = olm.NewPKSigning(); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate self-signing key: %w", err)
|
||||
}
|
||||
if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil {
|
||||
if keysCache.UserSigningKey, err = olm.NewPKSigning(); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate user-signing key: %w", err)
|
||||
}
|
||||
mach.Log.Debug().
|
||||
Str("master", keysCache.MasterKey.PublicKey.String()).
|
||||
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
|
||||
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
|
||||
Str("master", keysCache.MasterKey.PublicKey().String()).
|
||||
Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()).
|
||||
Str("user_signing", keysCache.UserSigningKey.PublicKey().String()).
|
||||
Msg("Generated cross-signing keys")
|
||||
return &keysCache, nil
|
||||
}
|
||||
|
@ -93,12 +93,12 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
|
|||
// PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server.
|
||||
func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
|
||||
userID := mach.Client.UserID
|
||||
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String())
|
||||
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String())
|
||||
masterKey := mautrix.CrossSigningKeys{
|
||||
UserID: userID,
|
||||
Usage: []id.CrossSigningUsage{id.XSUsageMaster},
|
||||
Keys: map[id.KeyID]id.Ed25519{
|
||||
masterKeyID: keys.MasterKey.PublicKey,
|
||||
masterKeyID: keys.MasterKey.PublicKey(),
|
||||
},
|
||||
}
|
||||
masterSig, err := mach.account.Internal.SignJSON(masterKey)
|
||||
|
@ -111,27 +111,27 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross
|
|||
UserID: userID,
|
||||
Usage: []id.CrossSigningUsage{id.XSUsageSelfSigning},
|
||||
Keys: map[id.KeyID]id.Ed25519{
|
||||
id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey.String()): keys.SelfSigningKey.PublicKey,
|
||||
id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey().String()): keys.SelfSigningKey.PublicKey(),
|
||||
},
|
||||
}
|
||||
selfSig, err := keys.MasterKey.SignJSON(selfKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign self-signing key: %w", err)
|
||||
}
|
||||
selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), selfSig)
|
||||
selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), selfSig)
|
||||
|
||||
userKey := mautrix.CrossSigningKeys{
|
||||
UserID: userID,
|
||||
Usage: []id.CrossSigningUsage{id.XSUsageUserSigning},
|
||||
Keys: map[id.KeyID]id.Ed25519{
|
||||
id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey.String()): keys.UserSigningKey.PublicKey,
|
||||
id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey().String()): keys.UserSigningKey.PublicKey(),
|
||||
},
|
||||
}
|
||||
userSig, err := keys.MasterKey.SignJSON(userKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sign user-signing key: %w", err)
|
||||
}
|
||||
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), userSig)
|
||||
userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
|
||||
|
||||
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
|
||||
Master: masterKey,
|
||||
|
|
|
@ -60,7 +60,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe
|
|||
Str("signature", signature).
|
||||
Msg("Signed master key of user with our user-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey(), signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
|
|||
|
||||
userID := mach.Client.UserID
|
||||
deviceID := mach.Client.DeviceID
|
||||
masterKey := mach.CrossSigningKeys.MasterKey.PublicKey
|
||||
masterKey := mach.CrossSigningKeys.MasterKey.PublicKey()
|
||||
|
||||
masterKeyObj := mautrix.ReqKeysSignatures{
|
||||
UserID: userID,
|
||||
|
@ -149,7 +149,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er
|
|||
Str("signature", signature).
|
||||
Msg("Signed own device key with self-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey(), signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
|
@ -180,12 +180,12 @@ func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device
|
|||
}
|
||||
|
||||
// signAndUpload signs the given key signatures object and uploads it to the server.
|
||||
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) {
|
||||
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key olm.PKSigning) (string, error) {
|
||||
signature, err := key.SignJSON(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign JSON: %w", err)
|
||||
}
|
||||
req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey.String(), signature)
|
||||
req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey().String(), signature)
|
||||
|
||||
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
|
||||
userID: map[string]mautrix.ReqKeysSignatures{
|
||||
|
|
|
@ -110,24 +110,24 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
|
|||
|
||||
// UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key.
|
||||
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed(), key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed(), key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed(), key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Also store these locally
|
||||
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey); err != nil {
|
||||
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey); err != nil {
|
||||
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey); err != nil {
|
||||
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -96,5 +96,12 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear internal cache so that it refreshes from crypto store
|
||||
if userID == mach.Client.UserID && mach.crossSigningPubkeys != nil {
|
||||
log.Debug().Msg("Resetting internal cross-signing key cache")
|
||||
mach.crossSigningPubkeys = nil
|
||||
mach.crossSigningPubkeysFetched = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,13 +37,13 @@ func getOlmMachine(t *testing.T) *OlmMachine {
|
|||
}
|
||||
|
||||
userID := id.UserID("@mautrix")
|
||||
mk, _ := olm.NewPkSigning()
|
||||
ssk, _ := olm.NewPkSigning()
|
||||
usk, _ := olm.NewPkSigning()
|
||||
mk, _ := olm.NewPKSigning()
|
||||
ssk, _ := olm.NewPKSigning()
|
||||
usk, _ := olm.NewPKSigning()
|
||||
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey)
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey())
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey())
|
||||
sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey())
|
||||
|
||||
return &OlmMachine{
|
||||
CryptoStore: sqlStore,
|
||||
|
@ -70,10 +70,10 @@ func TestTrustOwnDevice(t *testing.T) {
|
|||
t.Error("Own device trusted while it shouldn't be")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey,
|
||||
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(),
|
||||
ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
|
||||
m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey,
|
||||
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2")
|
||||
ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted {
|
||||
t.Error("Own user not trusted while they should be")
|
||||
|
@ -90,22 +90,22 @@ func TestTrustOtherUser(t *testing.T) {
|
|||
t.Error("Other user trusted while they shouldn't be")
|
||||
}
|
||||
|
||||
theirMasterKey, _ := olm.NewPkSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
|
||||
theirMasterKey, _ := olm.NewPKSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
|
||||
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
|
||||
|
||||
// sign them with self-signing instead of user-signing key
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig")
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted {
|
||||
t.Error("Other user trusted before their master key has been signed with our user-signing key")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
|
||||
t.Error("Other user not trusted while they should be")
|
||||
|
@ -127,29 +127,29 @@ func TestTrustOtherDevice(t *testing.T) {
|
|||
t.Error("Other device trusted while it shouldn't be")
|
||||
}
|
||||
|
||||
theirMasterKey, _ := olm.NewPkSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey)
|
||||
theirSSK, _ := olm.NewPkSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey)
|
||||
theirMasterKey, _ := olm.NewPKSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey())
|
||||
theirSSK, _ := olm.NewPKSigning()
|
||||
m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey())
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1")
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey,
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2")
|
||||
m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1")
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(),
|
||||
m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2")
|
||||
|
||||
if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted {
|
||||
t.Error("Other user not trusted while they should be")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey,
|
||||
otherUser, theirMasterKey.PublicKey, "sig3")
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(),
|
||||
otherUser, theirMasterKey.PublicKey(), "sig3")
|
||||
|
||||
if m.IsDeviceTrusted(theirDevice) {
|
||||
t.Error("Other device trusted before it has been signed with user's SSK")
|
||||
}
|
||||
|
||||
m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey,
|
||||
otherUser, theirSSK.PublicKey, "sig4")
|
||||
otherUser, theirSSK.PublicKey(), "sig4")
|
||||
|
||||
if !m.IsDeviceTrusted(theirDevice) {
|
||||
t.Error("Other device not trusted while it should be")
|
||||
|
|
|
@ -32,7 +32,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
|||
}
|
||||
theirMSK, ok := theirKeys[id.XSUsageMaster]
|
||||
if !ok {
|
||||
mach.machOrContextLog(ctx).Error().
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
Str("user_id", device.UserID.String()).
|
||||
Msg("Master key of user not found")
|
||||
return id.TrustStateUnset, nil
|
||||
|
|
|
@ -93,6 +93,10 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
|||
}
|
||||
}
|
||||
|
||||
// FetchKeys fetches the devices of a list of other users. If includeUntracked
|
||||
// is set to false, then the users are filtered to to only include user IDs
|
||||
// whose device lists have been stored with the PutDevices function on the
|
||||
// [Store]. See the FilterTrackedUsers function on [Store] for details.
|
||||
func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) {
|
||||
req := &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: mautrix.DeviceKeysRequest{},
|
||||
|
|
|
@ -118,7 +118,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
|
|||
log.Debug().Msg("Encrypted event successfully")
|
||||
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
|
||||
return nil, fmt.Errorf("failed to update outbound group session after encrypting: %w", err)
|
||||
}
|
||||
encrypted := &event.EncryptedEventContent{
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
|
|
|
@ -37,7 +37,10 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
|||
Str("olm_session_id", session.ID().String()).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Encrypting olm message")
|
||||
msgType, ciphertext := session.Encrypt(plaintext)
|
||||
msgType, ciphertext, err := session.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
|
||||
|
|
|
@ -8,14 +8,18 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -39,6 +43,9 @@ type Account struct {
|
|||
NumFallbackKeys uint8 `json:"number_fallback_keys"`
|
||||
}
|
||||
|
||||
// Ensure that Account adheres to the olm.Account interface.
|
||||
var _ olm.Account = (*Account)(nil)
|
||||
|
||||
// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
|
||||
func AccountFromJSONPickled(pickled, key []byte) (*Account, error) {
|
||||
if len(pickled) == 0 {
|
||||
|
@ -82,7 +89,7 @@ func NewAccount(reader io.Reader) (*Account, error) {
|
|||
}
|
||||
|
||||
// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (a Account) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
func (a *Account) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
|
||||
}
|
||||
|
||||
|
@ -92,44 +99,60 @@ func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
|
|||
}
|
||||
|
||||
// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
|
||||
func (a Account) IdentityKeysJSON() ([]byte, error) {
|
||||
func (a *Account) IdentityKeysJSON() ([]byte, error) {
|
||||
res := struct {
|
||||
Ed25519 string `json:"ed25519"`
|
||||
Curve25519 string `json:"curve25519"`
|
||||
}{}
|
||||
ed25519, curve25519 := a.IdentityKeys()
|
||||
ed25519, curve25519, err := a.IdentityKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.Ed25519 = string(ed25519)
|
||||
res.Curve25519 = string(curve25519)
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account.
|
||||
func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
|
||||
func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
|
||||
ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey))
|
||||
curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey))
|
||||
return ed25519, curve25519
|
||||
return ed25519, curve25519, nil
|
||||
}
|
||||
|
||||
// Sign returns the base64-encoded signature of a message using the Ed25519 key
|
||||
// for this Account.
|
||||
func (a Account) Sign(message []byte) ([]byte, error) {
|
||||
func (a *Account) Sign(message []byte) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil
|
||||
}
|
||||
|
||||
// SignJSON signs the given JSON object following the Matrix specification:
|
||||
// https://matrix.org/docs/spec/appendices#signing-json
|
||||
func (a *Account) SignJSON(obj any) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
signed, err := a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
|
||||
return string(signed), err
|
||||
}
|
||||
|
||||
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
|
||||
//
|
||||
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
|
||||
func (a Account) OneTimeKeys() map[string]id.Curve25519 {
|
||||
func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) {
|
||||
oneTimeKeys := make(map[string]id.Curve25519)
|
||||
for _, curKey := range a.OTKeys {
|
||||
if !curKey.Published {
|
||||
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
|
||||
}
|
||||
}
|
||||
return oneTimeKeys
|
||||
return oneTimeKeys, nil
|
||||
}
|
||||
|
||||
//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string.
|
||||
|
@ -143,9 +166,12 @@ func (a Account) OneTimeKeys() map[string]id.Curve25519 {
|
|||
}
|
||||
}
|
||||
*/
|
||||
func (a Account) OneTimeKeysJSON() ([]byte, error) {
|
||||
func (a *Account) OneTimeKeysJSON() ([]byte, error) {
|
||||
res := make(map[string]map[string]id.Curve25519)
|
||||
otKeys := a.OneTimeKeys()
|
||||
otKeys, err := a.OneTimeKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res["Curve25519"] = otKeys
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
@ -186,7 +212,7 @@ func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
|
|||
|
||||
// NewOutboundSession creates a new outbound session to a
|
||||
// given curve25519 identity Key and one time key.
|
||||
func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) {
|
||||
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
||||
return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
|
@ -205,13 +231,18 @@ func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundSession creates a new inbound session from an incoming PRE_KEY message.
|
||||
func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) {
|
||||
// NewInboundSession creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure.
|
||||
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) {
|
||||
return a.NewInboundSessionFrom(nil, oneTimeKeyMsg)
|
||||
}
|
||||
|
||||
// NewInboundSessionFrom creates a new inbound session from an incoming PRE_KEY message.
|
||||
func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
var theirIdentityKeyDecoded *crypto.Curve25519PublicKey
|
||||
var err error
|
||||
if theirIdentityKey != nil {
|
||||
theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey))
|
||||
if err != nil {
|
||||
|
@ -221,14 +252,10 @@ func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMs
|
|||
theirIdentityKeyDecoded = &theirIdentityKeyCurve
|
||||
}
|
||||
|
||||
s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return session.NewInboundOlmSession(theirIdentityKeyDecoded, []byte(oneTimeKeyMsg), a.searchOTKForOur, a.IdKeys.Curve25519)
|
||||
}
|
||||
|
||||
func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
|
||||
func (a *Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
|
||||
for curIndex := range a.OTKeys {
|
||||
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
|
||||
return &a.OTKeys[curIndex]
|
||||
|
@ -244,16 +271,17 @@ func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneT
|
|||
}
|
||||
|
||||
// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s.
|
||||
func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) {
|
||||
toFind := s.BobOneTimeKey
|
||||
func (a *Account) RemoveOneTimeKeys(s olm.Session) error {
|
||||
toFind := s.(*session.OlmSession).BobOneTimeKey
|
||||
for curIndex := range a.OTKeys {
|
||||
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
|
||||
//Remove and return
|
||||
a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1]
|
||||
a.OTKeys = a.OTKeys[:len(a.OTKeys)-1]
|
||||
return
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
//if the key is a fallback or prevFallback, don't remove it
|
||||
}
|
||||
|
||||
|
@ -279,7 +307,7 @@ func (a *Account) GenFallbackKey(reader io.Reader) error {
|
|||
|
||||
// FallbackKey returns the public part of the current fallback key of the Account.
|
||||
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
|
||||
func (a Account) FallbackKey() map[string]id.Curve25519 {
|
||||
func (a *Account) FallbackKey() map[string]id.Curve25519 {
|
||||
keys := make(map[string]id.Curve25519)
|
||||
if a.NumFallbackKeys >= 1 {
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
|
||||
|
@ -297,7 +325,7 @@ func (a Account) FallbackKey() map[string]id.Curve25519 {
|
|||
}
|
||||
}
|
||||
*/
|
||||
func (a Account) FallbackKeyJSON() ([]byte, error) {
|
||||
func (a *Account) FallbackKeyJSON() ([]byte, error) {
|
||||
res := make(map[string]map[string]id.Curve25519)
|
||||
fbk := a.FallbackKey()
|
||||
res["curve25519"] = fbk
|
||||
|
@ -306,7 +334,7 @@ func (a Account) FallbackKeyJSON() ([]byte, error) {
|
|||
|
||||
// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished.
|
||||
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
|
||||
func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
|
||||
func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
|
||||
keys := make(map[string]id.Curve25519)
|
||||
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
|
||||
|
@ -324,7 +352,7 @@ func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
|
|||
}
|
||||
}
|
||||
*/
|
||||
func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
|
||||
func (a *Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
|
||||
res := make(map[string]map[string]id.Curve25519)
|
||||
fbk := a.FallbackKeyUnpublished()
|
||||
res["curve25519"] = fbk
|
||||
|
@ -448,7 +476,10 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) {
|
|||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm().
|
||||
func (a Account) Pickle(key []byte) ([]byte, error) {
|
||||
func (a *Account) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, goolm.ErrNoKeyProvided
|
||||
}
|
||||
pickeledBytes := make([]byte, a.PickleLen())
|
||||
written, err := a.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
|
@ -466,7 +497,7 @@ func (a Account) Pickle(key []byte) ([]byte, error) {
|
|||
|
||||
// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (a Account) PickleLibOlm(target []byte) (int, error) {
|
||||
func (a *Account) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < a.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
|
@ -510,7 +541,7 @@ func (a Account) PickleLibOlm(target []byte) (int, error) {
|
|||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled Account will have.
|
||||
func (a Account) PickleLen() int {
|
||||
func (a *Account) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM)
|
||||
length += a.IdKeys.Ed25519.PickleLen()
|
||||
length += a.IdKeys.Curve25519.PickleLen()
|
||||
|
@ -521,3 +552,9 @@ func (a Account) PickleLen() int {
|
|||
length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID)
|
||||
return length
|
||||
}
|
||||
|
||||
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
|
||||
// Account can store.
|
||||
func (a *Account) MaxNumberOfOneTimeKeys() uint {
|
||||
return uint(MaxOneTimeKeys)
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestAccount(t *testing.T) {
|
|||
t.Fatal("IdentityKeys Ed25519 public unequal")
|
||||
}
|
||||
|
||||
if len(firstAccount.OneTimeKeys()) != 2 {
|
||||
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 2 {
|
||||
t.Fatal("should get 2 unpublished oneTimeKeys")
|
||||
}
|
||||
if len(firstAccount.FallbackKeyUnpublished()) == 0 {
|
||||
|
@ -84,7 +84,7 @@ func TestAccount(t *testing.T) {
|
|||
if len(firstAccount.FallbackKeyUnpublished()) != 0 {
|
||||
t.Fatal("should get no fallbackKey")
|
||||
}
|
||||
if len(firstAccount.OneTimeKeys()) != 0 {
|
||||
if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 0 {
|
||||
t.Fatal("should get no oneTimeKeys")
|
||||
}
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ func TestSessions(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
plaintext := []byte("test message")
|
||||
msgType, crypttext, err := aliceSession.Encrypt(plaintext, nil)
|
||||
msgType, crypttext, err := aliceSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -147,11 +147,11 @@ func TestSessions(t *testing.T) {
|
|||
t.Fatal("wrong message type")
|
||||
}
|
||||
|
||||
bobSession, err := bobAccount.NewInboundSession(nil, crypttext)
|
||||
bobSession, err := bobAccount.NewInboundSession(string(crypttext))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decodedText, err := bobSession.Decrypt(crypttext, msgType)
|
||||
decodedText, err := bobSession.Decrypt(string(crypttext), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -252,7 +252,7 @@ func TestLoopback(t *testing.T) {
|
|||
}
|
||||
|
||||
plainText := []byte("Hello, World")
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText, nil)
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -260,12 +260,12 @@ func TestLoopback(t *testing.T) {
|
|||
t.Fatal("wrong message type")
|
||||
}
|
||||
|
||||
bobSession, err := accountB.NewInboundSession(nil, message1)
|
||||
bobSession, err := accountB.NewInboundSession(string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1)
|
||||
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -274,7 +274,7 @@ func TestLoopback(t *testing.T) {
|
|||
}
|
||||
// Check that the inbound session matches the key this message is supposed to be from.
|
||||
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1)
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -283,7 +283,7 @@ func TestLoopback(t *testing.T) {
|
|||
}
|
||||
// Check that the inbound session isn't from a different user.
|
||||
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1)
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -291,7 +291,7 @@ func TestLoopback(t *testing.T) {
|
|||
t.Fatal("session is sad to be from b but is from a")
|
||||
}
|
||||
// Check that we can decrypt the message.
|
||||
decryptedMessage, err := bobSession.Decrypt(message1, msgType)
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -299,7 +299,7 @@ func TestLoopback(t *testing.T) {
|
|||
t.Fatal("messages are not the same")
|
||||
}
|
||||
|
||||
msgTyp2, message2, err := bobSession.Encrypt(plainText, nil)
|
||||
msgTyp2, message2, err := bobSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -307,7 +307,7 @@ func TestLoopback(t *testing.T) {
|
|||
t.Fatal("wrong message type")
|
||||
}
|
||||
|
||||
decryptedMessage2, err := aliceSession.Decrypt(message2, msgTyp2)
|
||||
decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -316,7 +316,7 @@ func TestLoopback(t *testing.T) {
|
|||
}
|
||||
|
||||
//decrypting again should fail, as the chain moved on
|
||||
_, err = aliceSession.Decrypt(message2, msgTyp2)
|
||||
_, err = aliceSession.Decrypt(string(message2), msgTyp2)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
@ -348,7 +348,7 @@ func TestMoreMessages(t *testing.T) {
|
|||
}
|
||||
|
||||
plainText := []byte("Hello, World")
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText, nil)
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -356,11 +356,11 @@ func TestMoreMessages(t *testing.T) {
|
|||
t.Fatal("wrong message type")
|
||||
}
|
||||
|
||||
bobSession, err := accountB.NewInboundSession(nil, message1)
|
||||
bobSession, err := accountB.NewInboundSession(string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decryptedMessage, err := bobSession.Decrypt(message1, msgType)
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -370,7 +370,7 @@ func TestMoreMessages(t *testing.T) {
|
|||
|
||||
for i := 0; i < 8; i++ {
|
||||
//alice sends, bob reveices
|
||||
msgType, message, err := aliceSession.Encrypt(plainText, nil)
|
||||
msgType, message, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -384,7 +384,7 @@ func TestMoreMessages(t *testing.T) {
|
|||
t.Fatal("wrong message type")
|
||||
}
|
||||
}
|
||||
decryptedMessage, err := bobSession.Decrypt(message, msgType)
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -393,14 +393,14 @@ func TestMoreMessages(t *testing.T) {
|
|||
}
|
||||
|
||||
//now bob sends, alice receives
|
||||
msgType, message, err = bobSession.Encrypt(plainText, nil)
|
||||
msgType, message, err = bobSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType == id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
decryptedMessage, err = aliceSession.Decrypt(message, msgType)
|
||||
decryptedMessage, err = aliceSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -435,7 +435,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
}
|
||||
|
||||
plainText := []byte("Hello, World")
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText, nil)
|
||||
msgType, message1, err := aliceSession.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -443,12 +443,12 @@ func TestFallbackKey(t *testing.T) {
|
|||
t.Fatal("wrong message type")
|
||||
}
|
||||
|
||||
bobSession, err := accountB.NewInboundSession(nil, message1)
|
||||
bobSession, err := accountB.NewInboundSession(string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1)
|
||||
sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
}
|
||||
// Check that the inbound session matches the key this message is supposed to be from.
|
||||
aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1)
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -466,7 +466,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
}
|
||||
// Check that the inbound session isn't from a different user.
|
||||
bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded()
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1)
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -474,7 +474,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
t.Fatal("session is sad to be from b but is from a")
|
||||
}
|
||||
// Check that we can decrypt the message.
|
||||
decryptedMessage, err := bobSession.Decrypt(message1, msgType)
|
||||
decryptedMessage, err := bobSession.Decrypt(string(message1), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -493,7 +493,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msgType2, message2, err := aliceSession2.Encrypt(plainText, nil)
|
||||
msgType2, message2, err := aliceSession2.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -502,19 +502,19 @@ func TestFallbackKey(t *testing.T) {
|
|||
}
|
||||
// bobSession should not be valid for the message2
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom(nil, message2)
|
||||
sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sessionIsOK {
|
||||
t.Fatal("session was detected to be valid but should not")
|
||||
}
|
||||
bobSession2, err := accountB.NewInboundSession(nil, message2)
|
||||
bobSession2, err := accountB.NewInboundSession(string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Check that the inbound session matches the message it was created from.
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(nil, message2)
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -522,7 +522,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
t.Fatal("session was not detected to be valid")
|
||||
}
|
||||
// Check that the inbound session matches the key this message is supposed to be from.
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&aIDKey, message2)
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -530,7 +530,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
t.Fatal("session is sad to be not from a but it should")
|
||||
}
|
||||
// Check that the inbound session isn't from a different user.
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&bIDKey, message2)
|
||||
sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -538,7 +538,7 @@ func TestFallbackKey(t *testing.T) {
|
|||
t.Fatal("session is sad to be from b but is from a")
|
||||
}
|
||||
// Check that we can decrypt the message.
|
||||
decryptedMessage2, err := bobSession2.Decrypt(message2, msgType2)
|
||||
decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -553,14 +553,14 @@ func TestFallbackKey(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msgType3, message3, err := aliceSession3.Encrypt(plainText, nil)
|
||||
msgType3, message3, err := aliceSession3.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msgType3 != id.OlmMsgTypePreKey {
|
||||
t.Fatal("wrong message type")
|
||||
}
|
||||
_, err = accountB.NewInboundSession(nil, message3)
|
||||
_, err = accountB.NewInboundSession(string(message3))
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
|
|
@ -95,10 +95,10 @@ func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Cu
|
|||
}
|
||||
|
||||
// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations.
|
||||
func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) {
|
||||
func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
var err error
|
||||
if !r.SenderChains.IsSet {
|
||||
newRatchetKey, err := crypto.Curve25519GenerateKey(reader)
|
||||
newRatchetKey, err := crypto.Curve25519GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func TestSendReceive(t *testing.T) {
|
|||
plainText := []byte("Hello Bob")
|
||||
|
||||
//Alice sends Bob a message
|
||||
encryptedMessage, err := aliceRatchet.Encrypt(plainText, nil)
|
||||
encryptedMessage, err := aliceRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func TestSendReceive(t *testing.T) {
|
|||
|
||||
//Bob sends Alice a message
|
||||
plainText = []byte("Hello Alice")
|
||||
encryptedMessage, err = bobRatchet.Encrypt(plainText, nil)
|
||||
encryptedMessage, err = bobRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -83,11 +83,11 @@ func TestOutOfOrder(t *testing.T) {
|
|||
plainText2 := []byte("Second Messsage. A bit longer than the first.")
|
||||
|
||||
/* Alice sends Bob two messages and they arrive out of order */
|
||||
message1Encrypted, err := aliceRatchet.Encrypt(plainText1, nil)
|
||||
message1Encrypted, err := aliceRatchet.Encrypt(plainText1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
message2Encrypted, err := aliceRatchet.Encrypt(plainText2, nil)
|
||||
message2Encrypted, err := aliceRatchet.Encrypt(plainText2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ func TestMoreMessages(t *testing.T) {
|
|||
}
|
||||
plainText := []byte("These 15 bytes")
|
||||
for i := 0; i < 8; i++ {
|
||||
messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil)
|
||||
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func TestMoreMessages(t *testing.T) {
|
|||
}
|
||||
}
|
||||
for i := 0; i < 8; i++ {
|
||||
messageEncrypted, err := bobRatchet.Encrypt(plainText, nil)
|
||||
messageEncrypted, err := bobRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ func TestMoreMessages(t *testing.T) {
|
|||
t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted)
|
||||
}
|
||||
}
|
||||
messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil)
|
||||
messageEncrypted, err := aliceRatchet.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -171,7 +171,7 @@ func TestJSONEncoding(t *testing.T) {
|
|||
|
||||
plainText := []byte("These 15 bytes")
|
||||
|
||||
messageEncrypted, err := newRatcher.Encrypt(plainText, nil)
|
||||
messageEncrypted, err := newRatcher.Encrypt(plainText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -45,8 +45,8 @@ func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decrypti
|
|||
return s, nil
|
||||
}
|
||||
|
||||
// PubKey returns the public key base 64 encoded.
|
||||
func (s Decryption) PubKey() id.Curve25519 {
|
||||
// PublicKey returns the public key base 64 encoded.
|
||||
func (s Decryption) PublicKey() id.Curve25519 {
|
||||
return s.KeyPair.B64Encoded()
|
||||
}
|
||||
|
||||
|
|
|
@ -30,14 +30,14 @@ func TestEncryptionDecryption(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) {
|
||||
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
|
||||
t.Fatal("public key not correct")
|
||||
}
|
||||
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
|
||||
t.Fatal("private key not correct")
|
||||
}
|
||||
|
||||
encryption, err := pk.NewEncryption(decryption.PubKey())
|
||||
encryption, err := pk.NewEncryption(decryption.PublicKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -66,7 +66,10 @@ func TestSigning(t *testing.T) {
|
|||
}
|
||||
message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.")
|
||||
signing, _ := pk.NewSigningFromSeed(seed)
|
||||
signature := signing.Sign(message)
|
||||
signature, err := signing.Sign(message)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signatureDecoded, err := goolm.Base64Decode(signature)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -101,7 +104,7 @@ func TestDecryptionPickling(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) {
|
||||
if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) {
|
||||
t.Fatal("public key not correct")
|
||||
}
|
||||
if !bytes.Equal(decryption.PrivateKey(), alicePrivate) {
|
||||
|
@ -125,7 +128,7 @@ func TestDecryptionPickling(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal([]byte(newDecription.PubKey()), alicePublic) {
|
||||
if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) {
|
||||
t.Fatal("public key not correct")
|
||||
}
|
||||
if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) {
|
||||
|
|
|
@ -2,7 +2,11 @@ package pk
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
@ -10,15 +14,15 @@ import (
|
|||
|
||||
// Signing is used for signing a pk
|
||||
type Signing struct {
|
||||
KeyPair crypto.Ed25519KeyPair `json:"key_pair"`
|
||||
Seed []byte `json:"seed"`
|
||||
keyPair crypto.Ed25519KeyPair
|
||||
seed []byte
|
||||
}
|
||||
|
||||
// NewSigningFromSeed constructs a new Signing based on a seed.
|
||||
func NewSigningFromSeed(seed []byte) (*Signing, error) {
|
||||
s := &Signing{}
|
||||
s.Seed = seed
|
||||
s.KeyPair = crypto.Ed25519GenerateFromSeed(seed)
|
||||
s.seed = seed
|
||||
s.keyPair = crypto.Ed25519GenerateFromSeed(seed)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -32,13 +36,34 @@ func NewSigning() (*Signing, error) {
|
|||
return NewSigningFromSeed(seed)
|
||||
}
|
||||
|
||||
// Sign returns the signature of the message base64 encoded.
|
||||
func (s Signing) Sign(message []byte) []byte {
|
||||
signature := s.KeyPair.Sign(message)
|
||||
return goolm.Base64Encode(signature)
|
||||
// Seed returns the seed of the key pair.
|
||||
func (s Signing) Seed() []byte {
|
||||
return s.seed
|
||||
}
|
||||
|
||||
// PublicKey returns the public key of the key pair base 64 encoded.
|
||||
func (s Signing) PublicKey() id.Ed25519 {
|
||||
return s.KeyPair.B64Encoded()
|
||||
return s.keyPair.B64Encoded()
|
||||
}
|
||||
|
||||
// Sign returns the signature of the message base64 encoded.
|
||||
func (s Signing) Sign(message []byte) ([]byte, error) {
|
||||
signature := s.keyPair.Sign(message)
|
||||
return goolm.Base64Encode(signature), nil
|
||||
}
|
||||
|
||||
// SignJSON creates a signature for the given object after encoding it to
|
||||
// canonical JSON.
|
||||
func (s Signing) SignJSON(obj any) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(signature), nil
|
||||
}
|
||||
|
|
|
@ -89,7 +89,7 @@ func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession
|
|||
}
|
||||
|
||||
// getRatchet tries to find the correct ratchet for a messageIndex.
|
||||
func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
|
||||
func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
|
||||
// pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that
|
||||
if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) {
|
||||
o.Ratchet.AdvanceTo(messageIndex)
|
||||
|
@ -107,7 +107,10 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet,
|
|||
}
|
||||
|
||||
// Decrypt decrypts a base64 encoded group message.
|
||||
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) {
|
||||
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) {
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, 0, goolm.ErrEmptyInput
|
||||
}
|
||||
if o.SigningKey == nil {
|
||||
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
|
||||
}
|
||||
|
@ -143,17 +146,17 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error
|
|||
return nil, 0, err
|
||||
}
|
||||
o.SigningKeyVerified = true
|
||||
return decrypted, msg.MessageIndex, nil
|
||||
return decrypted, uint(msg.MessageIndex), nil
|
||||
|
||||
}
|
||||
|
||||
// SessionID returns the base64 endoded signing key
|
||||
func (o MegolmInboundSession) SessionID() id.SessionID {
|
||||
// ID returns the base64 endoded signing key
|
||||
func (o *MegolmInboundSession) ID() id.SessionID {
|
||||
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey))
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key)
|
||||
}
|
||||
|
||||
|
@ -162,8 +165,14 @@ func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error {
|
|||
return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON)
|
||||
}
|
||||
|
||||
// SessionExportMessage creates an base64 encoded export of the session.
|
||||
func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) {
|
||||
// Export returns the base64-encoded ratchet key for this session, at the given
|
||||
// index, in a format which can be used by
|
||||
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
|
||||
// InboundGroupSession using the supplied key. Returns error on failure.
|
||||
// if we do not have a session key corresponding to the given index (ie, it was
|
||||
// sent before the session key was shared with us) the error will be
|
||||
// returned.
|
||||
func (o *MegolmInboundSession) Export(messageIndex uint32) ([]byte, error) {
|
||||
ratchet, err := o.getRatchet(messageIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -174,6 +183,11 @@ func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte,
|
|||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return goolm.ErrNoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
return goolm.ErrEmptyInput
|
||||
}
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -223,7 +237,10 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) {
|
|||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm().
|
||||
func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
|
||||
func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, goolm.ErrNoKeyProvided
|
||||
}
|
||||
pickeledBytes := make([]byte, o.PickleLen())
|
||||
written, err := o.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
|
@ -241,7 +258,7 @@ func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
|
|||
|
||||
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
|
||||
func (o *MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < o.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
|
@ -266,7 +283,7 @@ func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
|
|||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled session will have.
|
||||
func (o MegolmInboundSession) PickleLen() int {
|
||||
func (o *MegolmInboundSession) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm)
|
||||
length += o.InitialRatchet.PickleLen()
|
||||
length += o.Ratchet.PickleLen()
|
||||
|
@ -274,3 +291,15 @@ func (o MegolmInboundSession) PickleLen() int {
|
|||
length += libolmpickle.PickleBoolLen(o.SigningKeyVerified)
|
||||
return length
|
||||
}
|
||||
|
||||
// FirstKnownIndex returns the first message index we know how to decrypt.
|
||||
func (s *MegolmInboundSession) FirstKnownIndex() uint32 {
|
||||
return s.InitialRatchet.Counter
|
||||
}
|
||||
|
||||
// IsVerified check if the session has been verified as a valid session. (A
|
||||
// session is verified either because the original session share was signed, or
|
||||
// because we have subsequently successfully decrypted a message.)
|
||||
func (s *MegolmInboundSession) IsVerified() bool {
|
||||
return s.SigningKeyVerified
|
||||
}
|
||||
|
|
|
@ -55,14 +55,14 @@ func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSessi
|
|||
}
|
||||
a := &MegolmOutboundSession{}
|
||||
err := a.Unpickle(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
return a, err
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext as a base64 encoded group message.
|
||||
func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
return nil, goolm.ErrEmptyInput
|
||||
}
|
||||
encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -71,12 +71,12 @@ func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
|
|||
}
|
||||
|
||||
// SessionID returns the base64 endoded public signing key
|
||||
func (o MegolmOutboundSession) SessionID() id.SessionID {
|
||||
func (o *MegolmOutboundSession) ID() id.SessionID {
|
||||
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey))
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key)
|
||||
}
|
||||
|
||||
|
@ -88,6 +88,9 @@ func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error {
|
|||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return goolm.ErrNoKeyProvided
|
||||
}
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -122,7 +125,10 @@ func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) {
|
|||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm().
|
||||
func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
|
||||
func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, goolm.ErrNoKeyProvided
|
||||
}
|
||||
pickeledBytes := make([]byte, o.PickleLen())
|
||||
written, err := o.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
|
@ -140,7 +146,7 @@ func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
|
|||
|
||||
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
|
||||
func (o *MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < o.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
|
@ -159,13 +165,28 @@ func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
|
|||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled session will have.
|
||||
func (o MegolmOutboundSession) PickleLen() int {
|
||||
func (o *MegolmOutboundSession) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm)
|
||||
length += o.Ratchet.PickleLen()
|
||||
length += o.SigningKey.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
|
||||
func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
|
||||
return o.Ratchet.SessionSharingMessage(o.SigningKey)
|
||||
}
|
||||
|
||||
// MessageIndex returns the message index for this session. Each message is
|
||||
// sent with an increasing index; this returns the index for the next message.
|
||||
func (s *MegolmOutboundSession) MessageIndex() uint {
|
||||
return uint(s.Ratchet.Counter)
|
||||
}
|
||||
|
||||
// Key returns the base64-encoded current ratchet key for this session.
|
||||
func (s *MegolmOutboundSession) Key() string {
|
||||
message, err := s.SessionSharingMessage()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(message)
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestOutboundPickleJSON(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess.SessionID() != newSession.SessionID() {
|
||||
if sess.ID() != newSession.ID() {
|
||||
t.Fatal("session ids not equal")
|
||||
}
|
||||
if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) {
|
||||
|
@ -75,7 +75,7 @@ func TestInboundPickleJSON(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess.SessionID() != newSession.SessionID() {
|
||||
if sess.ID() != newSession.ID() {
|
||||
t.Fatal("sess ids not equal")
|
||||
}
|
||||
if !bytes.Equal(sess.SigningKey, newSession.SigningKey) {
|
||||
|
@ -128,7 +128,7 @@ func TestGroupSendReceive(t *testing.T) {
|
|||
if !inboundSession.SigningKeyVerified {
|
||||
t.Fatal("key not verified")
|
||||
}
|
||||
if inboundSession.SessionID() != outboundSession.SessionID() {
|
||||
if inboundSession.ID() != outboundSession.ID() {
|
||||
t.Fatal("session ids not equal")
|
||||
}
|
||||
|
||||
|
@ -174,7 +174,7 @@ func TestGroupSessionExportImport(t *testing.T) {
|
|||
}
|
||||
|
||||
//Export the keys
|
||||
exported, err := inboundSession.SessionExportMessage(0)
|
||||
exported, err := inboundSession.Export(0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
|
@ -204,7 +204,7 @@ func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error {
|
|||
|
||||
// ID returns an identifier for this Session. Will be the same for both ends of the conversation.
|
||||
// Generated by hashing the public keys used to create the session.
|
||||
func (s OlmSession) ID() id.SessionID {
|
||||
func (s *OlmSession) ID() id.SessionID {
|
||||
message := make([]byte, 3*crypto.Curve25519KeyLength)
|
||||
copy(message, s.AliceIdentityKey)
|
||||
copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey)
|
||||
|
@ -215,15 +215,39 @@ func (s OlmSession) ID() id.SessionID {
|
|||
}
|
||||
|
||||
// HasReceivedMessage returns true if this session has received any message.
|
||||
func (s OlmSession) HasReceivedMessage() bool {
|
||||
func (s *OlmSession) HasReceivedMessage() bool {
|
||||
return s.ReceivedMessage
|
||||
}
|
||||
|
||||
// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match.
|
||||
func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
|
||||
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match. Returns error on
|
||||
// failure.
|
||||
func (s *OlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
||||
return s.matchesInboundSession(nil, []byte(oneTimeKeyMsg))
|
||||
}
|
||||
|
||||
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match. Returns error on
|
||||
// failure.
|
||||
func (s *OlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
|
||||
var theirKey *id.Curve25519
|
||||
if theirIdentityKey != "" {
|
||||
theirs := id.Curve25519(theirIdentityKey)
|
||||
theirKey = &theirs
|
||||
}
|
||||
|
||||
return s.matchesInboundSession(theirKey, []byte(oneTimeKeyMsg))
|
||||
}
|
||||
|
||||
// matchesInboundSession checks if the oneTimeKeyMsg message is set for this
|
||||
// inbound Session. This can happen if multiple messages are sent to this
|
||||
// Account before this Account sends a message in reply. Returns true if the
|
||||
// session matches. Returns false if the session does not match.
|
||||
func (s *OlmSession) matchesInboundSession(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
|
||||
if len(receivedOTKMsg) == 0 {
|
||||
return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
|
@ -266,20 +290,20 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2
|
|||
// EncryptMsgType returns the type of the next message that Encrypt will
|
||||
// return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg.
|
||||
// Returns MsgTypeMsg if the message will be a normal message.
|
||||
func (s OlmSession) EncryptMsgType() id.OlmMsgType {
|
||||
func (s *OlmSession) EncryptMsgType() id.OlmMsgType {
|
||||
if s.ReceivedMessage {
|
||||
return id.OlmMsgTypeMsg
|
||||
}
|
||||
return id.OlmMsgTypePreKey
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations.
|
||||
func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) {
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded.
|
||||
func (s *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
messageType := s.EncryptMsgType()
|
||||
encrypted, err := s.Ratchet.Encrypt(plaintext, reader)
|
||||
encrypted, err := s.Ratchet.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
@ -304,11 +328,11 @@ func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType,
|
|||
}
|
||||
|
||||
// Decrypt decrypts a base64 encoded message using the Session.
|
||||
func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) {
|
||||
func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, error) {
|
||||
if len(crypttext) == 0 {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
decodedCrypttext, err := goolm.Base64Decode(crypttext)
|
||||
decodedCrypttext, err := goolm.Base64Decode([]byte(crypttext))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -333,6 +357,9 @@ func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, e
|
|||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (o *OlmSession) Unpickle(pickled, key []byte) error {
|
||||
if len(pickled) == 0 {
|
||||
return goolm.ErrEmptyInput
|
||||
}
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -386,26 +413,26 @@ func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) {
|
|||
return curPos, nil
|
||||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm().
|
||||
func (o OlmSession) Pickle(key []byte) ([]byte, error) {
|
||||
pickeledBytes := make([]byte, o.PickleLen())
|
||||
written, err := o.PickleLibOlm(pickeledBytes)
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled olmSession
|
||||
// using PickleLibOlm().
|
||||
func (s *OlmSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, goolm.ErrNoKeyProvided
|
||||
}
|
||||
pickeledBytes := make([]byte, s.PickleLen())
|
||||
written, err := s.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written != len(pickeledBytes) {
|
||||
return nil, errors.New("number of written bytes not correct")
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypted, nil
|
||||
return cipher.Pickle(key, pickeledBytes)
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
|
||||
func (o *OlmSession) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < o.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
|
@ -435,7 +462,7 @@ func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
|
|||
}
|
||||
|
||||
// PickleLen returns the actual number of bytes the pickled session will have.
|
||||
func (o OlmSession) PickleLen() int {
|
||||
func (o *OlmSession) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
|
||||
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
|
||||
length += o.AliceIdentityKey.PickleLen()
|
||||
|
@ -446,7 +473,7 @@ func (o OlmSession) PickleLen() int {
|
|||
}
|
||||
|
||||
// PickleLenMin returns the minimum number of bytes the pickled session must have.
|
||||
func (o OlmSession) PickleLenMin() int {
|
||||
func (o *OlmSession) PickleLenMin() int {
|
||||
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
|
||||
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
|
||||
length += o.AliceIdentityKey.PickleLen()
|
||||
|
@ -457,20 +484,17 @@ func (o OlmSession) PickleLenMin() int {
|
|||
}
|
||||
|
||||
// Describe returns a string describing the current state of the session for debugging.
|
||||
func (o OlmSession) Describe() string {
|
||||
var res string
|
||||
if o.Ratchet.SenderChains.IsSet {
|
||||
res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index)
|
||||
} else {
|
||||
res += "sender chain index: "
|
||||
}
|
||||
res += "receiver chain indicies:"
|
||||
func (o *OlmSession) Describe() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString("sender chain index: ")
|
||||
builder.WriteString(fmt.Sprint(o.Ratchet.SenderChains.CKey.Index))
|
||||
builder.WriteString(" receiver chain indices:")
|
||||
for _, curChain := range o.Ratchet.ReceiverChains {
|
||||
res += fmt.Sprintf(" %d", curChain.CKey.Index)
|
||||
builder.WriteString(fmt.Sprintf(" %d", curChain.CKey.Index))
|
||||
}
|
||||
res += " skipped message keys:"
|
||||
builder.WriteString(" skipped message keys:")
|
||||
for _, curSkip := range o.Ratchet.SkippedMessageKeys {
|
||||
res += fmt.Sprintf(" %d", curSkip.MKey.Index)
|
||||
builder.WriteString(fmt.Sprintf(" %d", curSkip.MKey.Index))
|
||||
}
|
||||
return res
|
||||
return builder.String()
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ func TestOlmSession(t *testing.T) {
|
|||
}
|
||||
//create a message so that there are more keys to marshal
|
||||
plaintext := []byte("Test message from Alice to Bob")
|
||||
msgType, message, err := aliceSession.Encrypt(plaintext, nil)
|
||||
msgType, message, err := aliceSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func TestOlmSession(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decryptedMsg, err := bobSession.Decrypt(message, msgType)
|
||||
decryptedMsg, err := bobSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func TestOlmSession(t *testing.T) {
|
|||
|
||||
//bob sends a message
|
||||
plaintext = []byte("A message from Bob to Alice")
|
||||
msgType, message, err = bobSession.Encrypt(plaintext, nil)
|
||||
msgType, message, err = bobSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ func TestOlmSession(t *testing.T) {
|
|||
}
|
||||
|
||||
//Alice receives message
|
||||
decryptedMsg, err = newAliceSession.Decrypt(message, msgType)
|
||||
decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -95,14 +95,14 @@ func TestOlmSession(t *testing.T) {
|
|||
}
|
||||
|
||||
//Alice receives message again
|
||||
_, err = newAliceSession.Decrypt(message, msgType)
|
||||
_, err = newAliceSession.Decrypt(string(message), msgType)
|
||||
if err == nil {
|
||||
t.Fatal("should have gotten an error")
|
||||
}
|
||||
|
||||
//Alice sends another message
|
||||
plaintext = []byte("A second message to Bob")
|
||||
msgType, message, err = newAliceSession.Encrypt(plaintext, nil)
|
||||
msgType, message, err = newAliceSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ func TestOlmSession(t *testing.T) {
|
|||
t.Fatal("Wrong message type")
|
||||
}
|
||||
//bob receives message
|
||||
decryptedMsg, err = bobSession.Decrypt(message, msgType)
|
||||
decryptedMsg, err = bobSession.Decrypt(string(message), msgType)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ func TestDecrypts(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
for curIndex, curMessage := range messages {
|
||||
_, err := sess.Decrypt(curMessage, id.OlmMsgTypePreKey)
|
||||
_, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey)
|
||||
if err != nil {
|
||||
if !errors.Is(err, expectedErr[curIndex]) {
|
||||
t.Fatal(err)
|
||||
|
|
|
@ -66,8 +66,10 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context)
|
|||
var key id.Ed25519
|
||||
if keyName == crossSigningPubkeys.MasterKey.String() {
|
||||
key = crossSigningPubkeys.MasterKey
|
||||
} else if device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to fetch device")
|
||||
} else if device, err := mach.CryptoStore.GetDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil {
|
||||
return nil, fmt.Errorf("failed to get device %s/%s from store: %w", mach.Client.UserID, keyName, err)
|
||||
} else if device == nil {
|
||||
log.Warn().Err(err).Msg("Device does not exist, ignoring signature")
|
||||
continue
|
||||
} else if !mach.IsDeviceTrusted(device) {
|
||||
log.Warn().Err(err).Msg("Device is not trusted")
|
||||
|
@ -163,7 +165,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
|
|||
}
|
||||
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *igsInternal,
|
||||
Internal: igsInternal,
|
||||
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
|
||||
SenderKey: keyBackupData.SenderKey,
|
||||
RoomID: roomID,
|
||||
|
|
|
@ -104,7 +104,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
|
|||
return false, ErrMismatchingExportedSessionID
|
||||
}
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *igsInternal,
|
||||
Internal: igsInternal,
|
||||
SigningKey: session.SenderClaimedKeys.Ed25519,
|
||||
SenderKey: session.SenderKey,
|
||||
RoomID: session.RoomID,
|
||||
|
|
|
@ -172,7 +172,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
|
||||
}
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *igsInternal,
|
||||
Internal: igsInternal,
|
||||
SigningKey: evt.Keys.Ed25519,
|
||||
SenderKey: content.SenderKey,
|
||||
RoomID: content.RoomID,
|
||||
|
@ -184,6 +184,11 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
|||
MaxMessages: maxMessages,
|
||||
IsScheduled: content.IsScheduled,
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
// We already have an equivalent or better session in the store, so don't override it.
|
||||
return false
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store new inbound group session")
|
||||
|
|
|
@ -681,6 +681,11 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
|||
log.Debug().Msg("No one-time keys nor device keys got when trying to share keys")
|
||||
return nil
|
||||
}
|
||||
// Save the keys before sending the upload request in case there is a
|
||||
// network failure.
|
||||
if err := mach.saveAccount(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
req := &mautrix.ReqUploadKeys{
|
||||
DeviceKeys: deviceKeys,
|
||||
OneTimeKeys: oneTimeKeys,
|
||||
|
@ -691,6 +696,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
|||
return err
|
||||
}
|
||||
mach.lastOTKUpload = time.Now()
|
||||
mach.account.Internal.MarkKeysAsPublished()
|
||||
mach.account.Shared = true
|
||||
return mach.saveAccount(ctx)
|
||||
}
|
||||
|
|
|
@ -77,6 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) {
|
|||
otk = otkTmp
|
||||
break
|
||||
}
|
||||
machineIn.account.Internal.MarkKeysAsPublished()
|
||||
|
||||
// create outbound olm session for sending machine using OTK
|
||||
olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key)
|
||||
|
|
|
@ -1,154 +1,20 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm/account"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Account stores a device account for end to end encrypted messaging.
|
||||
type Account struct {
|
||||
account.Account
|
||||
}
|
||||
import "maunium.net/go/mautrix/crypto/goolm/account"
|
||||
|
||||
// NewAccount creates a new Account.
|
||||
func NewAccount() *Account {
|
||||
a, err := account.NewAccount(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ac := &Account{}
|
||||
ac.Account = *a
|
||||
return ac
|
||||
func NewAccount() Account {
|
||||
return account.NewAccount()
|
||||
}
|
||||
|
||||
func NewBlankAccount() *Account {
|
||||
return &Account{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this Account.
|
||||
func (a *Account) Clear() error {
|
||||
a.Account = account.Account{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns an Account as a base64 string. Encrypts the Account using the
|
||||
// supplied key.
|
||||
func (a *Account) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := a.Account.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
|
||||
func (a *Account) IdentityKeysJSON() []byte {
|
||||
identityKeys, err := a.Account.IdentityKeysJSON()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return identityKeys
|
||||
}
|
||||
|
||||
// Sign returns the signature of a message using the ed25519 key for this
|
||||
// Account.
|
||||
func (a *Account) Sign(message []byte) []byte {
|
||||
if len(message) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
signature, err := a.Account.Sign(message)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return signature
|
||||
}
|
||||
|
||||
// SignJSON signs the given JSON object following the Matrix specification:
|
||||
// https://matrix.org/docs/spec/appendices#signing-json
|
||||
func (a *Account) SignJSON(obj interface{}) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil
|
||||
}
|
||||
|
||||
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
|
||||
// Account can store.
|
||||
func (a *Account) MaxNumberOfOneTimeKeys() uint {
|
||||
return uint(account.MaxOneTimeKeys)
|
||||
}
|
||||
|
||||
// GenOneTimeKeys generates a number of new one time keys. If the total number
|
||||
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
|
||||
// keys are discarded.
|
||||
func (a *Account) GenOneTimeKeys(num uint) {
|
||||
err := a.Account.GenOneTimeKeys(nil, num)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewOutboundSession creates a new out-bound session for sending messages to a
|
||||
// given curve25519 identityKey and oneTimeKey. Returns error on failure.
|
||||
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := &Session{}
|
||||
newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.OlmSession = *newSession
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundSession creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure.
|
||||
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := &Session{}
|
||||
newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.OlmSession = *newSession
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure.
|
||||
func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := &Session{}
|
||||
newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.OlmSession = *newSession
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RemoveOneTimeKeys removes the one time keys that the session used from the
|
||||
// Account. Returns error on failure.
|
||||
func (a *Account) RemoveOneTimeKeys(s *Session) error {
|
||||
a.Account.RemoveOneTimeKeys(&s.OlmSession)
|
||||
return nil
|
||||
func NewBlankAccount() Account {
|
||||
return &account.Account{}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type Account interface {
|
||||
// Pickle returns an Account as a base64 string. Encrypts the Account using the
|
||||
// supplied key.
|
||||
Pickle(key []byte) ([]byte, error)
|
||||
|
||||
// Unpickle loads an Account from a pickled base64 string. Decrypts the
|
||||
// Account using the supplied key. Returns error on failure.
|
||||
Unpickle(pickled, key []byte) error
|
||||
|
||||
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
|
||||
IdentityKeysJSON() ([]byte, error)
|
||||
|
||||
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity
|
||||
// keys for the Account.
|
||||
IdentityKeys() (id.Ed25519, id.Curve25519, error)
|
||||
|
||||
// Sign returns the signature of a message using the ed25519 key for this
|
||||
// Account.
|
||||
Sign(message []byte) ([]byte, error)
|
||||
|
||||
// SignJSON signs the given JSON object following the Matrix specification:
|
||||
// https://matrix.org/docs/spec/appendices#signing-json
|
||||
SignJSON(obj any) (string, error)
|
||||
|
||||
// OneTimeKeys returns the public parts of the unpublished one time keys for
|
||||
// the Account.
|
||||
//
|
||||
// The returned data is a struct with the single value "Curve25519", which is
|
||||
// itself an object mapping key id to base64-encoded Curve25519 key. For
|
||||
// example:
|
||||
//
|
||||
// {
|
||||
// Curve25519: {
|
||||
// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
|
||||
// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
|
||||
// }
|
||||
// }
|
||||
OneTimeKeys() (map[string]id.Curve25519, error)
|
||||
|
||||
// MarkKeysAsPublished marks the current set of one time keys as being
|
||||
// published.
|
||||
MarkKeysAsPublished()
|
||||
|
||||
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
|
||||
// Account can store.
|
||||
MaxNumberOfOneTimeKeys() uint
|
||||
|
||||
// GenOneTimeKeys generates a number of new one time keys. If the total
|
||||
// number of keys stored by this Account exceeds MaxNumberOfOneTimeKeys
|
||||
// then the old keys are discarded. Reads random data from the given
|
||||
// reader, or if nil is passed, defaults to crypto/rand.
|
||||
GenOneTimeKeys(reader io.Reader, num uint) error
|
||||
|
||||
// NewOutboundSession creates a new out-bound session for sending messages to a
|
||||
// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the
|
||||
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
|
||||
NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error)
|
||||
|
||||
// NewInboundSession creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure. If
|
||||
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
|
||||
// the message was for an unsupported protocol version then the error will be
|
||||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
|
||||
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
|
||||
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
NewInboundSession(oneTimeKeyMsg string) (Session, error)
|
||||
|
||||
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure. If
|
||||
// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If
|
||||
// the message was for an unsupported protocol version then the error will be
|
||||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
|
||||
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
|
||||
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error)
|
||||
|
||||
// RemoveOneTimeKeys removes the one time keys that the session used from the
|
||||
// Account. Returns error on failure. If the Account doesn't have any
|
||||
// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
RemoveOneTimeKeys(s Session) error
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
|
@ -19,18 +20,21 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Account stores a device account for end to end encrypted messaging.
|
||||
type Account struct {
|
||||
// LibOlmAccount stores a device account for end to end encrypted messaging.
|
||||
type LibOlmAccount struct {
|
||||
int *C.OlmAccount
|
||||
mem []byte
|
||||
}
|
||||
|
||||
// Ensure that LibOlmAccount implements Account.
|
||||
var _ Account = (*LibOlmAccount)(nil)
|
||||
|
||||
// AccountFromPickled loads an Account from a pickled base64 string. Decrypts
|
||||
// the Account using the supplied key. Returns error on failure. If the key
|
||||
// doesn't match the one used to encrypt the Account then the error will be
|
||||
// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be
|
||||
// "INVALID_BASE64".
|
||||
func AccountFromPickled(pickled, key []byte) (*Account, error) {
|
||||
func AccountFromPickled(pickled, key []byte) (Account, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -38,17 +42,21 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) {
|
|||
return a, a.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
func NewBlankAccount() *Account {
|
||||
func NewBlankLibOlmAccount() *LibOlmAccount {
|
||||
memory := make([]byte, accountSize())
|
||||
return &Account{
|
||||
return &LibOlmAccount{
|
||||
int: C.olm_account(unsafe.Pointer(&memory[0])),
|
||||
mem: memory,
|
||||
}
|
||||
}
|
||||
|
||||
func NewBlankAccount() Account {
|
||||
return NewBlankLibOlmAccount()
|
||||
}
|
||||
|
||||
// NewAccount creates a new Account.
|
||||
func NewAccount() *Account {
|
||||
a := NewBlankAccount()
|
||||
func NewAccount() Account {
|
||||
a := NewBlankLibOlmAccount()
|
||||
random := make([]byte, a.createRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
|
@ -72,12 +80,12 @@ func accountSize() uint {
|
|||
|
||||
// lastError returns an error describing the most recent error to happen to an
|
||||
// account.
|
||||
func (a *Account) lastError() error {
|
||||
func (a *LibOlmAccount) lastError() error {
|
||||
return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int))))
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this Account.
|
||||
func (a *Account) Clear() error {
|
||||
func (a *LibOlmAccount) Clear() error {
|
||||
r := C.olm_clear_account((*C.OlmAccount)(a.int))
|
||||
if r == errorVal() {
|
||||
return a.lastError()
|
||||
|
@ -87,36 +95,36 @@ func (a *Account) Clear() error {
|
|||
}
|
||||
|
||||
// pickleLen returns the number of bytes needed to store an Account.
|
||||
func (a *Account) pickleLen() uint {
|
||||
func (a *LibOlmAccount) pickleLen() uint {
|
||||
return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int)))
|
||||
}
|
||||
|
||||
// createRandomLen returns the number of random bytes needed to create an
|
||||
// Account.
|
||||
func (a *Account) createRandomLen() uint {
|
||||
func (a *LibOlmAccount) createRandomLen() uint {
|
||||
return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int)))
|
||||
}
|
||||
|
||||
// identityKeysLen returns the size of the output buffer needed to hold the
|
||||
// identity keys.
|
||||
func (a *Account) identityKeysLen() uint {
|
||||
func (a *LibOlmAccount) identityKeysLen() uint {
|
||||
return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int)))
|
||||
}
|
||||
|
||||
// signatureLen returns the length of an ed25519 signature encoded as base64.
|
||||
func (a *Account) signatureLen() uint {
|
||||
func (a *LibOlmAccount) signatureLen() uint {
|
||||
return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int)))
|
||||
}
|
||||
|
||||
// oneTimeKeysLen returns the size of the output buffer needed to hold the one
|
||||
// time keys.
|
||||
func (a *Account) oneTimeKeysLen() uint {
|
||||
func (a *LibOlmAccount) oneTimeKeysLen() uint {
|
||||
return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int)))
|
||||
}
|
||||
|
||||
// genOneTimeKeysRandomLen returns the number of random bytes needed to
|
||||
// generate a given number of new one time keys.
|
||||
func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
|
||||
func (a *LibOlmAccount) genOneTimeKeysRandomLen(num uint) uint {
|
||||
return uint(C.olm_account_generate_one_time_keys_random_length(
|
||||
(*C.OlmAccount)(a.int),
|
||||
C.size_t(num)))
|
||||
|
@ -124,9 +132,9 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint {
|
|||
|
||||
// Pickle returns an Account as a base64 string. Encrypts the Account using the
|
||||
// supplied key.
|
||||
func (a *Account) Pickle(key []byte) []byte {
|
||||
func (a *LibOlmAccount) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
return nil, NoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, a.pickleLen())
|
||||
r := C.olm_pickle_account(
|
||||
|
@ -136,12 +144,12 @@ func (a *Account) Pickle(key []byte) []byte {
|
|||
unsafe.Pointer(&pickled[0]),
|
||||
C.size_t(len(pickled)))
|
||||
if r == errorVal() {
|
||||
panic(a.lastError())
|
||||
return nil, a.lastError()
|
||||
}
|
||||
return pickled[:r]
|
||||
return pickled[:r], nil
|
||||
}
|
||||
|
||||
func (a *Account) Unpickle(pickled, key []byte) error {
|
||||
func (a *LibOlmAccount) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
}
|
||||
|
@ -158,18 +166,21 @@ func (a *Account) Unpickle(pickled, key []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) GobEncode() ([]byte, error) {
|
||||
pickled := a.Pickle(pickleKey)
|
||||
func (a *LibOlmAccount) GobEncode() ([]byte, error) {
|
||||
pickled, err := a.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
rawPickled := make([]byte, length)
|
||||
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) GobDecode(rawPickled []byte) error {
|
||||
func (a *LibOlmAccount) GobDecode(rawPickled []byte) error {
|
||||
if a.int == nil {
|
||||
*a = *NewBlankAccount()
|
||||
*a = *NewBlankLibOlmAccount()
|
||||
}
|
||||
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
|
||||
pickled := make([]byte, length)
|
||||
|
@ -178,8 +189,11 @@ func (a *Account) GobDecode(rawPickled []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) MarshalJSON() ([]byte, error) {
|
||||
pickled := a.Pickle(pickleKey)
|
||||
func (a *LibOlmAccount) MarshalJSON() ([]byte, error) {
|
||||
pickled, err := a.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
quotes[0] = '"'
|
||||
quotes[len(quotes)-1] = '"'
|
||||
|
@ -188,41 +202,44 @@ func (a *Account) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) UnmarshalJSON(data []byte) error {
|
||||
func (a *LibOlmAccount) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
}
|
||||
if a.int == nil {
|
||||
*a = *NewBlankAccount()
|
||||
*a = *NewBlankLibOlmAccount()
|
||||
}
|
||||
return a.Unpickle(data[1:len(data)-1], pickleKey)
|
||||
}
|
||||
|
||||
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
|
||||
func (a *Account) IdentityKeysJSON() []byte {
|
||||
func (a *LibOlmAccount) IdentityKeysJSON() ([]byte, error) {
|
||||
identityKeys := make([]byte, a.identityKeysLen())
|
||||
r := C.olm_account_identity_keys(
|
||||
(*C.OlmAccount)(a.int),
|
||||
unsafe.Pointer(&identityKeys[0]),
|
||||
C.size_t(len(identityKeys)))
|
||||
if r == errorVal() {
|
||||
panic(a.lastError())
|
||||
return nil, a.lastError()
|
||||
} else {
|
||||
return identityKeys
|
||||
return identityKeys, nil
|
||||
}
|
||||
}
|
||||
|
||||
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity
|
||||
// keys for the Account.
|
||||
func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
|
||||
identityKeysJSON := a.IdentityKeysJSON()
|
||||
func (a *LibOlmAccount) IdentityKeys() (id.Ed25519, id.Curve25519, error) {
|
||||
identityKeysJSON, err := a.IdentityKeysJSON()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519")
|
||||
return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str)
|
||||
return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil
|
||||
}
|
||||
|
||||
// Sign returns the signature of a message using the ed25519 key for this
|
||||
// Account.
|
||||
func (a *Account) Sign(message []byte) []byte {
|
||||
func (a *LibOlmAccount) Sign(message []byte) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
|
@ -236,19 +253,20 @@ func (a *Account) Sign(message []byte) []byte {
|
|||
if r == errorVal() {
|
||||
panic(a.lastError())
|
||||
}
|
||||
return signature
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
// SignJSON signs the given JSON object following the Matrix specification:
|
||||
// https://matrix.org/docs/spec/appendices#signing-json
|
||||
func (a *Account) SignJSON(obj interface{}) (string, error) {
|
||||
func (a *LibOlmAccount) SignJSON(obj interface{}) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil
|
||||
signed, err := a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
|
||||
return string(signed), err
|
||||
}
|
||||
|
||||
// OneTimeKeys returns the public parts of the unpublished one time keys for
|
||||
|
@ -264,45 +282,44 @@ func (a *Account) SignJSON(obj interface{}) (string, error) {
|
|||
// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
|
||||
// }
|
||||
// }
|
||||
func (a *Account) OneTimeKeys() map[string]id.Curve25519 {
|
||||
func (a *LibOlmAccount) OneTimeKeys() (map[string]id.Curve25519, error) {
|
||||
oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen())
|
||||
r := C.olm_account_one_time_keys(
|
||||
(*C.OlmAccount)(a.int),
|
||||
unsafe.Pointer(&oneTimeKeysJSON[0]),
|
||||
C.size_t(len(oneTimeKeysJSON)))
|
||||
if r == errorVal() {
|
||||
panic(a.lastError())
|
||||
return nil, a.lastError()
|
||||
}
|
||||
var oneTimeKeys struct {
|
||||
Curve25519 map[string]id.Curve25519 `json:"curve25519"`
|
||||
}
|
||||
err := json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return oneTimeKeys.Curve25519
|
||||
return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys)
|
||||
}
|
||||
|
||||
// MarkKeysAsPublished marks the current set of one time keys as being
|
||||
// published.
|
||||
func (a *Account) MarkKeysAsPublished() {
|
||||
func (a *LibOlmAccount) MarkKeysAsPublished() {
|
||||
C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int))
|
||||
}
|
||||
|
||||
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
|
||||
// Account can store.
|
||||
func (a *Account) MaxNumberOfOneTimeKeys() uint {
|
||||
func (a *LibOlmAccount) MaxNumberOfOneTimeKeys() uint {
|
||||
return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int)))
|
||||
}
|
||||
|
||||
// GenOneTimeKeys generates a number of new one time keys. If the total number
|
||||
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
|
||||
// keys are discarded.
|
||||
func (a *Account) GenOneTimeKeys(num uint) {
|
||||
func (a *LibOlmAccount) GenOneTimeKeys(reader io.Reader, num uint) error {
|
||||
random := make([]byte, a.genOneTimeKeysRandomLen(num)+1)
|
||||
_, err := rand.Read(random)
|
||||
if reader == nil {
|
||||
reader = rand.Reader
|
||||
}
|
||||
_, err := reader.Read(random)
|
||||
if err != nil {
|
||||
panic(NotEnoughGoRandom)
|
||||
return NotEnoughGoRandom
|
||||
}
|
||||
r := C.olm_account_generate_one_time_keys(
|
||||
(*C.OlmAccount)(a.int),
|
||||
|
@ -310,18 +327,19 @@ func (a *Account) GenOneTimeKeys(num uint) {
|
|||
unsafe.Pointer(&random[0]),
|
||||
C.size_t(len(random)))
|
||||
if r == errorVal() {
|
||||
panic(a.lastError())
|
||||
return a.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewOutboundSession creates a new out-bound session for sending messages to a
|
||||
// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the
|
||||
// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64"
|
||||
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) {
|
||||
func (a *LibOlmAccount) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
s := NewBlankLibOlmSession()
|
||||
random := make([]byte, s.createOutboundRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
|
@ -349,11 +367,11 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2
|
|||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
|
||||
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
|
||||
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
|
||||
func (a *LibOlmAccount) NewInboundSession(oneTimeKeyMsg string) (Session, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
s := NewBlankLibOlmSession()
|
||||
r := C.olm_create_inbound_session(
|
||||
(*C.OlmSession)(s.int),
|
||||
(*C.OlmAccount)(a.int),
|
||||
|
@ -372,16 +390,16 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
|
|||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the
|
||||
// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one
|
||||
// time key then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
|
||||
func (a *LibOlmAccount) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error) {
|
||||
if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
s := NewBlankLibOlmSession()
|
||||
r := C.olm_create_inbound_session_from(
|
||||
(*C.OlmSession)(s.int),
|
||||
(*C.OlmAccount)(a.int),
|
||||
unsafe.Pointer(&([]byte(theirIdentityKey)[0])),
|
||||
C.size_t(len(theirIdentityKey)),
|
||||
unsafe.Pointer(&([]byte(*theirIdentityKey)[0])),
|
||||
C.size_t(len(*theirIdentityKey)),
|
||||
unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])),
|
||||
C.size_t(len(oneTimeKeyMsg)))
|
||||
if r == errorVal() {
|
||||
|
@ -393,10 +411,10 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeK
|
|||
// RemoveOneTimeKeys removes the one time keys that the session used from the
|
||||
// Account. Returns error on failure. If the Account doesn't have any
|
||||
// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID".
|
||||
func (a *Account) RemoveOneTimeKeys(s *Session) error {
|
||||
func (a *LibOlmAccount) RemoveOneTimeKeys(s Session) error {
|
||||
r := C.olm_remove_one_time_keys(
|
||||
(*C.OlmAccount)(a.int),
|
||||
(*C.OlmSession)(s.int))
|
||||
(*C.OlmSession)(s.(*LibOlmSession).int))
|
||||
if r == errorVal() {
|
||||
return a.lastError()
|
||||
}
|
|
@ -4,146 +4,39 @@ package olm
|
|||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// InboundGroupSession stores an inbound encrypted messaging session for a
|
||||
// group.
|
||||
type InboundGroupSession struct {
|
||||
session.MegolmInboundSession
|
||||
}
|
||||
|
||||
// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled
|
||||
// base64 string. Decrypts the InboundGroupSession using the supplied key.
|
||||
// Returns error on failure.
|
||||
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
|
||||
func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
lenKey := len(key)
|
||||
if lenKey == 0 {
|
||||
if len(key) == 0 {
|
||||
key = []byte(" ")
|
||||
}
|
||||
megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
MegolmInboundSession: *megolmSession,
|
||||
}, nil
|
||||
return session.MegolmInboundSessionFromPickled(pickled, key)
|
||||
}
|
||||
|
||||
// NewInboundGroupSession creates a new inbound group session from a key
|
||||
// exported from OutboundGroupSession.Key(). Returns error on failure.
|
||||
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
func NewInboundGroupSession(sessionKey []byte) (InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
megolmSession, err := session.NewMegolmInboundSession(sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
MegolmInboundSession: *megolmSession,
|
||||
}, nil
|
||||
return session.NewMegolmInboundSession(sessionKey)
|
||||
}
|
||||
|
||||
// InboundGroupSessionImport imports an inbound group session from a previous
|
||||
// export. Returns error on failure.
|
||||
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
func InboundGroupSessionImport(sessionKey []byte) (InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
MegolmInboundSession: *megolmSession,
|
||||
}, nil
|
||||
return session.NewMegolmInboundSessionFromExport(sessionKey)
|
||||
}
|
||||
|
||||
func NewBlankInboundGroupSession() *InboundGroupSession {
|
||||
return &InboundGroupSession{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this InboundGroupSession.
|
||||
func (s *InboundGroupSession) Clear() error {
|
||||
s.MegolmInboundSession = session.MegolmInboundSession{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
|
||||
// InboundGroupSession using the supplied key.
|
||||
func (s *InboundGroupSession) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := s.MegolmInboundSession.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
return EmptyInput
|
||||
}
|
||||
sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.MegolmInboundSession = *sOlm
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a message using the InboundGroupSession. Returns the the
|
||||
// plain-text and message index on success. Returns error on failure.
|
||||
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, 0, EmptyInput
|
||||
}
|
||||
plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return plaintext, uint(messageIndex), nil
|
||||
}
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
func (s *InboundGroupSession) ID() id.SessionID {
|
||||
return s.MegolmInboundSession.SessionID()
|
||||
}
|
||||
|
||||
// FirstKnownIndex returns the first message index we know how to decrypt.
|
||||
func (s *InboundGroupSession) FirstKnownIndex() uint32 {
|
||||
return s.MegolmInboundSession.InitialRatchet.Counter
|
||||
}
|
||||
|
||||
// IsVerified check if the session has been verified as a valid session. (A
|
||||
// session is verified either because the original session share was signed, or
|
||||
// because we have subsequently successfully decrypted a message.)
|
||||
func (s *InboundGroupSession) IsVerified() uint {
|
||||
if s.MegolmInboundSession.SigningKeyVerified {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Export returns the base64-encoded ratchet key for this session, at the given
|
||||
// index, in a format which can be used by
|
||||
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
|
||||
// InboundGroupSession using the supplied key. Returns error on failure.
|
||||
// if we do not have a session key corresponding to the given index (ie, it was
|
||||
// sent before the session key was shared with us) the error will be
|
||||
// returned.
|
||||
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
|
||||
res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
func NewBlankInboundGroupSession() InboundGroupSession {
|
||||
return &session.MegolmInboundSession{}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type InboundGroupSession interface {
|
||||
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
|
||||
// InboundGroupSession using the supplied key.
|
||||
Pickle(key []byte) ([]byte, error)
|
||||
|
||||
// Unpickle loads an [InboundGroupSession] from a pickled base64 string.
|
||||
// Decrypts the [InboundGroupSession] using the supplied key.
|
||||
Unpickle(pickled, key []byte) error
|
||||
|
||||
// Decrypt decrypts a message using the [InboundGroupSession]. Returns the
|
||||
// plain-text and message index on success. Returns error on failure. If
|
||||
// the base64 couldn't be decoded then the error will be "INVALID_BASE64".
|
||||
// If the message is for an unsupported version of the protocol then the
|
||||
// error will be "BAD_MESSAGE_VERSION". If the message couldn't be decoded
|
||||
// then the error will be BAD_MESSAGE_FORMAT". If the MAC on the message
|
||||
// was invalid then the error will be "BAD_MESSAGE_MAC". If we do not have
|
||||
// a session key corresponding to the message's index (ie, it was sent
|
||||
// before the session key was shared with us) the error will be
|
||||
// "OLM_UNKNOWN_MESSAGE_INDEX".
|
||||
Decrypt(message []byte) ([]byte, uint, error)
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
ID() id.SessionID
|
||||
|
||||
// FirstKnownIndex returns the first message index we know how to decrypt.
|
||||
FirstKnownIndex() uint32
|
||||
|
||||
// IsVerified check if the session has been verified as a valid session.
|
||||
// (A session is verified either because the original session share was
|
||||
// signed, or because we have subsequently successfully decrypted a
|
||||
// message.)
|
||||
IsVerified() bool
|
||||
|
||||
// Export returns the base64-encoded ratchet key for this session, at the
|
||||
// given index, in a format which can be used by
|
||||
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
|
||||
// InboundGroupSession using the supplied key. Returns error on failure.
|
||||
// if we do not have a session key corresponding to the given index (ie, it
|
||||
// was sent before the session key was shared with us) the error will be
|
||||
// "OLM_UNKNOWN_MESSAGE_INDEX".
|
||||
Export(messageIndex uint32) ([]byte, error)
|
||||
}
|
||||
|
||||
var _ InboundGroupSession = (*session.MegolmInboundSession)(nil)
|
|
@ -13,19 +13,22 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// InboundGroupSession stores an inbound encrypted messaging session for a
|
||||
// LibOlmInboundGroupSession stores an inbound encrypted messaging session for a
|
||||
// group.
|
||||
type InboundGroupSession struct {
|
||||
type LibOlmInboundGroupSession struct {
|
||||
int *C.OlmInboundGroupSession
|
||||
mem []byte
|
||||
}
|
||||
|
||||
// Ensure that LibOlmInboundGroupSession implements InboundGroupSession.
|
||||
var _ InboundGroupSession = (*LibOlmInboundGroupSession)(nil)
|
||||
|
||||
// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled
|
||||
// base64 string. Decrypts the InboundGroupSession using the supplied key.
|
||||
// Returns error on failure. If the key doesn't match the one used to encrypt
|
||||
// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the
|
||||
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
|
||||
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
|
||||
func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -42,7 +45,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession,
|
|||
// If the sessionKey is not valid base64 the error will be
|
||||
// "OLM_INVALID_BASE64". If the session_key is invalid the error will be
|
||||
// "OLM_BAD_SESSION_KEY".
|
||||
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
func NewInboundGroupSession(sessionKey []byte) (*LibOlmInboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -61,7 +64,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
|
|||
// export. Returns error on failure. If the sessionKey is not valid base64
|
||||
// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the
|
||||
// error will be "OLM_BAD_SESSION_KEY".
|
||||
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
func InboundGroupSessionImport(sessionKey []byte) (*LibOlmInboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -83,9 +86,9 @@ func inboundGroupSessionSize() uint {
|
|||
}
|
||||
|
||||
// newInboundGroupSession initialises an empty InboundGroupSession.
|
||||
func NewBlankInboundGroupSession() *InboundGroupSession {
|
||||
func NewBlankInboundGroupSession() *LibOlmInboundGroupSession {
|
||||
memory := make([]byte, inboundGroupSessionSize())
|
||||
return &InboundGroupSession{
|
||||
return &LibOlmInboundGroupSession{
|
||||
int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])),
|
||||
mem: memory,
|
||||
}
|
||||
|
@ -93,12 +96,12 @@ func NewBlankInboundGroupSession() *InboundGroupSession {
|
|||
|
||||
// lastError returns an error describing the most recent error to happen to an
|
||||
// inbound group session.
|
||||
func (s *InboundGroupSession) lastError() error {
|
||||
func (s *LibOlmInboundGroupSession) lastError() error {
|
||||
return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int))))
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this InboundGroupSession.
|
||||
func (s *InboundGroupSession) Clear() error {
|
||||
func (s *LibOlmInboundGroupSession) Clear() error {
|
||||
r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int))
|
||||
if r == errorVal() {
|
||||
return s.lastError()
|
||||
|
@ -108,15 +111,15 @@ func (s *InboundGroupSession) Clear() error {
|
|||
|
||||
// pickleLen returns the number of bytes needed to store an inbound group
|
||||
// session.
|
||||
func (s *InboundGroupSession) pickleLen() uint {
|
||||
func (s *LibOlmInboundGroupSession) pickleLen() uint {
|
||||
return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
|
||||
// InboundGroupSession using the supplied key.
|
||||
func (s *InboundGroupSession) Pickle(key []byte) []byte {
|
||||
func (s *LibOlmInboundGroupSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
return nil, NoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, s.pickleLen())
|
||||
r := C.olm_pickle_inbound_group_session(
|
||||
|
@ -126,12 +129,12 @@ func (s *InboundGroupSession) Pickle(key []byte) []byte {
|
|||
unsafe.Pointer(&pickled[0]),
|
||||
C.size_t(len(pickled)))
|
||||
if r == errorVal() {
|
||||
panic(s.lastError())
|
||||
return nil, s.lastError()
|
||||
}
|
||||
return pickled[:r]
|
||||
return pickled[:r], nil
|
||||
}
|
||||
|
||||
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
func (s *LibOlmInboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
|
@ -150,16 +153,19 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) GobEncode() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
func (s *LibOlmInboundGroupSession) GobEncode() ([]byte, error) {
|
||||
pickled, err := s.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
rawPickled := make([]byte, length)
|
||||
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
func (s *LibOlmInboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankInboundGroupSession()
|
||||
}
|
||||
|
@ -170,8 +176,11 @@ func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
func (s *LibOlmInboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
pickled, err := s.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
quotes[0] = '"'
|
||||
quotes[len(quotes)-1] = '"'
|
||||
|
@ -180,7 +189,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
func (s *LibOlmInboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
}
|
||||
|
@ -203,7 +212,7 @@ func clone(original []byte) []byte {
|
|||
// unsupported version of the protocol then the error will be
|
||||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error
|
||||
// will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
|
||||
func (s *LibOlmInboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) {
|
||||
if len(message) == 0 {
|
||||
return 0, EmptyInput
|
||||
}
|
||||
|
@ -228,7 +237,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro
|
|||
// error will be "BAD_MESSAGE_MAC". If we do not have a session key
|
||||
// corresponding to the message's index (ie, it was sent before the session key
|
||||
// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX".
|
||||
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
|
||||
func (s *LibOlmInboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, 0, EmptyInput
|
||||
}
|
||||
|
@ -254,12 +263,12 @@ func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
|
|||
}
|
||||
|
||||
// sessionIdLen returns the number of bytes needed to store a session ID.
|
||||
func (s *InboundGroupSession) sessionIdLen() uint {
|
||||
func (s *LibOlmInboundGroupSession) sessionIdLen() uint {
|
||||
return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
func (s *InboundGroupSession) ID() id.SessionID {
|
||||
func (s *LibOlmInboundGroupSession) ID() id.SessionID {
|
||||
sessionID := make([]byte, s.sessionIdLen())
|
||||
r := C.olm_inbound_group_session_id(
|
||||
(*C.OlmInboundGroupSession)(s.int),
|
||||
|
@ -272,20 +281,20 @@ func (s *InboundGroupSession) ID() id.SessionID {
|
|||
}
|
||||
|
||||
// FirstKnownIndex returns the first message index we know how to decrypt.
|
||||
func (s *InboundGroupSession) FirstKnownIndex() uint32 {
|
||||
func (s *LibOlmInboundGroupSession) FirstKnownIndex() uint32 {
|
||||
return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// IsVerified check if the session has been verified as a valid session. (A
|
||||
// session is verified either because the original session share was signed, or
|
||||
// because we have subsequently successfully decrypted a message.)
|
||||
func (s *InboundGroupSession) IsVerified() uint {
|
||||
return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int)))
|
||||
func (s *LibOlmInboundGroupSession) IsVerified() bool {
|
||||
return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) == 1
|
||||
}
|
||||
|
||||
// exportLen returns the number of bytes needed to export an inbound group
|
||||
// session.
|
||||
func (s *InboundGroupSession) exportLen() uint {
|
||||
func (s *LibOlmInboundGroupSession) exportLen() uint {
|
||||
return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
|
@ -296,7 +305,7 @@ func (s *InboundGroupSession) exportLen() uint {
|
|||
// if we do not have a session key corresponding to the given index (ie, it was
|
||||
// sent before the session key was shared with us) the error will be
|
||||
// "OLM_UNKNOWN_MESSAGE_INDEX".
|
||||
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
|
||||
func (s *LibOlmInboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
|
||||
key := make([]byte, s.exportLen())
|
||||
r := C.olm_export_inbound_group_session(
|
||||
(*C.OlmInboundGroupSession)(s.int),
|
|
@ -2,13 +2,6 @@
|
|||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Signatures is the data structure used to sign JSON objects.
|
||||
type Signatures map[id.UserID]map[id.DeviceKeyID]string
|
||||
|
||||
// Version returns the version number of the olm library.
|
||||
func Version() (major, minor, patch uint8) {
|
||||
return 3, 2, 15
|
||||
|
|
|
@ -5,12 +5,6 @@ package olm
|
|||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
// #include <olm/olm.h>
|
||||
import "C"
|
||||
import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Signatures is the data structure used to sign JSON objects.
|
||||
type Signatures map[id.UserID]map[id.DeviceKeyID]string
|
||||
|
||||
// Version returns the version number of the olm library.
|
||||
func Version() (major, minor, patch uint8) {
|
|
@ -4,21 +4,14 @@ package olm
|
|||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// OutboundGroupSession stores an outbound encrypted messaging session for a
|
||||
// group.
|
||||
type OutboundGroupSession struct {
|
||||
session.MegolmOutboundSession
|
||||
}
|
||||
|
||||
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
|
||||
// base64 string. Decrypts the OutboundGroupSession using the supplied key.
|
||||
// Returns error on failure. If the key doesn't match the one used to encrypt
|
||||
// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the
|
||||
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
|
||||
func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) {
|
||||
func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -26,86 +19,19 @@ func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession
|
|||
if lenKey == 0 {
|
||||
key = []byte(" ")
|
||||
}
|
||||
megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OutboundGroupSession{
|
||||
MegolmOutboundSession: *megolmSession,
|
||||
}, nil
|
||||
return session.MegolmOutboundSessionFromPickled(pickled, key)
|
||||
}
|
||||
|
||||
// NewOutboundGroupSession creates a new outbound group session.
|
||||
func NewOutboundGroupSession() *OutboundGroupSession {
|
||||
megolmSession, err := session.NewMegolmOutboundSession()
|
||||
func NewOutboundGroupSession() OutboundGroupSession {
|
||||
session, err := session.NewMegolmOutboundSession()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &OutboundGroupSession{
|
||||
MegolmOutboundSession: *megolmSession,
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
// newOutboundGroupSession initialises an empty OutboundGroupSession.
|
||||
func NewBlankOutboundGroupSession() *OutboundGroupSession {
|
||||
return &OutboundGroupSession{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this OutboundGroupSession.
|
||||
func (s *OutboundGroupSession) Clear() error {
|
||||
s.MegolmOutboundSession = session.MegolmOutboundSession{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the
|
||||
// OutboundGroupSession using the supplied key.
|
||||
func (s *OutboundGroupSession) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := s.MegolmOutboundSession.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
}
|
||||
return s.MegolmOutboundSession.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message
|
||||
// as base64.
|
||||
func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
|
||||
if len(plaintext) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
message, err := s.MegolmOutboundSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
func (s *OutboundGroupSession) ID() id.SessionID {
|
||||
return s.MegolmOutboundSession.SessionID()
|
||||
}
|
||||
|
||||
// MessageIndex returns the message index for this session. Each message is
|
||||
// sent with an increasing index; this returns the index for the next message.
|
||||
func (s *OutboundGroupSession) MessageIndex() uint {
|
||||
return uint(s.MegolmOutboundSession.Ratchet.Counter)
|
||||
}
|
||||
|
||||
// Key returns the base64-encoded current ratchet key for this session.
|
||||
func (s *OutboundGroupSession) Key() string {
|
||||
message, err := s.MegolmOutboundSession.SessionSharingMessage()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(message)
|
||||
// NewBlankOutboundGroupSession initialises an empty OutboundGroupSession.
|
||||
func NewBlankOutboundGroupSession() OutboundGroupSession {
|
||||
return &session.MegolmOutboundSession{}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type OutboundGroupSession interface {
|
||||
// Pickle returns a Session as a base64 string. Encrypts the Session using
|
||||
// the supplied key.
|
||||
Pickle(key []byte) ([]byte, error)
|
||||
|
||||
// Unpickle loads an [OutboundGroupSession] from a pickled base64 string.
|
||||
// Decrypts the [OutboundGroupSession] using the supplied key.
|
||||
Unpickle(pickled, key []byte) error
|
||||
|
||||
// Encrypt encrypts a message using the [OutboundGroupSession]. Returns the
|
||||
// encrypted message as base64.
|
||||
Encrypt(plaintext []byte) ([]byte, error)
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
ID() id.SessionID
|
||||
|
||||
// MessageIndex returns the message index for this session. Each message
|
||||
// is sent with an increasing index; this returns the index for the next
|
||||
// message.
|
||||
MessageIndex() uint
|
||||
|
||||
// Key returns the base64-encoded current ratchet key for this session.
|
||||
Key() string
|
||||
}
|
||||
|
||||
var _ OutboundGroupSession = (*session.MegolmOutboundSession)(nil)
|
|
@ -14,29 +14,32 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// OutboundGroupSession stores an outbound encrypted messaging session for a
|
||||
// group.
|
||||
type OutboundGroupSession struct {
|
||||
// LibOlmOutboundGroupSession stores an outbound encrypted messaging session
|
||||
// for a group.
|
||||
type LibOlmOutboundGroupSession struct {
|
||||
int *C.OlmOutboundGroupSession
|
||||
mem []byte
|
||||
}
|
||||
|
||||
// Ensure that LibOlmOutboundGroupSession implements OutboundGroupSession.
|
||||
var _ OutboundGroupSession = (*LibOlmOutboundGroupSession)(nil)
|
||||
|
||||
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
|
||||
// base64 string. Decrypts the OutboundGroupSession using the supplied key.
|
||||
// Returns error on failure. If the key doesn't match the one used to encrypt
|
||||
// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the
|
||||
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
|
||||
func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) {
|
||||
func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := NewBlankOutboundGroupSession()
|
||||
s := NewBlankLibOlmOutboundGroupSession()
|
||||
return s, s.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
// NewOutboundGroupSession creates a new outbound group session.
|
||||
func NewOutboundGroupSession() *OutboundGroupSession {
|
||||
s := NewBlankOutboundGroupSession()
|
||||
func NewOutboundGroupSession() OutboundGroupSession {
|
||||
s := NewBlankLibOlmOutboundGroupSession()
|
||||
random := make([]byte, s.createRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
|
@ -58,23 +61,29 @@ func outboundGroupSessionSize() uint {
|
|||
return uint(C.olm_outbound_group_session_size())
|
||||
}
|
||||
|
||||
// newOutboundGroupSession initialises an empty OutboundGroupSession.
|
||||
func NewBlankOutboundGroupSession() *OutboundGroupSession {
|
||||
// NewBlankLibOlmOutboundGroupSession initialises an empty
|
||||
// LibOlmOutboundGroupSession.
|
||||
func NewBlankLibOlmOutboundGroupSession() *LibOlmOutboundGroupSession {
|
||||
memory := make([]byte, outboundGroupSessionSize())
|
||||
return &OutboundGroupSession{
|
||||
return &LibOlmOutboundGroupSession{
|
||||
int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])),
|
||||
mem: memory,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBlankOutboundGroupSession initialises an empty OutboundGroupSession.
|
||||
func NewBlankOutboundGroupSession() OutboundGroupSession {
|
||||
return NewBlankLibOlmOutboundGroupSession()
|
||||
}
|
||||
|
||||
// lastError returns an error describing the most recent error to happen to an
|
||||
// outbound group session.
|
||||
func (s *OutboundGroupSession) lastError() error {
|
||||
func (s *LibOlmOutboundGroupSession) lastError() error {
|
||||
return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int))))
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this OutboundGroupSession.
|
||||
func (s *OutboundGroupSession) Clear() error {
|
||||
func (s *LibOlmOutboundGroupSession) Clear() error {
|
||||
r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int))
|
||||
if r == errorVal() {
|
||||
return s.lastError()
|
||||
|
@ -85,15 +94,15 @@ func (s *OutboundGroupSession) Clear() error {
|
|||
|
||||
// pickleLen returns the number of bytes needed to store an outbound group
|
||||
// session.
|
||||
func (s *OutboundGroupSession) pickleLen() uint {
|
||||
func (s *LibOlmOutboundGroupSession) pickleLen() uint {
|
||||
return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the
|
||||
// OutboundGroupSession using the supplied key.
|
||||
func (s *OutboundGroupSession) Pickle(key []byte) []byte {
|
||||
func (s *LibOlmOutboundGroupSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
return nil, NoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, s.pickleLen())
|
||||
r := C.olm_pickle_outbound_group_session(
|
||||
|
@ -103,12 +112,12 @@ func (s *OutboundGroupSession) Pickle(key []byte) []byte {
|
|||
unsafe.Pointer(&pickled[0]),
|
||||
C.size_t(len(pickled)))
|
||||
if r == errorVal() {
|
||||
panic(s.lastError())
|
||||
return nil, s.lastError()
|
||||
}
|
||||
return pickled[:r]
|
||||
return pickled[:r], nil
|
||||
}
|
||||
|
||||
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
func (s *LibOlmOutboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
}
|
||||
|
@ -125,18 +134,21 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) GobEncode() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
func (s *LibOlmOutboundGroupSession) GobEncode() ([]byte, error) {
|
||||
pickled, err := s.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
rawPickled := make([]byte, length)
|
||||
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
func (s *LibOlmOutboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankOutboundGroupSession()
|
||||
*s = *NewBlankLibOlmOutboundGroupSession()
|
||||
}
|
||||
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
|
||||
pickled := make([]byte, length)
|
||||
|
@ -145,8 +157,11 @@ func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
func (s *LibOlmOutboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
pickled, err := s.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
quotes[0] = '"'
|
||||
quotes[len(quotes)-1] = '"'
|
||||
|
@ -155,33 +170,33 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
func (s *LibOlmOutboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
}
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankOutboundGroupSession()
|
||||
*s = *NewBlankLibOlmOutboundGroupSession()
|
||||
}
|
||||
return s.Unpickle(data[1:len(data)-1], pickleKey)
|
||||
}
|
||||
|
||||
// createRandomLen returns the number of random bytes needed to create an
|
||||
// Account.
|
||||
func (s *OutboundGroupSession) createRandomLen() uint {
|
||||
func (s *LibOlmOutboundGroupSession) createRandomLen() uint {
|
||||
return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// encryptMsgLen returns the size of the next message in bytes for the given
|
||||
// number of plain-text bytes.
|
||||
func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
|
||||
func (s *LibOlmOutboundGroupSession) encryptMsgLen(plainTextLen int) uint {
|
||||
return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen)))
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message
|
||||
// as base64.
|
||||
func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
|
||||
func (s *LibOlmOutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
panic(EmptyInput)
|
||||
return nil, EmptyInput
|
||||
}
|
||||
message := make([]byte, s.encryptMsgLen(len(plaintext)))
|
||||
r := C.olm_group_encrypt(
|
||||
|
@ -191,18 +206,18 @@ func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
|
|||
(*C.uint8_t)(&message[0]),
|
||||
C.size_t(len(message)))
|
||||
if r == errorVal() {
|
||||
panic(s.lastError())
|
||||
return nil, s.lastError()
|
||||
}
|
||||
return message[:r]
|
||||
return message[:r], nil
|
||||
}
|
||||
|
||||
// sessionIdLen returns the number of bytes needed to store a session ID.
|
||||
func (s *OutboundGroupSession) sessionIdLen() uint {
|
||||
func (s *LibOlmOutboundGroupSession) sessionIdLen() uint {
|
||||
return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
func (s *OutboundGroupSession) ID() id.SessionID {
|
||||
func (s *LibOlmOutboundGroupSession) ID() id.SessionID {
|
||||
sessionID := make([]byte, s.sessionIdLen())
|
||||
r := C.olm_outbound_group_session_id(
|
||||
(*C.OlmOutboundGroupSession)(s.int),
|
||||
|
@ -216,17 +231,17 @@ func (s *OutboundGroupSession) ID() id.SessionID {
|
|||
|
||||
// MessageIndex returns the message index for this session. Each message is
|
||||
// sent with an increasing index; this returns the index for the next message.
|
||||
func (s *OutboundGroupSession) MessageIndex() uint {
|
||||
func (s *LibOlmOutboundGroupSession) MessageIndex() uint {
|
||||
return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// sessionKeyLen returns the number of bytes needed to store a session key.
|
||||
func (s *OutboundGroupSession) sessionKeyLen() uint {
|
||||
func (s *LibOlmOutboundGroupSession) sessionKeyLen() uint {
|
||||
return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int)))
|
||||
}
|
||||
|
||||
// Key returns the base64-encoded current ratchet key for this session.
|
||||
func (s *OutboundGroupSession) Key() string {
|
||||
func (s *LibOlmOutboundGroupSession) Key() string {
|
||||
sessionKey := make([]byte, s.sessionKeyLen())
|
||||
r := C.olm_outbound_group_session_key(
|
||||
(*C.OlmOutboundGroupSession)(s.int),
|
|
@ -1,71 +1,29 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// When the goolm build flag is enabled, this file will make [PKSigning]
|
||||
// constructors use the goolm constuctors.
|
||||
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
import "maunium.net/go/mautrix/crypto/goolm/pk"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm/pk"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// PkSigning stores a key pair for signing messages.
|
||||
type PkSigning struct {
|
||||
pk.Signing
|
||||
PublicKey id.Ed25519
|
||||
Seed []byte
|
||||
// NewPKSigningFromSeed creates a new PKSigning object using the given seed.
|
||||
func NewPKSigningFromSeed(seed []byte) (PKSigning, error) {
|
||||
return pk.NewSigningFromSeed(seed)
|
||||
}
|
||||
|
||||
// Clear clears the underlying memory of a PkSigning object.
|
||||
func (p *PkSigning) Clear() {
|
||||
p.Signing = pk.Signing{}
|
||||
// NewPKSigning creates a new [PKSigning] object, containing a key pair for
|
||||
// signing messages.
|
||||
func NewPKSigning() (PKSigning, error) {
|
||||
return pk.NewSigning()
|
||||
}
|
||||
|
||||
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
|
||||
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
|
||||
p := &PkSigning{}
|
||||
signing, err := pk.NewSigningFromSeed(seed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Signing = *signing
|
||||
p.Seed = seed
|
||||
p.PublicKey = p.Signing.PublicKey()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
|
||||
func NewPkSigning() (*PkSigning, error) {
|
||||
p := &PkSigning{}
|
||||
signing, err := pk.NewSigning()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Signing = *signing
|
||||
p.Seed = signing.Seed
|
||||
p.PublicKey = p.Signing.PublicKey()
|
||||
return p, err
|
||||
}
|
||||
|
||||
// Sign creates a signature for the given message using this key.
|
||||
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
|
||||
return p.Signing.Sign(message), nil
|
||||
}
|
||||
|
||||
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
|
||||
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(signature), nil
|
||||
func NewPKDecryption(privateKey []byte) (PKDecryption, error) {
|
||||
return pk.NewDecryption()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/pk"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// PKSigning is an interface for signing messages.
|
||||
type PKSigning interface {
|
||||
// Seed returns the seed of the key.
|
||||
Seed() []byte
|
||||
|
||||
// PublicKey returns the public key.
|
||||
PublicKey() id.Ed25519
|
||||
|
||||
// Sign creates a signature for the given message using this key.
|
||||
Sign(message []byte) ([]byte, error)
|
||||
|
||||
// SignJSON creates a signature for the given object after encoding it to
|
||||
// canonical JSON.
|
||||
SignJSON(obj any) (string, error)
|
||||
}
|
||||
|
||||
var _ PKSigning = (*pk.Signing)(nil)
|
||||
|
||||
// PKDecryption is an interface for decrypting messages.
|
||||
type PKDecryption interface {
|
||||
// PublicKey returns the public key.
|
||||
PublicKey() id.Curve25519
|
||||
|
||||
// Decrypt verifies and decrypts the given message.
|
||||
Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error)
|
||||
}
|
||||
|
||||
var _ PKDecryption = (*pk.Decryption)(nil)
|
|
@ -1,3 +1,9 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
@ -18,14 +24,17 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// PkSigning stores a key pair for signing messages.
|
||||
type PkSigning struct {
|
||||
// LibOlmPKSigning stores a key pair for signing messages.
|
||||
type LibOlmPKSigning struct {
|
||||
int *C.OlmPkSigning
|
||||
mem []byte
|
||||
PublicKey id.Ed25519
|
||||
Seed []byte
|
||||
publicKey id.Ed25519
|
||||
seed []byte
|
||||
}
|
||||
|
||||
// Ensure that LibOlmPKSigning implements PKSigning.
|
||||
var _ PKSigning = (*LibOlmPKSigning)(nil)
|
||||
|
||||
func pkSigningSize() uint {
|
||||
return uint(C.olm_pk_signing_size())
|
||||
}
|
||||
|
@ -42,48 +51,57 @@ func pkSigningSignatureLength() uint {
|
|||
return uint(C.olm_pk_signature_length())
|
||||
}
|
||||
|
||||
func NewBlankPkSigning() *PkSigning {
|
||||
func newBlankPKSigning() *LibOlmPKSigning {
|
||||
memory := make([]byte, pkSigningSize())
|
||||
return &PkSigning{
|
||||
return &LibOlmPKSigning{
|
||||
int: C.olm_pk_signing(unsafe.Pointer(&memory[0])),
|
||||
mem: memory,
|
||||
}
|
||||
}
|
||||
|
||||
// Clear clears the underlying memory of a PkSigning object.
|
||||
func (p *PkSigning) Clear() {
|
||||
C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int))
|
||||
}
|
||||
|
||||
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
|
||||
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
|
||||
p := NewBlankPkSigning()
|
||||
p.Clear()
|
||||
// NewPKSigningFromSeed creates a new [PKSigning] object using the given seed.
|
||||
func NewPKSigningFromSeed(seed []byte) (PKSigning, error) {
|
||||
p := newBlankPKSigning()
|
||||
p.clear()
|
||||
pubKey := make([]byte, pkSigningPublicKeyLength())
|
||||
if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int),
|
||||
unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
|
||||
unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() {
|
||||
return nil, p.lastError()
|
||||
}
|
||||
p.PublicKey = id.Ed25519(pubKey)
|
||||
p.Seed = seed
|
||||
p.publicKey = id.Ed25519(pubKey)
|
||||
p.seed = seed
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
|
||||
func NewPkSigning() (*PkSigning, error) {
|
||||
// NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for
|
||||
// signing messages.
|
||||
func NewPKSigning() (PKSigning, error) {
|
||||
// Generate the seed
|
||||
seed := make([]byte, pkSigningSeedLength())
|
||||
_, err := rand.Read(seed)
|
||||
if err != nil {
|
||||
panic(NotEnoughGoRandom)
|
||||
}
|
||||
pk, err := NewPkSigningFromSeed(seed)
|
||||
pk, err := NewPKSigningFromSeed(seed)
|
||||
return pk, err
|
||||
}
|
||||
|
||||
func (p *LibOlmPKSigning) PublicKey() id.Ed25519 {
|
||||
return p.publicKey
|
||||
}
|
||||
|
||||
func (p *LibOlmPKSigning) Seed() []byte {
|
||||
return p.seed
|
||||
}
|
||||
|
||||
// clear clears the underlying memory of a LibOlmPKSigning object.
|
||||
func (p *LibOlmPKSigning) clear() {
|
||||
C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int))
|
||||
}
|
||||
|
||||
// Sign creates a signature for the given message using this key.
|
||||
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
|
||||
func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) {
|
||||
signature := make([]byte, pkSigningSignatureLength())
|
||||
if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)),
|
||||
(*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() {
|
||||
|
@ -93,7 +111,7 @@ func (p *PkSigning) Sign(message []byte) ([]byte, error) {
|
|||
}
|
||||
|
||||
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
|
||||
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
|
||||
func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -107,12 +125,13 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
|
|||
return string(signature), nil
|
||||
}
|
||||
|
||||
// lastError returns the last error that happened in relation to this PkSigning object.
|
||||
func (p *PkSigning) lastError() error {
|
||||
// lastError returns the last error that happened in relation to this
|
||||
// LibOlmPKSigning object.
|
||||
func (p *LibOlmPKSigning) lastError() error {
|
||||
return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int))))
|
||||
}
|
||||
|
||||
type PkDecryption struct {
|
||||
type LibOlmPKDecryption struct {
|
||||
int *C.OlmPkDecryption
|
||||
mem []byte
|
||||
PublicKey []byte
|
||||
|
@ -126,13 +145,13 @@ func pkDecryptionPublicKeySize() uint {
|
|||
return uint(C.olm_pk_key_length())
|
||||
}
|
||||
|
||||
func NewPkDecryption(privateKey []byte) (*PkDecryption, error) {
|
||||
func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) {
|
||||
memory := make([]byte, pkDecryptionSize())
|
||||
p := &PkDecryption{
|
||||
p := &LibOlmPKDecryption{
|
||||
int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])),
|
||||
mem: memory,
|
||||
}
|
||||
p.Clear()
|
||||
p.clear()
|
||||
pubKey := make([]byte, pkDecryptionPublicKeySize())
|
||||
|
||||
if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int),
|
||||
|
@ -145,7 +164,7 @@ func NewPkDecryption(privateKey []byte) (*PkDecryption, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
|
||||
func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
|
||||
maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext))))
|
||||
plaintext := make([]byte, maxPlaintextLength)
|
||||
|
||||
|
@ -162,11 +181,12 @@ func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byt
|
|||
}
|
||||
|
||||
// Clear clears the underlying memory of a PkDecryption object.
|
||||
func (p *PkDecryption) Clear() {
|
||||
func (p *LibOlmPKDecryption) clear() {
|
||||
C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int))
|
||||
}
|
||||
|
||||
// lastError returns the last error that happened in relation to this PkDecryption object.
|
||||
func (p *PkDecryption) lastError() error {
|
||||
// lastError returns the last error that happened in relation to this
|
||||
// LibOlmPKDecryption object.
|
||||
func (p *LibOlmPKDecryption) lastError() error {
|
||||
return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int))))
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Only run this test if goolm is disabled (that is, libolm is used).
|
||||
//go:build !goolm
|
||||
|
||||
package olm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/pk"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
)
|
||||
|
||||
func FuzzSign(f *testing.F) {
|
||||
seed := []byte("Quohboh3ka3ooghequier9lee8Bahwoh")
|
||||
goolmPkSigning, err := pk.NewSigningFromSeed(seed)
|
||||
require.NoError(f, err)
|
||||
|
||||
libolmPkSigning, err := olm.NewPKSigningFromSeed(seed)
|
||||
require.NoError(f, err)
|
||||
|
||||
f.Add([]byte("message"))
|
||||
|
||||
f.Fuzz(func(t *testing.T, message []byte) {
|
||||
// libolm breaks with empty messages, so don't perform differential
|
||||
// fuzzing on that.
|
||||
if len(message) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
libolmResult, libolmErr := libolmPkSigning.Sign(message)
|
||||
goolmResult, goolmErr := goolmPkSigning.Sign(message)
|
||||
|
||||
assert.Equal(t, goolmErr, libolmErr)
|
||||
assert.Equal(t, goolmResult, libolmResult)
|
||||
})
|
||||
}
|
|
@ -1,110 +1,28 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// When the goolm build flag is enabled, this file will make [PKSigning]
|
||||
// constructors use the goolm constuctors.
|
||||
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Session stores an end to end encrypted messaging session.
|
||||
type Session struct {
|
||||
session.OlmSession
|
||||
}
|
||||
import "maunium.net/go/mautrix/crypto/goolm/session"
|
||||
|
||||
// SessionFromPickled loads a Session from a pickled base64 string. Decrypts
|
||||
// the Session using the supplied key. Returns error on failure.
|
||||
func SessionFromPickled(pickled, key []byte) (*Session, error) {
|
||||
func SessionFromPickled(pickled, key []byte) (Session, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
s := session.NewOlmSession()
|
||||
return s, s.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
func NewBlankSession() *Session {
|
||||
return &Session{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this Session.
|
||||
func (s *Session) Clear() error {
|
||||
s.OlmSession = session.OlmSession{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns a Session as a base64 string. Encrypts the Session using the
|
||||
// supplied key.
|
||||
func (s *Session) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := s.OlmSession.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
func (s *Session) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
return EmptyInput
|
||||
}
|
||||
sOlm, err := session.OlmSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.OlmSession = *sOlm
|
||||
return nil
|
||||
}
|
||||
|
||||
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match. Returns error on
|
||||
// failure.
|
||||
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
||||
return s.MatchesInboundSessionFrom("", oneTimeKeyMsg)
|
||||
}
|
||||
|
||||
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match. Returns error on
|
||||
// failure.
|
||||
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
|
||||
if theirIdentityKey != "" {
|
||||
theirKey := id.Curve25519(theirIdentityKey)
|
||||
return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg))
|
||||
}
|
||||
return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg))
|
||||
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message
|
||||
// as base64.
|
||||
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
||||
if len(plaintext) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
messageType, message, err := s.OlmSession.Encrypt(plaintext, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return messageType, message
|
||||
}
|
||||
|
||||
// Decrypt decrypts a message using the Session. Returns the the plain-text on
|
||||
// success. Returns error on failure.
|
||||
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
return s.OlmSession.Decrypt([]byte(message), msgType)
|
||||
}
|
||||
|
||||
// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
|
||||
func (s *Session) Describe() string {
|
||||
return s.OlmSession.Describe()
|
||||
func NewBlankSession() Session {
|
||||
return session.NewOlmSession()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type Session interface {
|
||||
// Pickle returns a Session as a base64 string. Encrypts the Session using
|
||||
// the supplied key.
|
||||
Pickle(key []byte) ([]byte, error)
|
||||
|
||||
// Unpickle loads a Session from a pickled base64 string. Decrypts the
|
||||
// Session using the supplied key.
|
||||
Unpickle(pickled, key []byte) error
|
||||
|
||||
// ID returns an identifier for this Session. Will be the same for both
|
||||
// ends of the conversation.
|
||||
ID() id.SessionID
|
||||
|
||||
// HasReceivedMessage returns true if this session has received any
|
||||
// message.
|
||||
HasReceivedMessage() bool
|
||||
|
||||
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the
|
||||
// session matches. Returns false if the session does not match. Returns
|
||||
// error on failure. If the base64 couldn't be decoded then the error will
|
||||
// be "INVALID_BASE64". If the message was for an unsupported protocol
|
||||
// version then the error will be "BAD_MESSAGE_VERSION". If the message
|
||||
// couldn't be decoded then then the error will be "BAD_MESSAGE_FORMAT".
|
||||
MatchesInboundSession(oneTimeKeyMsg string) (bool, error)
|
||||
|
||||
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this
|
||||
// in-bound Session. This can happen if multiple messages are sent to this
|
||||
// Account before this Account sends a message in reply. Returns true if
|
||||
// the session matches. Returns false if the session does not match.
|
||||
// Returns error on failure. If the base64 couldn't be decoded then the
|
||||
// error will be "INVALID_BASE64". If the message was for an unsupported
|
||||
// protocol version then the error will be "BAD_MESSAGE_VERSION". If the
|
||||
// message couldn't be decoded then then the error will be
|
||||
// "BAD_MESSAGE_FORMAT".
|
||||
MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error)
|
||||
|
||||
// EncryptMsgType returns the type of the next message that Encrypt will
|
||||
// return. Returns MsgTypePreKey if the message will be a PRE_KEY message.
|
||||
// Returns MsgTypeMsg if the message will be a normal message.
|
||||
EncryptMsgType() id.OlmMsgType
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted
|
||||
// message as base64.
|
||||
Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error)
|
||||
|
||||
// Decrypt decrypts a message using the Session. Returns the plain-text on
|
||||
// success. Returns error on failure. If the base64 couldn't be decoded
|
||||
// then the error will be "INVALID_BASE64". If the message is for an
|
||||
// unsupported version of the protocol then the error will be
|
||||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error
|
||||
// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then
|
||||
// the error will be "BAD_MESSAGE_MAC".
|
||||
Decrypt(message string, msgType id.OlmMsgType) ([]byte, error)
|
||||
|
||||
// Describe generates a string describing the internal state of an olm
|
||||
// session for debugging and logging purposes.
|
||||
Describe() string
|
||||
}
|
||||
|
||||
var _ Session = (*session.OlmSession)(nil)
|
|
@ -1,3 +1,9 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
@ -24,12 +30,15 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Session stores an end to end encrypted messaging session.
|
||||
type Session struct {
|
||||
// LibOlmSession stores an end to end encrypted messaging session.
|
||||
type LibOlmSession struct {
|
||||
int *C.OlmSession
|
||||
mem []byte
|
||||
}
|
||||
|
||||
// Ensure that LibOlmSession implements Session.
|
||||
var _ Session = (*LibOlmSession)(nil)
|
||||
|
||||
// sessionSize is the size of a session object in bytes.
|
||||
func sessionSize() uint {
|
||||
return uint(C.olm_session_size())
|
||||
|
@ -40,7 +49,7 @@ func sessionSize() uint {
|
|||
// doesn't match the one used to encrypt the Session then the error will be
|
||||
// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be
|
||||
// "INVALID_BASE64".
|
||||
func SessionFromPickled(pickled, key []byte) (*Session, error) {
|
||||
func SessionFromPickled(pickled, key []byte) (Session, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -48,22 +57,26 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) {
|
|||
return s, s.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
func NewBlankSession() *Session {
|
||||
func NewBlankLibOlmSession() *LibOlmSession {
|
||||
memory := make([]byte, sessionSize())
|
||||
return &Session{
|
||||
return &LibOlmSession{
|
||||
int: C.olm_session(unsafe.Pointer(&memory[0])),
|
||||
mem: memory,
|
||||
}
|
||||
}
|
||||
|
||||
func NewBlankSession() Session {
|
||||
return NewBlankLibOlmSession()
|
||||
}
|
||||
|
||||
// lastError returns an error describing the most recent error to happen to a
|
||||
// session.
|
||||
func (s *Session) lastError() error {
|
||||
func (s *LibOlmSession) lastError() error {
|
||||
return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int))))
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this Session.
|
||||
func (s *Session) Clear() error {
|
||||
func (s *LibOlmSession) Clear() error {
|
||||
r := C.olm_clear_session((*C.OlmSession)(s.int))
|
||||
if r == errorVal() {
|
||||
return s.lastError()
|
||||
|
@ -72,31 +85,31 @@ func (s *Session) Clear() error {
|
|||
}
|
||||
|
||||
// pickleLen returns the number of bytes needed to store a session.
|
||||
func (s *Session) pickleLen() uint {
|
||||
func (s *LibOlmSession) pickleLen() uint {
|
||||
return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int)))
|
||||
}
|
||||
|
||||
// createOutboundRandomLen returns the number of random bytes needed to create
|
||||
// an outbound session.
|
||||
func (s *Session) createOutboundRandomLen() uint {
|
||||
func (s *LibOlmSession) createOutboundRandomLen() uint {
|
||||
return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int)))
|
||||
}
|
||||
|
||||
// idLen returns the length of the buffer needed to return the id for this
|
||||
// session.
|
||||
func (s *Session) idLen() uint {
|
||||
func (s *LibOlmSession) idLen() uint {
|
||||
return uint(C.olm_session_id_length((*C.OlmSession)(s.int)))
|
||||
}
|
||||
|
||||
// encryptRandomLen returns the number of random bytes needed to encrypt the
|
||||
// next message.
|
||||
func (s *Session) encryptRandomLen() uint {
|
||||
func (s *LibOlmSession) encryptRandomLen() uint {
|
||||
return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int)))
|
||||
}
|
||||
|
||||
// encryptMsgLen returns the size of the next message in bytes for the given
|
||||
// number of plain-text bytes.
|
||||
func (s *Session) encryptMsgLen(plainTextLen int) uint {
|
||||
func (s *LibOlmSession) encryptMsgLen(plainTextLen int) uint {
|
||||
return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen)))
|
||||
}
|
||||
|
||||
|
@ -107,7 +120,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint {
|
|||
// unsupported version of the protocol then the error will be
|
||||
// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error
|
||||
// will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
|
||||
func (s *LibOlmSession) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
|
||||
if len(message) == 0 {
|
||||
return 0, EmptyInput
|
||||
}
|
||||
|
@ -124,9 +137,9 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType)
|
|||
|
||||
// Pickle returns a Session as a base64 string. Encrypts the Session using the
|
||||
// supplied key.
|
||||
func (s *Session) Pickle(key []byte) []byte {
|
||||
func (s *LibOlmSession) Pickle(key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
return nil, NoKeyProvided
|
||||
}
|
||||
pickled := make([]byte, s.pickleLen())
|
||||
r := C.olm_pickle_session(
|
||||
|
@ -138,10 +151,12 @@ func (s *Session) Pickle(key []byte) []byte {
|
|||
if r == errorVal() {
|
||||
panic(s.lastError())
|
||||
}
|
||||
return pickled[:r]
|
||||
return pickled[:r], nil
|
||||
}
|
||||
|
||||
func (s *Session) Unpickle(pickled, key []byte) error {
|
||||
// Unpickle unpickles the base64-encoded Olm session decrypting it with the
|
||||
// provided key. This function mutates the input pickled data slice.
|
||||
func (s *LibOlmSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
}
|
||||
|
@ -158,18 +173,21 @@ func (s *Session) Unpickle(pickled, key []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) GobEncode() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
func (s *LibOlmSession) GobEncode() ([]byte, error) {
|
||||
pickled, err := s.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
rawPickled := make([]byte, length)
|
||||
_, err := base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) GobDecode(rawPickled []byte) error {
|
||||
func (s *LibOlmSession) GobDecode(rawPickled []byte) error {
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankSession()
|
||||
*s = *NewBlankLibOlmSession()
|
||||
}
|
||||
length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
|
||||
pickled := make([]byte, length)
|
||||
|
@ -178,8 +196,11 @@ func (s *Session) GobDecode(rawPickled []byte) error {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) MarshalJSON() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
func (s *LibOlmSession) MarshalJSON() ([]byte, error) {
|
||||
pickled, err := s.Pickle(pickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
quotes[0] = '"'
|
||||
quotes[len(quotes)-1] = '"'
|
||||
|
@ -188,19 +209,19 @@ func (s *Session) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
func (s *LibOlmSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
}
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankSession()
|
||||
*s = *NewBlankLibOlmSession()
|
||||
}
|
||||
return s.Unpickle(data[1:len(data)-1], pickleKey)
|
||||
}
|
||||
|
||||
// Id returns an identifier for this Session. Will be the same for both ends
|
||||
// of the conversation.
|
||||
func (s *Session) ID() id.SessionID {
|
||||
func (s *LibOlmSession) ID() id.SessionID {
|
||||
sessionID := make([]byte, s.idLen())
|
||||
r := C.olm_session_id(
|
||||
(*C.OlmSession)(s.int),
|
||||
|
@ -213,7 +234,7 @@ func (s *Session) ID() id.SessionID {
|
|||
}
|
||||
|
||||
// HasReceivedMessage returns true if this session has received any message.
|
||||
func (s *Session) HasReceivedMessage() bool {
|
||||
func (s *LibOlmSession) HasReceivedMessage() bool {
|
||||
switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) {
|
||||
case 0:
|
||||
return false
|
||||
|
@ -230,7 +251,7 @@ func (s *Session) HasReceivedMessage() bool {
|
|||
// "INVALID_BASE64". If the message was for an unsupported protocol version
|
||||
// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be
|
||||
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
||||
func (s *LibOlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return false, EmptyInput
|
||||
}
|
||||
|
@ -255,7 +276,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
|||
// "INVALID_BASE64". If the message was for an unsupported protocol version
|
||||
// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be
|
||||
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
|
||||
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
|
||||
func (s *LibOlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
|
||||
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
|
||||
return false, EmptyInput
|
||||
}
|
||||
|
@ -278,7 +299,7 @@ func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg stri
|
|||
// return. Returns MsgTypePreKey if the message will be a PRE_KEY message.
|
||||
// Returns MsgTypeMsg if the message will be a normal message. Returns error
|
||||
// on failure.
|
||||
func (s *Session) EncryptMsgType() id.OlmMsgType {
|
||||
func (s *LibOlmSession) EncryptMsgType() id.OlmMsgType {
|
||||
switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) {
|
||||
case C.size_t(id.OlmMsgTypePreKey):
|
||||
return id.OlmMsgTypePreKey
|
||||
|
@ -291,15 +312,16 @@ func (s *Session) EncryptMsgType() id.OlmMsgType {
|
|||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message
|
||||
// as base64.
|
||||
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
||||
func (s *LibOlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
panic(EmptyInput)
|
||||
return 0, nil, EmptyInput
|
||||
}
|
||||
// Make the slice be at least length 1
|
||||
random := make([]byte, s.encryptRandomLen()+1)
|
||||
_, err := rand.Read(random)
|
||||
if err != nil {
|
||||
panic(NotEnoughGoRandom)
|
||||
// TODO can we just return err here?
|
||||
return 0, nil, NotEnoughGoRandom
|
||||
}
|
||||
messageType := s.EncryptMsgType()
|
||||
message := make([]byte, s.encryptMsgLen(len(plaintext)))
|
||||
|
@ -312,9 +334,9 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
|||
unsafe.Pointer(&message[0]),
|
||||
C.size_t(len(message)))
|
||||
if r == errorVal() {
|
||||
panic(s.lastError())
|
||||
return 0, nil, s.lastError()
|
||||
}
|
||||
return messageType, message[:r]
|
||||
return messageType, message[:r], nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a message using the Session. Returns the the plain-text on
|
||||
|
@ -324,7 +346,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
|||
// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT".
|
||||
// If the MAC on the message was invalid then the error will be
|
||||
// "BAD_MESSAGE_MAC".
|
||||
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
|
||||
func (s *LibOlmSession) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
|
@ -351,7 +373,7 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error)
|
|||
const maxDescribeSize = 600
|
||||
|
||||
// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
|
||||
func (s *Session) Describe() string {
|
||||
func (s *LibOlmSession) Describe() string {
|
||||
desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize)))
|
||||
defer C.free(unsafe.Pointer(desc))
|
||||
C.meowlm_session_describe(
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright (c) 2024 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Only run this test if goolm is disabled (that is, libolm is used).
|
||||
//go:build !goolm
|
||||
|
||||
package olm_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
goolmsession "maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
)
|
||||
|
||||
func TestBlankSession(t *testing.T) {
|
||||
libolmSession := olm.NewBlankLibOlmSession()
|
||||
goolmSession := goolmsession.NewOlmSession()
|
||||
|
||||
assert.Equal(t, libolmSession.ID(), goolmSession.ID())
|
||||
assert.Equal(t, libolmSession.HasReceivedMessage(), goolmSession.HasReceivedMessage())
|
||||
assert.Equal(t, libolmSession.EncryptMsgType(), goolmSession.EncryptMsgType())
|
||||
assert.Equal(t, libolmSession.Describe(), goolmSession.Describe())
|
||||
|
||||
libolmPickled, err := libolmSession.Pickle([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
goolmPickled, err := goolmSession.Pickle([]byte("test"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, goolmPickled, libolmPickled)
|
||||
}
|
||||
|
||||
func TestSessionPickle(t *testing.T) {
|
||||
pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT")
|
||||
pickleKey := []byte("secret_key")
|
||||
|
||||
goolmSession := goolmsession.NewOlmSession()
|
||||
err := goolmSession.Unpickle(pickledDataFromLibOlm, pickleKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
libolmSession := olm.NewBlankLibOlmSession()
|
||||
err = libolmSession.Unpickle(pickledDataFromLibOlm, pickleKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Reset the pickle data since libolmSession.Unpickle modifies it.
|
||||
pickledDataFromLibOlm = []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT")
|
||||
|
||||
goolmPickled, err := goolmSession.Pickle(pickleKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pickledDataFromLibOlm, goolmPickled)
|
||||
|
||||
libolmPickled, err := libolmSession.Pickle(pickleKey)
|
||||
assert.Equal(t, pickledDataFromLibOlm, libolmPickled)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// func FuzzSession(f *testing.F) {
|
||||
// f.Add([]byte("plaintext"))
|
||||
|
||||
// identityKeyAlice, err := crypto.Curve25519GenerateKey(nil)
|
||||
// require.NoError(f, err)
|
||||
// identityKeyBob, err := crypto.Curve25519GenerateKey(nil)
|
||||
// require.NoError(f, err)
|
||||
|
||||
// f.Fuzz(func(t *testing.T, plaintext []byte) {
|
||||
// // identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey
|
||||
|
||||
// goolmSession, err := goolmsession.NewOutboundOlmSession(identityKeyAlice, identityKeyBob.PublicKey, otk)
|
||||
// assert.NoError(t, err)
|
||||
|
||||
// libolmAccount := olm.NewAccount()
|
||||
// libolmSession, err := libolmAccount.NewInboundSessionFrom(id.Curve25519(identityKeyBob.PublicKey), string(otk))
|
||||
|
||||
// goolmMsgType, goolmCiphertext, goolmErr := goolmSession.Encrypt(plaintext)
|
||||
// assert.NoError(t, goolmErr)
|
||||
|
||||
// libolmMsgType, libolmCiphertext, libolmErr := libolmSession.Encrypt(plaintext)
|
||||
// assert.NoError(t, libolmErr)
|
||||
|
||||
// assert.Equal(t, goolmMsgType, libolmMsgType)
|
||||
// assert.Equal(t, goolmCiphertext, libolmCiphertext)
|
||||
// })
|
||||
// }
|
|
@ -54,9 +54,9 @@ func (session *OlmSession) Describe() string {
|
|||
return session.Internal.Describe()
|
||||
}
|
||||
|
||||
func wrapSession(session *olm.Session) *OlmSession {
|
||||
func wrapSession(session olm.Session) *OlmSession {
|
||||
return &OlmSession{
|
||||
Internal: *session,
|
||||
Internal: session,
|
||||
ExpirationMixin: ExpirationMixin{
|
||||
TimeMixin: TimeMixin{
|
||||
CreationTime: time.Now(),
|
||||
|
@ -68,7 +68,7 @@ func wrapSession(session *olm.Session) *OlmSession {
|
|||
}
|
||||
|
||||
func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) {
|
||||
session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext)
|
||||
session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, cipher
|
|||
return wrapSession(session), nil
|
||||
}
|
||||
|
||||
func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
||||
func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
|
||||
session.LastEncryptedTime = time.Now()
|
||||
return session.Internal.Encrypt(plaintext)
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
|
|||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
Internal: igs,
|
||||
SigningKey: signingKey,
|
||||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
|
@ -148,7 +148,7 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
igs.Internal = *imported
|
||||
igs.Internal = imported
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -182,7 +182,7 @@ type OutboundGroupSession struct {
|
|||
|
||||
func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession {
|
||||
ogs := &OutboundGroupSession{
|
||||
Internal: *olm.NewOutboundGroupSession(),
|
||||
Internal: olm.NewOutboundGroupSession(),
|
||||
ExpirationMixin: ExpirationMixin{
|
||||
TimeMixin: TimeMixin{
|
||||
CreationTime: time.Now(),
|
||||
|
@ -237,7 +237,7 @@ func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) {
|
|||
}
|
||||
ogs.MessageCount++
|
||||
ogs.LastEncryptedTime = time.Now()
|
||||
return ogs.Internal.Encrypt(plaintext), nil
|
||||
return ogs.Internal.Encrypt(plaintext)
|
||||
}
|
||||
|
||||
type TimeMixin struct {
|
||||
|
|
|
@ -16,13 +16,26 @@ import (
|
|||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) {
|
||||
secret, err = mach.CryptoStore.GetSecret(ctx, name)
|
||||
if err != nil || secret != "" {
|
||||
return
|
||||
// Callback function to process a received secret.
|
||||
//
|
||||
// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses.
|
||||
type SecretReceiverFunc func(string) (bool, error)
|
||||
|
||||
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// always offer our stored secret first, if any
|
||||
secret, err := mach.CryptoStore.GetSecret(ctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if secret != "" {
|
||||
if ok, err := receiver(secret); ok || err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
requestID, secretChan := random.String(64), make(chan string, 1)
|
||||
requestID, secretChan := random.String(64), make(chan string, 5)
|
||||
mach.secretLock.Lock()
|
||||
mach.secretListeners[requestID] = secretChan
|
||||
mach.secretLock.Unlock()
|
||||
|
@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret,
|
|||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case <-time.After(timeout):
|
||||
case secret = <-secretChan:
|
||||
}
|
||||
// best effort cancel request from all devices when returning
|
||||
defer func() {
|
||||
go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
|
||||
Action: event.SecretRequestCancellation,
|
||||
RequestID: requestID,
|
||||
RequestingDeviceID: mach.Client.DeviceID,
|
||||
})
|
||||
}()
|
||||
|
||||
if secret != "" {
|
||||
err = mach.CryptoStore.PutSecret(ctx, name, secret)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case secret = <-secretChan:
|
||||
if ok, err := receiver(secret); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
return mach.CryptoStore.PutSecret(ctx, name, secret)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) {
|
||||
|
@ -121,9 +144,11 @@ func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserI
|
|||
return
|
||||
} else if secret != "" {
|
||||
log.Debug().Msg("Responding to secret request")
|
||||
mach.sendToOneDevice(ctx, mach.Client.UserID, content.RequestingDeviceID, event.ToDeviceSecretRequest, &event.SecretSendEventContent{
|
||||
RequestID: content.RequestID,
|
||||
Secret: secret,
|
||||
mach.SendEncryptedToDevice(ctx, device, event.ToDeviceSecretSend, event.Content{
|
||||
Parsed: event.SecretSendEventContent{
|
||||
RequestID: content.RequestID,
|
||||
Secret: secret,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
log.Debug().Msg("No stored secret found, secret request ignored")
|
||||
|
@ -157,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven
|
|||
return
|
||||
}
|
||||
|
||||
// secret channel is buffered and we don't want to block
|
||||
// at worst we drop _some_ of the responses
|
||||
select {
|
||||
case secretChan <- content.Secret:
|
||||
default:
|
||||
}
|
||||
|
||||
// best effort cancel this for all other targets
|
||||
go func() {
|
||||
mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
|
||||
Action: event.SecretRequestCancellation,
|
||||
RequestID: content.RequestID,
|
||||
RequestingDeviceID: mach.Client.DeviceID,
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -123,8 +123,11 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi
|
|||
// PutAccount stores an OlmAccount in the database.
|
||||
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
|
||||
store.Account = account
|
||||
bytes := account.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
bytes, err := account.Internal.Pickle(store.PickleKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
|
||||
account=excluded.account, account_id=excluded.account_id,
|
||||
|
@ -137,7 +140,7 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount
|
|||
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
|
||||
if store.Account == nil {
|
||||
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
|
||||
acc := &OlmAccount{Internal: olm.NewBlankAccount()}
|
||||
var accountBytes []byte
|
||||
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
|
@ -183,7 +186,7 @@ func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey)
|
|||
defer store.olmSessionCacheLock.Unlock()
|
||||
cache := store.getOlmSessionCache(key)
|
||||
for rows.Next() {
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession()}
|
||||
sess := OlmSession{Internal: olm.NewBlankSession()}
|
||||
var sessionBytes []byte
|
||||
var sessionID id.SessionID
|
||||
err = rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime)
|
||||
|
@ -220,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender
|
|||
row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
|
||||
key, store.AccountID)
|
||||
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession()}
|
||||
sess := OlmSession{Internal: olm.NewBlankSession()}
|
||||
var sessionBytes []byte
|
||||
var sessionID id.SessionID
|
||||
|
||||
|
@ -246,8 +249,11 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender
|
|||
func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID)
|
||||
store.getOlmSessionCache(key)[session.ID()] = session
|
||||
return err
|
||||
|
@ -255,8 +261,11 @@ func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, s
|
|||
|
||||
// UpdateSession replaces the Olm session for a sender in the database.
|
||||
func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
|
||||
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
|
||||
sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
@ -277,7 +286,10 @@ func datePtr(t time.Time) *time.Time {
|
|||
|
||||
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
|
||||
if err != nil {
|
||||
|
@ -335,7 +347,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
|
|||
senderKey = id.Curve25519(senderKeyDB.String)
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
Internal: igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
|
@ -447,7 +459,7 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
|
||||
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
|
||||
igs = olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
|
@ -480,7 +492,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
|
|||
}
|
||||
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
Internal: igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
|
@ -523,8 +535,11 @@ func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context
|
|||
|
||||
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_megolm_outbound_session
|
||||
(room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
|
@ -539,8 +554,11 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, sessio
|
|||
|
||||
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
|
||||
sessionBytes, err := session.Internal.Pickle(store.PickleKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
|
||||
sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
@ -565,7 +583,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ogs.Internal = *intOGS
|
||||
ogs.Internal = intOGS
|
||||
ogs.RoomID = roomID
|
||||
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
|
||||
return &ogs, nil
|
||||
|
|
|
@ -115,7 +115,7 @@ func TestStoreOlmSession(t *testing.T) {
|
|||
|
||||
olmSess := OlmSession{
|
||||
id: olmSessID,
|
||||
Internal: *olmInternal,
|
||||
Internal: olmInternal,
|
||||
}
|
||||
err = store.AddSession(context.TODO(), olmSessID, &olmSess)
|
||||
if err != nil {
|
||||
|
@ -133,7 +133,13 @@ func TestStoreOlmSession(t *testing.T) {
|
|||
if retrieved.ID() != olmSessID {
|
||||
t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID())
|
||||
}
|
||||
if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled {
|
||||
|
||||
pickled, err := retrieved.Internal.Pickle([]byte("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("Error pickling Olm session: %v", err)
|
||||
}
|
||||
|
||||
if string(pickled) != olmPickled {
|
||||
t.Error("Pickled Olm session does not match original")
|
||||
}
|
||||
})
|
||||
|
@ -152,7 +158,7 @@ func TestStoreMegolmSession(t *testing.T) {
|
|||
}
|
||||
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *internal,
|
||||
Internal: internal,
|
||||
SigningKey: acc.SigningKey(),
|
||||
SenderKey: acc.IdentityKey(),
|
||||
RoomID: "room1",
|
||||
|
@ -168,7 +174,9 @@ func TestStoreMegolmSession(t *testing.T) {
|
|||
t.Errorf("Error retrieving inbound group session: %v", err)
|
||||
}
|
||||
|
||||
if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession {
|
||||
if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil {
|
||||
t.Fatalf("Error pickling inbound group session: %v", err)
|
||||
} else if string(pickled) != groupSession {
|
||||
t.Error("Pickled inbound group session does not match original")
|
||||
}
|
||||
})
|
||||
|
|
|
@ -337,6 +337,12 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI
|
|||
devices, err := vh.mach.CryptoStore.GetDevices(ctx, to)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get devices for user: %w", err)
|
||||
} else if len(devices) == 0 {
|
||||
// HACK: we are doing this because the client doesn't wait until it has
|
||||
// the devices before starting verification.
|
||||
if _, err := vh.mach.FetchKeys(ctx, []id.UserID{to}, true); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
vh.getLog(ctx).Info().
|
||||
|
|
|
@ -149,5 +149,6 @@ type Unsigned struct {
|
|||
|
||||
func (us *Unsigned) IsEmpty() bool {
|
||||
return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 &&
|
||||
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil
|
||||
us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil &&
|
||||
us.BeeperHSOrder == 0
|
||||
}
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package federation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"go.mau.fi/util/jsontime"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type ServerVersion struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ServerKeyProvider is an interface that returns private server keys for server key requests.
|
||||
type ServerKeyProvider interface {
|
||||
Get(r *http.Request) (serverName string, key *SigningKey)
|
||||
}
|
||||
|
||||
// StaticServerKey is an implementation of [ServerKeyProvider] that always returns the same server name and key.
|
||||
type StaticServerKey struct {
|
||||
ServerName string
|
||||
Key *SigningKey
|
||||
}
|
||||
|
||||
func (ssk *StaticServerKey) Get(r *http.Request) (serverName string, key *SigningKey) {
|
||||
return ssk.ServerName, ssk.Key
|
||||
}
|
||||
|
||||
// KeyServer implements a basic Matrix key server that can serve its own keys, plus the federation version endpoint.
|
||||
//
|
||||
// It does not implement querying keys of other servers, nor any other federation endpoints.
|
||||
type KeyServer struct {
|
||||
KeyProvider ServerKeyProvider
|
||||
Version ServerVersion
|
||||
WellKnownTarget string
|
||||
}
|
||||
|
||||
// Register registers the key server endpoints to the given router.
|
||||
func (ks *KeyServer) Register(r *mux.Router) {
|
||||
r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet)
|
||||
r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet)
|
||||
keyRouter := r.PathPrefix("/_matrix/key").Subrouter()
|
||||
keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet)
|
||||
keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet)
|
||||
keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost)
|
||||
keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
|
||||
ErrCode: mautrix.MUnrecognized.ErrCode,
|
||||
Err: "Unrecognized endpoint",
|
||||
})
|
||||
})
|
||||
keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{
|
||||
ErrCode: mautrix.MUnrecognized.ErrCode,
|
||||
Err: "Invalid method for endpoint",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func jsonResponse(w http.ResponseWriter, code int, data any) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
_ = json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint.
|
||||
type RespWellKnown struct {
|
||||
Server string `json:"m.server"`
|
||||
}
|
||||
|
||||
// GetWellKnown implements the `GET /.well-known/matrix/server` endpoint
|
||||
//
|
||||
// https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver
|
||||
func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) {
|
||||
if ks.WellKnownTarget == "" {
|
||||
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
|
||||
ErrCode: mautrix.MNotFound.ErrCode,
|
||||
Err: "No well-known target set",
|
||||
})
|
||||
} else {
|
||||
jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget})
|
||||
}
|
||||
}
|
||||
|
||||
// RespServerVersion is the response body for the `GET /_matrix/federation/v1/version` endpoint
|
||||
type RespServerVersion struct {
|
||||
Server ServerVersion `json:"server"`
|
||||
}
|
||||
|
||||
// GetServerVersion implements the `GET /_matrix/federation/v1/version` endpoint
|
||||
//
|
||||
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version
|
||||
func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) {
|
||||
jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version})
|
||||
}
|
||||
|
||||
// GetServerKey implements the `GET /_matrix/key/v2/server` endpoint.
|
||||
//
|
||||
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2server
|
||||
func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) {
|
||||
domain, key := ks.KeyProvider.Get(r)
|
||||
if key == nil {
|
||||
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
|
||||
ErrCode: mautrix.MNotFound.ErrCode,
|
||||
Err: fmt.Sprintf("No signing key found for %q", r.Host),
|
||||
})
|
||||
} else {
|
||||
jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// ReqQueryKeys is the request body for the `POST /_matrix/key/v2/query` endpoint
|
||||
type ReqQueryKeys struct {
|
||||
ServerKeys map[string]map[id.KeyID]QueryKeysCriteria `json:"server_keys"`
|
||||
}
|
||||
|
||||
type QueryKeysCriteria struct {
|
||||
MinimumValidUntilTS jsontime.UnixMilli `json:"minimum_valid_until_ts"`
|
||||
}
|
||||
|
||||
// PostQueryKeysResponse is the response body for the `POST /_matrix/key/v2/query` endpoint
|
||||
type PostQueryKeysResponse struct {
|
||||
ServerKeys map[string]*ServerKeyResponse `json:"server_keys"`
|
||||
}
|
||||
|
||||
// PostQueryKeys implements the `POST /_matrix/key/v2/query` endpoint
|
||||
//
|
||||
// https://spec.matrix.org/v1.9/server-server-api/#post_matrixkeyv2query
|
||||
func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) {
|
||||
var req ReqQueryKeys
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
|
||||
ErrCode: mautrix.MBadJSON.ErrCode,
|
||||
Err: fmt.Sprintf("failed to parse request: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := &PostQueryKeysResponse{
|
||||
ServerKeys: make(map[string]*ServerKeyResponse),
|
||||
}
|
||||
for serverName, keys := range req.ServerKeys {
|
||||
domain, key := ks.KeyProvider.Get(r)
|
||||
if domain != serverName {
|
||||
continue
|
||||
}
|
||||
for keyID, criteria := range keys {
|
||||
if key.ID == keyID && criteria.MinimumValidUntilTS.Before(time.Now().Add(24*time.Hour)) {
|
||||
resp.ServerKeys[serverName] = key.GenerateKeyResponse(serverName, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
jsonResponse(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint
|
||||
type GetQueryKeysResponse struct {
|
||||
ServerKeys []*ServerKeyResponse `json:"server_keys"`
|
||||
}
|
||||
|
||||
// GetQueryKeys implements the `GET /_matrix/key/v2/query/{serverName}` endpoint
|
||||
//
|
||||
// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername
|
||||
func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) {
|
||||
serverName := mux.Vars(r)["serverName"]
|
||||
minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts")
|
||||
minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64)
|
||||
if err != nil && minimumValidUntilTSString != "" {
|
||||
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
|
||||
ErrCode: mautrix.MInvalidParam.ErrCode,
|
||||
Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err),
|
||||
})
|
||||
return
|
||||
} else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) {
|
||||
jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{
|
||||
ErrCode: mautrix.MInvalidParam.ErrCode,
|
||||
Err: "minimum_valid_until_ts may not be more than 24 hours in the future",
|
||||
})
|
||||
return
|
||||
}
|
||||
resp := &GetQueryKeysResponse{
|
||||
ServerKeys: []*ServerKeyResponse{},
|
||||
}
|
||||
if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName {
|
||||
resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil))
|
||||
}
|
||||
jsonResponse(w, http.StatusOK, resp)
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package federation
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/util/jsontime"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// SigningKey is a Matrix federation signing key pair.
|
||||
type SigningKey struct {
|
||||
ID id.KeyID
|
||||
Pub id.SigningKey
|
||||
Priv ed25519.PrivateKey
|
||||
}
|
||||
|
||||
// SynapseString returns a string representation of the private key compatible with Synapse's .signing.key file format.
|
||||
//
|
||||
// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function.
|
||||
func (sk *SigningKey) SynapseString() string {
|
||||
alg, id := sk.ID.Parse()
|
||||
return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed()))
|
||||
}
|
||||
|
||||
// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey.
|
||||
func ParseSynapseKey(key string) (*SigningKey, error) {
|
||||
parts := strings.Split(key, " ")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid key format (expected 3 space-separated parts, got %d)", len(parts))
|
||||
} else if parts[0] != string(id.KeyAlgorithmEd25519) {
|
||||
return nil, fmt.Errorf("unsupported key algorithm %s (only ed25519 is supported)", parts[0])
|
||||
}
|
||||
seed, err := base64.RawStdEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid private key: %w", err)
|
||||
}
|
||||
priv := ed25519.NewKeyFromSeed(seed)
|
||||
pub := base64.RawStdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey))
|
||||
return &SigningKey{
|
||||
ID: id.NewKeyID(id.KeyAlgorithmEd25519, parts[1]),
|
||||
Pub: id.SigningKey(pub),
|
||||
Priv: priv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateSigningKey generates a new random signing key.
|
||||
func GenerateSigningKey() *SigningKey {
|
||||
pub, priv, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &SigningKey{
|
||||
ID: id.NewKeyID(id.KeyAlgorithmEd25519, base64.RawURLEncoding.EncodeToString(pub[:4])),
|
||||
Pub: id.SigningKey(base64.RawStdEncoding.EncodeToString(pub)),
|
||||
Priv: priv,
|
||||
}
|
||||
}
|
||||
|
||||
// ServerKeyResponse is the response body for the `GET /_matrix/key/v2/server` endpoint.
|
||||
// It's also used inside the query endpoint response structs.
|
||||
type ServerKeyResponse struct {
|
||||
ServerName string `json:"server_name"`
|
||||
VerifyKeys map[id.KeyID]ServerVerifyKey `json:"verify_keys"`
|
||||
OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"`
|
||||
Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"`
|
||||
ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"`
|
||||
}
|
||||
|
||||
type ServerVerifyKey struct {
|
||||
Key id.SigningKey `json:"key"`
|
||||
}
|
||||
|
||||
type OldVerifyKey struct {
|
||||
Key id.SigningKey `json:"key"`
|
||||
ExpiredTS jsontime.UnixMilli `json:"expired_ts"`
|
||||
}
|
||||
|
||||
func (sk *SigningKey) SignJSON(data any) ([]byte, error) {
|
||||
marshaled, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sk.SignRawJSON(marshaled), nil
|
||||
}
|
||||
|
||||
func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte {
|
||||
return ed25519.Sign(sk.Priv, canonicaljson.CanonicalJSONAssumeValid(data))
|
||||
}
|
||||
|
||||
// GenerateKeyResponse generates a key response signed by this key with the given server name and optionally some old verify keys.
|
||||
func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[id.KeyID]OldVerifyKey) *ServerKeyResponse {
|
||||
skr := &ServerKeyResponse{
|
||||
ServerName: serverName,
|
||||
OldVerifyKeys: oldVerifyKeys,
|
||||
ValidUntilTS: jsontime.UM(time.Now().Add(24 * time.Hour)),
|
||||
VerifyKeys: map[id.KeyID]ServerVerifyKey{
|
||||
sk.ID: {Key: sk.Pub},
|
||||
},
|
||||
}
|
||||
signature, err := sk.SignJSON(skr)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
skr.Signatures = map[string]map[id.KeyID]string{
|
||||
serverName: {
|
||||
sk.ID: base64.RawURLEncoding.EncodeToString(signature),
|
||||
},
|
||||
}
|
||||
return skr
|
||||
}
|
|
@ -7,6 +7,7 @@
|
|||
package format
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
|
@ -33,14 +34,16 @@ func (ts TagStack) Has(tag string) bool {
|
|||
}
|
||||
|
||||
type Context struct {
|
||||
Ctx context.Context
|
||||
ReturnData map[string]any
|
||||
TagStack TagStack
|
||||
|
||||
PreserveWhitespace bool
|
||||
}
|
||||
|
||||
func NewContext() Context {
|
||||
func NewContext(ctx context.Context) Context {
|
||||
return Context{
|
||||
Ctx: ctx,
|
||||
ReturnData: map[string]any{},
|
||||
TagStack: make(TagStack, 0, 4),
|
||||
}
|
||||
|
@ -411,7 +414,7 @@ func HTMLToText(html string) string {
|
|||
Newline: "\n",
|
||||
HorizontalLine: "\n---\n",
|
||||
PillConverter: DefaultPillConverter,
|
||||
}).Parse(html, NewContext())
|
||||
}).Parse(html, NewContext(context.TODO()))
|
||||
}
|
||||
|
||||
// HTMLToMarkdown converts Matrix HTML into markdown with the default settings.
|
||||
|
@ -429,5 +432,5 @@ func HTMLToMarkdown(html string) string {
|
|||
}
|
||||
return fmt.Sprintf("[%s](%s)", text, href)
|
||||
},
|
||||
}).Parse(html, NewContext())
|
||||
}).Parse(html, NewContext(context.TODO()))
|
||||
}
|
||||
|
|
15
go.mod
15
go.mod
|
@ -8,18 +8,17 @@ require (
|
|||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/rs/zerolog v1.32.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.17.1
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/yuin/goldmark v1.7.0
|
||||
go.mau.fi/util v0.4.0
|
||||
github.com/yuin/goldmark v1.7.1
|
||||
go.mau.fi/util v0.4.2
|
||||
go.mau.fi/zeroconfig v0.1.2
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a
|
||||
golang.org/x/net v0.21.0
|
||||
golang.org/x/crypto v0.22.0
|
||||
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8
|
||||
golang.org/x/net v0.24.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
maunium.net/go/mauflag v1.0.0
|
||||
maunium.net/go/maulogger/v2 v2.4.1
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -30,6 +29,6 @@ require (
|
|||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
golang.org/x/sys v0.19.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
)
|
||||
|
|
30
go.sum
30
go.sum
|
@ -24,8 +24,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
|||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
|
||||
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
|
||||
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
|
@ -35,23 +35,23 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
|||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
|
||||
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
go.mau.fi/util v0.4.0 h1:S2X3qU4pUcb/vxBRfAuZjbrR9xVMAXSjQojNBLPBbhs=
|
||||
go.mau.fi/util v0.4.0/go.mod h1:leeiHtgVBuN+W9aDii3deAXnfC563iN3WK6BF8/AjNw=
|
||||
github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
|
||||
github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
|
||||
go.mau.fi/util v0.4.2 h1:RR3TOcRHmCF9Bx/3YG4S65MYfa+nV6/rn8qBWW4Mi30=
|
||||
go.mau.fi/util v0.4.2/go.mod h1:PlAVfUUcPyHPrwnvjkJM9UFcPE7qGPDJqk+Oufa1Gtw=
|
||||
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
|
||||
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE=
|
||||
golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc=
|
||||
golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
|
@ -60,5 +60,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
|
||||
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
|
||||
|
|
25
id/userid.go
25
id/userid.go
|
@ -36,19 +36,34 @@ var (
|
|||
ErrEmptyLocalpart = errors.New("empty localparts are not allowed")
|
||||
)
|
||||
|
||||
// ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format
|
||||
func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, localpart, homeserver string) {
|
||||
if len(identifier) == 0 {
|
||||
return
|
||||
}
|
||||
sigil = identifier[0]
|
||||
strIdentifier := string(identifier)
|
||||
if strings.ContainsRune(strIdentifier, ':') {
|
||||
parts := strings.SplitN(strIdentifier, ":", 2)
|
||||
localpart = parts[0][1:]
|
||||
homeserver = parts[1]
|
||||
} else {
|
||||
localpart = strIdentifier[1:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Parse parses the user ID into the localpart and server name.
|
||||
//
|
||||
// Note that this only enforces very basic user ID formatting requirements: user IDs start with
|
||||
// a @, and contain a : after the @. If you want to enforce localpart validity, see the
|
||||
// ParseAndValidate and ValidateUserLocalpart functions.
|
||||
func (userID UserID) Parse() (localpart, homeserver string, err error) {
|
||||
if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') {
|
||||
// This error wrapping lets you use errors.Is() nicely even though the message contains the user ID
|
||||
var sigil byte
|
||||
sigil, localpart, homeserver = ParseCommonIdentifier(userID)
|
||||
if sigil != '@' || homeserver == "" {
|
||||
err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID)
|
||||
return
|
||||
}
|
||||
parts := strings.SplitN(string(userID), ":", 2)
|
||||
localpart, homeserver = strings.TrimPrefix(parts[0], "@"), parts[1]
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -309,6 +309,12 @@ func (slr SyncLeftRoom) MarshalJSON() ([]byte, error) {
|
|||
return marshalAndDeleteEmpty((marshalableSyncLeftRoom)(slr), syncLeftRoomPathsToDelete)
|
||||
}
|
||||
|
||||
type BeeperInboxPreviewEvent struct {
|
||||
EventID id.EventID `json:"event_id"`
|
||||
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
|
||||
Event *event.Event `json:"event,omitempty"`
|
||||
}
|
||||
|
||||
type SyncJoinedRoom struct {
|
||||
Summary LazyLoadSummary `json:"summary"`
|
||||
State SyncEventsList `json:"state"`
|
||||
|
@ -319,6 +325,8 @@ type SyncJoinedRoom struct {
|
|||
UnreadNotifications *UnreadNotificationCounts `json:"unread_notifications,omitempty"`
|
||||
// https://github.com/matrix-org/matrix-spec-proposals/pull/2654
|
||||
MSC2654UnreadCount *int `json:"org.matrix.msc2654.unread_count,omitempty"`
|
||||
// Beeper extension
|
||||
BeeperInboxPreview *BeeperInboxPreviewEvent `json:"com.beeper.inbox.preview,omitempty"`
|
||||
}
|
||||
|
||||
type UnreadNotificationCounts struct {
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
const Version = "v0.18.0-beta.1"
|
||||
const Version = "v0.18.1"
|
||||
|
||||
var GoModVersion = ""
|
||||
var Commit = ""
|
||||
|
|
Loading…
Reference in New Issue