Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

crypto: run goolm side-by-side with libolm #314

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 81 additions & 12 deletions crypto/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,64 +7,108 @@
package crypto

import (
"bytes"
"encoding/json"
"fmt"

"github.com/tidwall/sjson"

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/crypto/libolm"

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (old, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (latest, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (old, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:

Check failure on line 19 in crypto/account.go

View workflow job for this annotation

GitHub Actions / Build (latest, goolm)

no required module provides package maunium.net/go/mautrix/crypto/libolm; to add it:
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id"
)

type OlmAccount struct {
Internal olm.Account
InternalLibolm olm.Account
InternalGoolm olm.Account
signingKey id.SigningKey
identityKey id.IdentityKey
Shared bool
KeyBackupVersion id.KeyBackupVersion
}

func NewOlmAccount() *OlmAccount {
account, err := olm.NewAccount()
libolmAccount, err := libolm.NewAccount()
if err != nil {
panic(err)
}
pickled, err := libolmAccount.Pickle([]byte("key"))
if err != nil {
panic(err)
}
goolmAccount, err := account.AccountFromPickled(pickled, []byte("key"))
if err != nil {
panic(err)
}
return &OlmAccount{
Internal: account,
InternalLibolm: libolmAccount,
InternalGoolm: goolmAccount,
}
}

func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) {
if len(account.signingKey) == 0 || len(account.identityKey) == 0 {
var err error
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys()
if err != nil {
panic(err)
}
goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys()
if err != nil {
panic(err)
}
if account.signingKey != goolmSigningKey {
panic("account signing keys not equal")
}
if account.identityKey != goolmIdentityKey {
panic("account identity keys not equal")
}
}
return account.signingKey, account.identityKey
}

func (account *OlmAccount) SigningKey() id.SigningKey {
if len(account.signingKey) == 0 {
var err error
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys()
if err != nil {
panic(err)
}
goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys()
if err != nil {
panic(err)
}
if account.signingKey != goolmSigningKey {
panic("account signing keys not equal")
}
if account.identityKey != goolmIdentityKey {
panic("account identity keys not equal")
}
}
return account.signingKey
}

func (account *OlmAccount) IdentityKey() id.IdentityKey {
if len(account.identityKey) == 0 {
var err error
account.signingKey, account.identityKey, err = account.Internal.IdentityKeys()
account.signingKey, account.identityKey, err = account.InternalLibolm.IdentityKeys()
if err != nil {
panic(err)
}
goolmSigningKey, goolmIdentityKey, err := account.InternalGoolm.IdentityKeys()
if err != nil {
panic(err)
}
if account.signingKey != goolmSigningKey {
panic("account signing keys not equal")
}
if account.identityKey != goolmIdentityKey {
panic("account identity keys not equal")
}
}
return account.identityKey
}
Expand All @@ -78,7 +122,15 @@
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signed, err := account.Internal.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
signed, err := account.InternalLibolm.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
goolmSigned, goolmErr := account.InternalGoolm.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
if err != nil {
if goolmErr == nil {
panic("libolm errored, but goolm did not on account.SignJSON")
}
} else if !bytes.Equal(signed, goolmSigned) {
panic("libolm and goolm signed are not equal in account.SignJSON")
}
return string(signed), err
}

Expand All @@ -102,19 +154,36 @@
return deviceKeys
}

