Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Nov 22, 2024
1 parent d3df25e commit 0d838e2
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
1 change: 1 addition & 0 deletions crypto/verificationhelper/sas.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific
return
}
} else {
fmt.Printf("txn %+v\n", txn.EphemeralKey)
err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationKey, &event.VerificationKeyEventContent{
Key: txn.EphemeralKey.PublicKey().Bytes(),
})
Expand Down
8 changes: 7 additions & 1 deletion crypto/verificationhelper/verificationhelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package verificationhelper_test

import (
"context"
"database/sql"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -65,7 +66,12 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ
func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) {
t.Helper()
sendingCallbacks = newAllVerificationCallbacks()
sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true)
db, err := sql.Open("sqlite3", ":memory:")
require.NoError(t, err)
store, err := NewSQLiteVerificationStore(ctx, db)
require.NoError(t, err)

sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, store, sendingCallbacks, true)
require.NoError(t, sendingHelper.Init(ctx))

receivingCallbacks = newAllVerificationCallbacks()
Expand Down
99 changes: 99 additions & 0 deletions crypto/verificationhelper/verificationstore_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package verificationhelper_test

import (
"context"
"database/sql"
"encoding/json"
"fmt"

_ "github.com/mattn/go-sqlite3"
"go.mau.fi/util/dbutil"

"maunium.net/go/mautrix/crypto/verificationhelper"
"maunium.net/go/mautrix/id"
)

type SQLiteVerificationStore struct {
db *sql.DB
}

const (
selectVerifications = `SELECT transaction_data FROM verifications`
getVerificationByTransactionID = selectVerifications + ` WHERE transaction_id = ?1`
getVerificationByUserDeviceID = selectVerifications + `
WHERE transaction_data->>'their_user_id' = ?1
AND transaction_data->>'their_device_id' = ?2
`
deleteVerificationsQuery = `DELETE FROM verifications WHERE transaction_id = ?1`
)

var _ verificationhelper.VerificationStore = (*SQLiteVerificationStore)(nil)

func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerificationStore, error) {
_, err := db.ExecContext(ctx, `
CREATE TABLE verifications (
transaction_id TEXT PRIMARY KEY NOT NULL,
transaction_data JSONB NOT NULL
);
CREATE INDEX verifications_user_device_id ON
verifications(transaction_data->>'their_user_id', transaction_data->>'their_device_id');
`)
return &SQLiteVerificationStore{db}, err
}

func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) {
rows, err := s.db.QueryContext(ctx, selectVerifications)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) {
err = rows.Scan(&dbutil.JSON{Data: &txn})
return
}).AsList()
}

func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) {
// DEBUG
row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID)
var x []byte
err = row.Scan(&x)
if err == sql.ErrNoRows {
err = verificationhelper.ErrUnknownVerificationTransaction
}
fmt.Printf("GET %s = %s\n", txnID, x)
// END DEBUG

row = vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID)
err = row.Scan(&dbutil.JSON{Data: &txn})
if err == sql.ErrNoRows {
err = verificationhelper.ErrUnknownVerificationTransaction
}
return
}

func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) {
row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID)
err = row.Scan(&dbutil.JSON{Data: &txn})
if err == sql.ErrNoRows {
err = verificationhelper.ErrUnknownVerificationTransaction
}
return
}

func (vq *SQLiteVerificationStore) SaveVerificationTransaction(ctx context.Context, txn verificationhelper.VerificationTransaction) (err error) {
x, _ := json.Marshal(txn)
fmt.Printf("SAVE %s = %s\n", txn.TransactionID, x)
// zerolog.Ctx(ctx).Debug().Any("transaction", txn).Msg("Saving verification transaction")
_, err = vq.db.ExecContext(ctx, `
INSERT INTO verifications (transaction_id, transaction_data)
VALUES (?1, ?2)
ON CONFLICT (transaction_id) DO UPDATE
SET transaction_data=excluded.transaction_data
`, txn.TransactionID, &dbutil.JSON{Data: txn})
return
}

func (vq *SQLiteVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) (err error) {
_, err = vq.db.ExecContext(ctx, deleteVerificationsQuery, txnID)
return
}

0 comments on commit 0d838e2

Please sign in to comment.