diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 395775e1..21276218 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -35,13 +35,13 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, qrCode.TransactionID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", qrCode.TransactionID, err) - } else if txn.VerificationState != VerificationStateReady { + txn, ok := vh.activeTransactions[qrCode.TransactionID] + if !ok { + return fmt.Errorf("unknown transaction ID found in QR code") + } else if txn.VerificationState != verificationStateReady { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } - txn.VerificationState = VerificationStateTheirQRScanned + txn.VerificationState = verificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") @@ -53,9 +53,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by switch qrCode.Mode { case QRCodeModeCrossSigning: - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) } if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the master key we expected") @@ -70,7 +70,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -78,7 +78,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // means that we don't trust the key. Key1 is the master key public // key, and Key2 is what the other device thinks our device key is. - if vh.client.UserID != txn.TheirUserID { + if vh.client.UserID != txn.TheirUser { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } @@ -114,12 +114,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") } - if vh.client.UserID != txn.TheirUserID { + if vh.client.UserID != txn.TheirUser { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } @@ -140,7 +140,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } @@ -177,12 +177,8 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn.SentOurDone = true if txn.ReceivedTheirDone { log.Debug().Msg("We already received their done event. Setting verification state to done.") - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - return err - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) - } else { - vh.store.SaveVerificationTransaction(ctx, txn) } return nil } @@ -200,27 +196,28 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateOurQRScanned { + txn, ok := vh.activeTransactions[txnID] + if !ok { + log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") + return nil + } else if txn.VerificationState != verificationStateOurQRScanned { return fmt.Errorf("transaction is not in the scanned state") } log.Info().Msg("Confirming QR code scanned") - if txn.TheirUserID == vh.client.UserID { + if txn.TheirUser == vh.client.UserID { // Self-signing situation. Trust their device. // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return err } // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { return fmt.Errorf("failed to update device trust state after verifying: %w", err) } @@ -234,33 +231,29 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } } else { // Cross-signing situation. Sign their master key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) } - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true if txn.ReceivedTheirDone { - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - return err - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) - } else { - vh.store.SaveVerificationTransaction(ctx, txn) } return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn VerificationTransaction) error { +func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). @@ -283,7 +276,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver return err } mode := QRCodeModeCrossSigning - if vh.client.UserID == txn.TheirUserID { + if vh.client.UserID == txn.TheirUser { // This is a self-signing situation. if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted @@ -305,7 +298,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other user's master signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { return err } @@ -315,7 +308,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other device's key. - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return err } @@ -333,5 +326,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret vh.showQRCode(ctx, txn.TransactionID, qrCode) - return vh.store.SaveVerificationTransaction(ctx, txn) + return nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 0492dd8d..e28ec405 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -40,23 +40,23 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateReady { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateReady { return errors.New("transaction is not in ready state") } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } - txn.VerificationState = VerificationStateSASStarted + txn.VerificationState = verificationStateSASStarted txn.StartedByUs = true if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } // Ensure that we have their device key. - _, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { log.Err(err).Msg("Failed to fetch device") return err @@ -78,9 +78,6 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio event.SASMethodEmoji, }, } - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - return err - } return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) } @@ -97,13 +94,14 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateSASKeysExchanged { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateSASKeysExchanged { return errors.New("transaction is not in keys exchanged state") } + var err error keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") @@ -111,7 +109,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, myDeviceKeyID.String(), myDevice.SigningKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } @@ -120,7 +118,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) if err != nil { return err } @@ -131,7 +129,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keyIDs = append(keyIDs, keyID.String()) } slices.Sort(keyIDs) - keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { return err } @@ -147,14 +145,14 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat txn.SentOurMAC = true if txn.ReceivedTheirMAC { - txn.VerificationState = VerificationStateSASMACExchanged + txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true } - return vh.store.SaveVerificationTransaction(ctx, txn) + return nil } // onVerificationStartSAS handles the m.key.verification.start events with @@ -162,7 +160,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "start_sas"). @@ -210,7 +208,7 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve return fmt.Errorf("failed to generate ephemeral key: %w", err) } txn.MACMethod = macMethod - txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} + txn.EphemeralKey = ephemeralKey txn.StartEventContent = startEvt commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) @@ -228,8 +226,8 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve if err != nil { return fmt.Errorf("failed to send accept event: %w", err) } - txn.VerificationState = VerificationStateSASAccepted - return vh.store.SaveVerificationTransaction(ctx, txn) + txn.VerificationState = verificationStateSASAccepted + return nil } func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { @@ -254,7 +252,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // event. This follows Step 4 of [Section 11.12.2.2] of the Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { acceptEvt := evt.Content.AsVerificationAccept() log := vh.getLog(ctx).With(). Str("verification_action", "accept"). @@ -269,7 +267,7 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateSASStarted { + if txn.VerificationState != verificationStateSASStarted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received accept event for a transaction that is not in the started state") return @@ -289,18 +287,14 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri return } - txn.VerificationState = VerificationStateSASAccepted + txn.VerificationState = verificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment - txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} + txn.EphemeralKey = ephemeralKey txn.EphemeralPublicKeyShared = true - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() @@ -308,23 +302,22 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateSASAccepted { + if txn.VerificationState != verificationStateSASAccepted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received key event for a transaction that is not in the accepted state") return } var err error - publicKey, err := ecdh.X25519().NewPublicKey(keyEvt.Key) + txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { log.Err(err).Msg("Failed to generate other public key") return } - txn.OtherPublicKey = &ECDHPublicKey{publicKey} if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(publicKey, txn.StartEventContent) + commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return @@ -349,7 +342,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific } txn.EphemeralPublicKeyShared = true } - txn.VerificationState = VerificationStateSASKeysExchanged + txn.VerificationState = verificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { @@ -377,14 +370,10 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific } } vh.showSAS(ctx, txn.TransactionID, emojis, decimals) - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } -func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) +func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) if err != nil { return nil, err } @@ -399,8 +388,8 @@ func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ( }, "|") theirInfo := strings.Join([]string{ - txn.TheirUserID.String(), - txn.TheirDeviceID.String(), + txn.TheirUser.String(), + txn.TheirDevice.String(), base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), }, "|") @@ -473,8 +462,8 @@ func BrokenB64Encode(input []byte) string { return string(output) } -func (vh *VerificationHelper) verificationMACHKDF(txn VerificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) +func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) if err != nil { return nil, err } @@ -574,7 +563,7 @@ var allEmojis = []rune{ '📌', } -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() @@ -590,12 +579,12 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() - if kID == txn.TheirDeviceID.String() { + if kID == txn.TheirDevice.String() { hasTheirDeviceKey = true } } slices.Sort(keyIDs) - expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return @@ -621,8 +610,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific var key string var theirDevice *id.Device - if kID == txn.TheirDeviceID.String() { - theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if kID == txn.TheirDevice.String() { + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return @@ -641,7 +630,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific key = crossSigningKeys.MasterKey.String() } - expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return @@ -652,9 +641,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } // Trust their device - if kID == txn.TheirDeviceID.String() { + if kID == txn.TheirDevice.String() { theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) return @@ -665,7 +654,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific txn.ReceivedTheirMAC = true if txn.SentOurMAC { - txn.VerificationState = VerificationStateSASMACExchanged + txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -673,8 +662,4 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } txn.SentOurDone = true } - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index bc424624..be8357f5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -9,7 +9,7 @@ package verificationhelper import ( "bytes" "context" - "errors" + "crypto/ecdh" "fmt" "sync" "time" @@ -25,6 +25,86 @@ import ( "maunium.net/go/mautrix/id" ) +type verificationState int + +const ( + verificationStateRequested verificationState = iota + verificationStateReady + + verificationStateTheirQRScanned // We scanned their QR code + verificationStateOurQRScanned // They scanned our QR code + + verificationStateSASStarted // An SAS verification has been started + verificationStateSASAccepted // An SAS verification has been accepted + verificationStateSASKeysExchanged // An SAS verification has exchanged keys + verificationStateSASMACExchanged // An SAS verification has exchanged MACs +) + +func (step verificationState) String() string { + switch step { + case verificationStateRequested: + return "requested" + case verificationStateReady: + return "ready" + case verificationStateTheirQRScanned: + return "their_qr_scanned" + case verificationStateOurQRScanned: + return "our_qr_scanned" + case verificationStateSASStarted: + return "sas_started" + case verificationStateSASAccepted: + return "sas_accepted" + case verificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case verificationStateSASMACExchanged: + return "sas_mac" + default: + return fmt.Sprintf("verificationStep(%d)", step) + } +} + +type verificationTransaction struct { + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID + + // VerificationState is the current step of the verification flow. + VerificationState verificationState + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID + + // TheirDevice is the device ID of the device that either made the initial + // request or accepted our request. + TheirDevice id.DeviceID + // TheirUser is the user ID of the other user. + TheirUser id.UserID + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte + + StartedByUs bool // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content + Commitment []byte // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod // The method used to calculate the MAC + EphemeralKey *ecdh.PrivateKey // The ephemeral key + EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared + OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key + ReceivedTheirMAC bool // Whether we have received their MAC + SentOurMAC bool // Whether we have sent our MAC + ReceivedTheirDone bool // Whether we have received their done event + SentOurDone bool // Whether we have sent our done event +} + // RequiredCallbacks is an interface representing the callbacks required for // the [VerificationHelper]. type RequiredCallbacks interface { @@ -65,9 +145,8 @@ type VerificationHelper struct { client *mautrix.Client mach *crypto.OlmMachine - store VerificationStore + activeTransactions map[id.VerificationTransactionID]*verificationTransaction activeTransactionsLock sync.Mutex - // activeTransactions map[id.VerificationTransactionID]*verificationTransaction // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod @@ -84,19 +163,15 @@ type VerificationHelper struct { var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } - if store == nil { - store = NewInMemoryVerificationStore() - } - helper := VerificationHelper{ - client: client, - mach: mach, - store: store, + client: client, + mach: mach, + activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, } if c, ok := callbacks.(RequiredCallbacks); !ok { @@ -158,7 +233,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // Wrapper for the event handlers to check that the transaction ID is known // and ignore the event if it isn't. - wrapHandler := func(callback func(context.Context, VerificationTransaction, *event.Event)) func(context.Context, *event.Event) { + wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { return func(ctx context.Context, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "check transaction ID"). @@ -182,11 +257,8 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() - txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) - if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { - log.Err(err).Msg("failed to get verification transaction") - return - } else if errors.Is(err, ErrUnknownVerificationTransaction) { + txn, ok := vh.activeTransactions[transactionID] + if !ok { // If it's a cancellation event for an unknown transaction, we // can just ignore it. if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { @@ -199,9 +271,9 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // We have to create a fake transaction so that the call to // cancelVerificationTxn works. - txn = VerificationTransaction{ - RoomID: evt.RoomID, - TheirUserID: evt.Sender, + txn = &verificationTransaction{ + RoomID: evt.RoomID, + TheirUser: evt.Sender, } if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { txn.TransactionID = transactionable.GetTransactionID() @@ -209,7 +281,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { txn.TransactionID = id.VerificationTransactionID(evt.ID) } if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDeviceID = id.DeviceID(fromDevice.(string)) + txn.TheirDevice = id.DeviceID(fromDevice.(string)) } // Send a cancellation event. @@ -250,11 +322,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS - allTransactions, err := vh.store.GetAllVerificationTransactions(ctx) - for _, txn := range allTransactions { - vh.expireTransactionAt(txn.TransactionID, txn.ExpirationTime.Time) - } - return err + return nil } // StartVerification starts an interactive verification flow with the given @@ -314,12 +382,13 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ - VerificationState: VerificationStateRequested, + vh.activeTransactions[txnID] = &verificationTransaction{ + VerificationState: verificationStateRequested, TransactionID: txnID, - TheirUserID: to, + TheirUser: to, SentToDeviceIDs: maps.Keys(devices), - }) + } + return txnID, nil } // StartInRoomVerification starts an interactive verification flow with the @@ -353,12 +422,13 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + vh.activeTransactions[txnID] = &verificationTransaction{ RoomID: roomID, - VerificationState: VerificationStateRequested, + VerificationState: verificationStateRequested, TransactionID: txnID, - TheirUserID: to, - }) + TheirUser: to, + } + return txnID, nil } // AcceptVerification accepts a verification request. The transaction ID should @@ -370,10 +440,10 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V Stringer("transaction_id", txnID). Logger() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return err - } else if txn.VerificationState != VerificationStateRequested { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } @@ -402,11 +472,11 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V FromDevice: vh.client.DeviceID, Methods: maps.Keys(supportedMethods), } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { return err } - txn.VerificationState = VerificationStateReady + txn.VerificationState = verificationStateReady if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) @@ -422,7 +492,8 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return vh.store.DeleteVerification(ctx, txnID) + delete(vh.activeTransactions, txnID) + return nil } // DismissVerification cancels the verification request with the given @@ -433,9 +504,9 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return err + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") } log := vh.getLog(ctx).With(). Str("verification_action", "cancel verification"). @@ -456,28 +527,29 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V } else { cancelEvt.SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUserID: {}, + txn.TheirUser: {}, }} - if len(txn.TheirDeviceID) > 0 { + if len(txn.TheirDevice) > 0 { // Send the cancellation event to only the device that accepted the // verification request. All of the other devices already received a // cancellation event with code "m.acceped". - req.Messages[txn.TheirUserID][txn.TheirDeviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUser][txn.TheirDevice] = &event.Content{Parsed: cancelEvt} } else { // Send the cancellation event to all of the devices that we sent the // request to. for _, deviceID := range txn.SentToDeviceIDs { if deviceID != vh.client.DeviceID { - req.Messages[txn.TheirUserID][deviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUser][deviceID] = &event.Content{Parsed: cancelEvt} } } } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { - return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUserID]), err) + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) } } - return vh.store.DeleteVerification(ctx, txn.TransactionID) + delete(vh.activeTransactions, txn.TransactionID) + return nil } // sendVerificationEvent sends a verification event to the other user's device @@ -489,7 +561,7 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V // [event.VerificationTransactionable]. // - evtType can be either the to-device or in-room version of the event type // as it is always stringified. -func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn VerificationTransaction, evtType event.Type, content any) error { +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { if txn.RoomID != "" { content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ @@ -501,13 +573,13 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUserID: { - txn.TheirDeviceID: &event.Content{Parsed: content}, + txn.TheirUser: { + txn.TheirDevice: &event.Content{Parsed: content}, }, }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDeviceID, err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDevice, err) } } return nil @@ -519,7 +591,7 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver // directly to expose the error to its caller). // // Must always be called with the activeTransactionsLock held. -func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() log.Info(). @@ -533,9 +605,7 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn Ver log.Err(err).Msg("failed to send cancellation event") return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("deleting verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } @@ -614,58 +684,54 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev } vh.activeTransactionsLock.Lock() - newTxn := VerificationTransaction{ - ExpirationTime: jsontime.UnixMilli{Time: verificationRequest.Timestamp.Add(time.Minute * 10)}, + newTxn := &verificationTransaction{ RoomID: evt.RoomID, - VerificationState: VerificationStateRequested, + VerificationState: verificationStateRequested, TransactionID: verificationRequest.TransactionID, - TheirDeviceID: verificationRequest.FromDevice, - TheirUserID: evt.Sender, + TheirDevice: verificationRequest.FromDevice, + TheirUser: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } - if txn, err := vh.store.FindVerificationTransactionForUserDevice(ctx, evt.Sender, verificationRequest.FromDevice); err != nil && !errors.Is(err, ErrUnknownVerificationTransaction) { - log.Err(err).Stringer("sender", evt.Sender).Stringer("device_id", verificationRequest.FromDevice).Msg("failed to find verification transaction") - vh.activeTransactionsLock.Unlock() - return - } else if !errors.Is(err, ErrUnknownVerificationTransaction) { - if txn.TransactionID == verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") - } else { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + for existingTxnID, existingTxn := range vh.activeTransactions { + if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice && existingTxnID != verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return + } + + if existingTxnID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return } - vh.activeTransactionsLock.Unlock() - return - } - if err := vh.store.SaveVerificationTransaction(ctx, newTxn); err != nil { - log.Err(err).Msg("failed to save verification transaction") } + vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) } -func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expireAt time.Time) { go func() { - time.Sleep(time.Until(expiresAt)) + time.Sleep(time.Until(expireAt)) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(context.Background(), txnID) - if err == ErrUnknownVerificationTransaction { - // Already deleted, nothing to expire + txn, ok := vh.activeTransactions[txnID] + if !ok { return - } else if err != nil { - vh.getLog(context.Background()).Err(err).Msg("failed to get verification transaction to expire") - } else { - vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") } + + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") }() } -func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() @@ -673,7 +739,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateRequested { + if txn.VerificationState != verificationStateRequested { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } @@ -681,12 +747,12 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationState = VerificationStateReady - txn.TheirDeviceID = readyEvt.FromDevice + txn.VerificationState = verificationStateReady + txn.TheirDevice = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods log.Info(). - Stringer("their_device_id", txn.TheirDeviceID). + Stringer("their_device_id", txn.TheirDevice). Any("their_supported_methods", txn.TheirSupportedMethods). Msg("Received verification ready event") @@ -700,16 +766,16 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif Reason: "The verification was accepted on another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} for _, deviceID := range txn.SentToDeviceIDs { - if deviceID == txn.TheirDeviceID || deviceID == vh.client.DeviceID { + if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this // is a self-verification). continue } - req.Messages[txn.TheirUserID][deviceID] = content + req.Messages[txn.TheirUser][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -721,17 +787,18 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif vh.scanQRCode(ctx, txn.TransactionID) } - if err := vh.generateAndShowQRCode(ctx, txn); err != nil { + err := vh.generateAndShowQRCode(ctx, txn) + if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %w", err) } } -func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). - Stringer("their_device_id", txn.TheirDeviceID). + Stringer("their_device_id", txn.TheirDevice). Any("their_supported_methods", txn.TheirSupportedMethods). Bool("started_by_us", txn.StartedByUs). Logger() @@ -741,7 +808,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == VerificationStateSASStarted || txn.VerificationState == VerificationStateOurQRScanned || txn.VerificationState == VerificationStateTheirQRScanned { + if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { // We might have sent the event, and they also sent an event. if txn.StartEventContent == nil || !txn.StartedByUs { // We didn't sent a start event yet, so we have gotten ourselves @@ -773,12 +840,12 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif return } - if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { + if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { log.Debug().Msg("Using their start event instead of ours because they are alphabetically before us") txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != VerificationStateReady { + } else if txn.VerificationState != verificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return } @@ -786,7 +853,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif switch startEvt.Method { case event.VerificationMethodSAS: log.Info().Msg("Received SAS start event") - txn.VerificationState = VerificationStateSASStarted + txn.VerificationState = verificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } @@ -796,11 +863,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } - txn.VerificationState = VerificationStateOurQRScanned + txn.VerificationState = verificationStateOurQRScanned vh.qrCodeScaned(ctx, txn.TransactionID) - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes @@ -810,18 +874,17 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } } -func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn VerificationTransaction, evt *event.Event) { - log := vh.getLog(ctx).With(). +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + vh.getLog(ctx).Info(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). Bool("sent_our_done", txn.SentOurDone). - Logger() - log.Info().Msg("Verification done") + Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if !slices.Contains([]VerificationState{ - VerificationStateTheirQRScanned, VerificationStateOurQRScanned, VerificationStateSASMACExchanged, + if !slices.Contains([]verificationState{ + verificationStateTheirQRScanned, verificationStateOurQRScanned, verificationStateSASMACExchanged, }, txn.VerificationState) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return @@ -829,16 +892,12 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi txn.ReceivedTheirDone = true if txn.SentOurDone { - if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("Delete verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) - } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") } } -func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). @@ -864,7 +923,7 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn Veri // that is currently in the REQUESTED state, then we will send // cancellations to all of the devices that we sent the request to. This // will ensure that all of the clients know that the request was cancelled. - if txn.VerificationState == VerificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + if txn.VerificationState == verificationStateRequested && len(txn.SentToDeviceIDs) > 0 { content := &event.Content{ Parsed: &event.VerificationCancelEventContent{ ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, @@ -872,9 +931,9 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn Veri Reason: "The verification was rejected from another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} for _, deviceID := range txn.SentToDeviceIDs { - req.Messages[txn.TheirUserID][deviceID] = content + req.Messages[txn.TheirUser][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -882,8 +941,6 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn Veri } } - if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("Delete verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 937cc414..11358b88 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -278,12 +278,12 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // Emulate scanning the QR code shown by the receiving device // on the sending device. err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID") + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") // Emulate scanning the QR code shown by the sending device on // the receiving device. err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID") + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") } func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index d0bf2298..273042c3 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -65,11 +65,11 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -104,7 +104,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -151,7 +151,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, newAllVerificationCallbacks(), true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -241,12 +241,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -289,11 +289,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) err = receivingHelper.Init(ctx) require.NoError(t, err) diff --git a/crypto/verificationhelper/verificationstore.go b/crypto/verificationhelper/verificationstore.go deleted file mode 100644 index 725a66a6..00000000 --- a/crypto/verificationhelper/verificationstore.go +++ /dev/null @@ -1,187 +0,0 @@ -package verificationhelper - -import ( - "context" - "crypto/ecdh" - "encoding/json" - "errors" - "fmt" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ErrUnknownVerificationTransaction = errors.New("unknown transaction ID") - -type VerificationState int - -const ( - VerificationStateRequested VerificationState = iota - VerificationStateReady - - VerificationStateTheirQRScanned // We scanned their QR code - VerificationStateOurQRScanned // They scanned our QR code - - VerificationStateSASStarted // An SAS verification has been started - VerificationStateSASAccepted // An SAS verification has been accepted - VerificationStateSASKeysExchanged // An SAS verification has exchanged keys - VerificationStateSASMACExchanged // An SAS verification has exchanged MACs -) - -func (step VerificationState) String() string { - switch step { - case VerificationStateRequested: - return "requested" - case VerificationStateReady: - return "ready" - case VerificationStateTheirQRScanned: - return "their_qr_scanned" - case VerificationStateOurQRScanned: - return "our_qr_scanned" - case VerificationStateSASStarted: - return "sas_started" - case VerificationStateSASAccepted: - return "sas_accepted" - case VerificationStateSASKeysExchanged: - return "sas_keys_exchanged" - case VerificationStateSASMACExchanged: - return "sas_mac" - default: - return fmt.Sprintf("VerificationState(%d)", step) - } -} - -type ECDHPrivateKey struct { - *ecdh.PrivateKey -} - -func (e *ECDHPrivateKey) UnmarshalJSON(data []byte) (err error) { - e.PrivateKey, err = ecdh.P256().NewPrivateKey(data) - return -} - -func (e *ECDHPrivateKey) MarshalJSON() ([]byte, error) { - return json.Marshal(e.Bytes()) -} - -type ECDHPublicKey struct { - *ecdh.PublicKey -} - -func (e *ECDHPublicKey) UnmarshalJSON(data []byte) (err error) { - e.PublicKey, err = ecdh.P256().NewPublicKey(data) - return -} - -func (e *ECDHPublicKey) MarshalJSON() ([]byte, error) { - return json.Marshal(e.Bytes()) -} - -type VerificationTransaction struct { - ExpirationTime jsontime.UnixMilli `json:"expiration_time"` - - // RoomID is the room ID if the verification is happening in a room or - // empty if it is a to-device verification. - RoomID id.RoomID `json:"room_id"` - - // VerificationState is the current step of the verification flow. - VerificationState VerificationState `json:"verification_state"` - // TransactionID is the ID of the verification transaction. - TransactionID id.VerificationTransactionID `json:"transaction_id"` - - // TheirDeviceID is the device ID of the device that either made the - // initial request or accepted our request. - TheirDeviceID id.DeviceID `json:"their_device_id"` - // TheirUserID is the user ID of the other user. - TheirUserID id.UserID `json:"their_user_id"` - // TheirSupportedMethods is a list of verification methods that the other - // device supports. - TheirSupportedMethods []event.VerificationMethod `json:"their_supported_methods"` - - // SentToDeviceIDs is a list of devices which the initial request was sent - // to. This is only used for to-device verification requests, and is meant - // to be used to send cancellation requests to all other devices when a - // verification request is accepted via a m.key.verification.ready event. - SentToDeviceIDs []id.DeviceID `json:"sent_to_device_ids"` - - // QRCodeSharedSecret is the shared secret that was encoded in the QR code - // that we showed. - QRCodeSharedSecret []byte `json:"qr_code_shared_secret"` - - StartedByUs bool `json:"started_by_us"` // Whether the verification was started by us - StartEventContent *event.VerificationStartEventContent `json:"start_event_content"` // The m.key.verification.start event content - Commitment []byte `json:"committment"` // The commitment from the m.key.verification.accept event - MACMethod event.MACMethod `json:"mac_method"` // The method used to calculate the MAC - EphemeralKey *ECDHPrivateKey `json:"ephemeral_key"` // The ephemeral key - EphemeralPublicKeyShared bool `json:"ephemeral_public_key_shared"` // Whether this device's ephemeral public key has been shared - OtherPublicKey *ECDHPublicKey `json:"other_public_key"` // The other device's ephemeral public key - ReceivedTheirMAC bool `json:"received_their_mac"` // Whether we have received their MAC - SentOurMAC bool `json:"sent_our_mac"` // Whether we have sent our MAC - ReceivedTheirDone bool `json:"received_their_done"` // Whether we have received their done event - SentOurDone bool `json:"sent_our_done"` // Whether we have sent our done event -} - -type VerificationStore interface { - // DeleteVerification deletes a verification transaction by ID - DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error - // GetVerificationTransaction gets a verification transaction by ID - GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) - // SaveVerificationTransaction saves a verification transaction by ID - SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error - // FindVerificationTransactionForUserDevice finds a verification - // transaction by user and device ID - FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) - // GetAllVerificationTransactions returns all of the verification - // transactions. This is used to reset the cancellation timeouts. - GetAllVerificationTransactions(ctx context.Context) ([]VerificationTransaction, error) -} - -type InMemoryVerificationStore struct { - txns map[id.VerificationTransactionID]VerificationTransaction -} - -var _ VerificationStore = (*InMemoryVerificationStore)(nil) - -func NewInMemoryVerificationStore() *InMemoryVerificationStore { - return &InMemoryVerificationStore{ - txns: map[id.VerificationTransactionID]VerificationTransaction{}, - } -} - -func (i *InMemoryVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error { - if _, ok := i.txns[txnID]; !ok { - return ErrUnknownVerificationTransaction - } - delete(i.txns, txnID) - return nil -} - -func (i *InMemoryVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) { - if _, ok := i.txns[txnID]; !ok { - return VerificationTransaction{}, ErrUnknownVerificationTransaction - } - return i.txns[txnID], nil -} - -func (i *InMemoryVerificationStore) SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error { - i.txns[txn.TransactionID] = txn - return nil -} - -func (i *InMemoryVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) { - for _, existingTxn := range i.txns { - if existingTxn.TheirUserID == userID && existingTxn.TheirDeviceID == deviceID { - return existingTxn, nil - } - } - return VerificationTransaction{}, ErrUnknownVerificationTransaction -} - -func (i *InMemoryVerificationStore) GetAllVerificationTransactions(ctx context.Context) (txns []VerificationTransaction, err error) { - for _, txn := range i.txns { - txns = append(txns, txn) - } - return -}