func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
func (a *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey {
newCount := int(a.InternalLibolm.MaxNumberOfOneTimeKeys()/2) - currentOTKCount
if newCount > 0 {
account.Internal.GenOneTimeKeys(uint(newCount))
a.InternalLibolm.GenOneTimeKeys(uint(newCount))

pickled, err := a.InternalLibolm.Pickle([]byte("key"))
if err != nil {
panic(err)
}
a.InternalGoolm, err = account.AccountFromPickled(pickled, []byte("key"))
if err != nil {
panic(err)
}
}
oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey)
internalKeys, err := account.Internal.OneTimeKeys()
internalKeys, err := a.InternalLibolm.OneTimeKeys()
if err != nil {
panic(err)
}
goolmInternalKeys, err := a.InternalGoolm.OneTimeKeys()
if err != nil {
panic(err)
}
for keyID, key := range internalKeys {
if goolmInternalKeys[keyID] != key {
panic(fmt.Sprintf("key %s not found in getOneTimeKeys", keyID))
}

key := mautrix.OneTimeKey{Key: key}
signature, _ := account.SignJSON(key)
signature, _ := a.SignJSON(key)
key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
key.IsSigned = true
oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key
Expand Down
24 changes: 21 additions & 3 deletions crypto/decryptmegolm.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package crypto

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -203,7 +204,11 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co
log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt")
return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex)
}
firstKnown := sess.Internal.FirstKnownIndex()
firstKnown := sess.InternalLibolm.FirstKnownIndex()
firstKnownGoolm := sess.InternalGoolm.FirstKnownIndex()
if firstKnown != firstKnownGoolm {
panic(fmt.Sprintf("firstKnown not the same %d != %d", firstKnown, firstKnownGoolm))
}
log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger()
if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
log.Debug().Err(err).Msg("Failed to check if message index is duplicate")
Expand All @@ -228,7 +233,16 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
plaintextGoolm, messageIndexGoolm, errGoolm := sess.InternalGoolm.Decrypt(content.MegolmCiphertext)
plaintext, messageIndex, err := sess.InternalLibolm.Decrypt(content.MegolmCiphertext)
if !bytes.Equal(plaintextGoolm, plaintext) {
panic("plaintext different")
} else if messageIndexGoolm != messageIndex {
panic(fmt.Sprintf("message index different %d != %d", messageIndexGoolm, messageIndex))
} else if err != nil && errGoolm == nil {
panic(fmt.Sprintf("goolm didn't error %v", err))
}

if err != nil {
if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt {
messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content)
Expand Down Expand Up @@ -277,7 +291,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if len(sess.RatchetSafety.MissedIndices) > 0 {
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
}
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
ratchetCurrentIndexGoolm := sess.InternalGoolm.FirstKnownIndex()
ratchetCurrentIndex := sess.InternalLibolm.FirstKnownIndex()
if ratchetCurrentIndexGoolm != ratchetCurrentIndex {
panic(fmt.Sprintf("ratchet current index different %d != %d", ratchetCurrentIndexGoolm, ratchetCurrentIndex))
}
log := zerolog.Ctx(ctx).With().
Uint32("prev_ratchet_index", ratchetCurrentIndex).
Uint32("new_ratchet_index", ratchetTargetIndex).
Expand Down
10 changes: 9 additions & 1 deletion crypto/encryptolm.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,17 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
log.Error().Err(err).Msg("Failed to verify signature of one-time key")
} else if !ok {
log.Warn().Msg("One-time key has invalid signature from device")
} else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
} else if sess, err := mach.account.InternalLibolm.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
} else {
goolmSess, err := mach.account.InternalGoolm.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key)
if err != nil {
panic("goolm NewOutboundSession errored")
}
if sess.Describe() != goolmSess.Describe() {
panic("goolm NewOutboundSession and libolm NewOutboundSession returned different values")
}

wrapped := wrapSession(sess)
err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions crypto/goolm/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,10 @@ func (a *Account) UnpickleLibOlm(buf []byte) error {
} else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 {
return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion)
} else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair
fmt.Printf("123 %+v\n", err)
return err
} else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair
fmt.Printf("456 %+v\n", err)
return err
}

Expand Down
4 changes: 2 additions & 2 deletions crypto/goolm/libolmpickle/pickle.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ func Pickle(key, plaintext []byte) ([]byte, error) {

// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256.
func Unpickle(key, input []byte) ([]byte, error) {
ciphertext, err := goolmbase64.Decode(input)
decoded, err := goolmbase64.Decode(input)
if err != nil {
return nil, err
}
ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]
ciphertext, mac := decoded[:len(decoded)-pickleMACLength], decoded[len(decoded)-pickleMACLength:]
if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil {
return nil, err
} else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil {
Expand Down
13 changes: 10 additions & 3 deletions crypto/keybackup.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (

"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/libolm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id"
)
Expand Down Expand Up @@ -144,7 +145,12 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm)
}

igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
igsInternalGoolm, err := session.NewMegolmInboundSessionFromExport([]byte(keyBackupData.SessionKey))
if err != nil {
return nil, err
}

igsInternal, err := libolm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
if err != nil {
return nil, fmt.Errorf("failed to import inbound group session: %w", err)
} else if igsInternal.ID() != sessionID {
Expand All @@ -169,7 +175,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.
}

igs := &InboundGroupSession{
Internal: igsInternal,
InternalLibolm: igsInternal,
InternalGoolm: igsInternalGoolm,
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,
Expand Down
7 changes: 6 additions & 1 deletion crypto/keyexport.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"math"

"go.mau.fi/util/exerrors"
"go.mau.fi/util/random"
"golang.org/x/crypto/pbkdf2"

Expand Down Expand Up @@ -81,10 +82,14 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte)
func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) {
export := make([]ExportedSession, len(sessions))
for i, session := range sessions {
key, err := session.Internal.Export(session.Internal.FirstKnownIndex())
key, err := session.InternalLibolm.Export(session.InternalLibolm.FirstKnownIndex())
if err != nil {
return nil, fmt.Errorf("failed to export session: %w", err)
}
keyGoolm := exerrors.Must(session.InternalGoolm.Export(session.InternalGoolm.FirstKnownIndex()))
if !bytes.Equal(key, keyGoolm) {
panic("keys not equal")
}
export[i] = ExportedSession{
Algorithm: id.AlgorithmMegolmV1,
ForwardingChains: session.ForwardingChains,
Expand Down
33 changes: 20 additions & 13 deletions crypto/keyimport.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ import (
"fmt"
"time"

"maunium.net/go/mautrix/crypto/olm"
"go.mau.fi/util/exerrors"

"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/libolm"
"maunium.net/go/mautrix/id"
)

Expand Down Expand Up @@ -92,38 +95,42 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
return sessionsJSON, nil
}

func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
if session.Algorithm != id.AlgorithmMegolmV1 {
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, sess ExportedSession) (bool, error) {
if sess.Algorithm != id.AlgorithmMegolmV1 {
return false, ErrInvalidExportedAlgorithm
}

igsInternal, err := olm.InboundGroupSessionImport([]byte(session.SessionKey))
igsInternal, err := libolm.InboundGroupSessionImport([]byte(sess.SessionKey))
if err != nil {
return false, fmt.Errorf("failed to import session: %w", err)
} else if igsInternal.ID() != session.SessionID {
} else if igsInternal.ID() != sess.SessionID {
return false, ErrMismatchingExportedSessionID
}
igs := &InboundGroupSession{
Internal: igsInternal,
SigningKey: session.SenderClaimedKeys.Ed25519,
SenderKey: session.SenderKey,
RoomID: session.RoomID,
InternalLibolm: igsInternal,
InternalGoolm: exerrors.Must(session.NewMegolmInboundSessionFromExport([]byte(sess.SessionKey))),
SigningKey: sess.SenderClaimedKeys.Ed25519,
SenderKey: sess.SenderKey,
RoomID: sess.RoomID,
// TODO should we add something here to mark the signing key as unverified like key requests do?
ForwardingChains: session.ForwardingChains,
ForwardingChains: sess.ForwardingChains,

ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID())
firstKnownIndex := igs.Internal.FirstKnownIndex()
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex {
firstKnownIndex := igs.InternalLibolm.FirstKnownIndex()
if firstKnownIndex != igs.InternalGoolm.FirstKnownIndex() {
panic("indexes different")
}
if existingIGS != nil && existingIGS.InternalLibolm.FirstKnownIndex() <= firstKnownIndex {
// We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
err = mach.CryptoStore.PutGroupSession(ctx, igs)
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}
mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex)
mach.markSessionReceived(ctx, sess.RoomID, igs.ID(), firstKnownIndex)
return true, nil
}

Expand Down
Loading
Loading