Skip to content

Commit

Permalink
refactor: deriveAddressSeed and deriveAddress utilities
Browse files Browse the repository at this point in the history
**Problems**

- The distinction between `derive_address_seed` and `derive_address` was
  unclear and we were inconsistent in it:
  - We ended up applying address Merkle tree public key in both
    functions, which is confusing.
- Before this change, there was no TypeScript function for deriving
  **address seed**. There was only `deriveAddress`, but deriving the
  unified seed was a mystery for developers.
- We have two utilities for hashing and truncating to BN254:
  - `hash_to_bn254_field_size_be` - the older one, which:
    - Searches for a bump in a loop, adds it to the hash inputs and then
      truncates the hash. That doesn't make sense, because truncating
      the hash should be sufficient, adding a bump is unnecessary.
    - Another limitation is that it takes only one sequence of bytes,
      making it difficult to provide multiple inputs without
      concatenating them.
  - `hashv_to_bn254_field_size` - the newer one, which:
    - Just truncates the hash result, without the bump mechanism.
    - Takes 2D byte slice as input, making it possible to pass multiple
      inputs.

**Changes**

- Don't add MT pubkey in `derive_address_seed`. It's not a correct place
  for it to be applied. The distinction between `derive_address_seed`
  and `derive_address` should be:
  - `derive_address_seed` takes provided seeds (defined by the
    developer) and hashes them together with the program ID. This
    operation is done only in the third-party program.
  - `derive_address` takes the address seed (result of
    `address_address_seed`) and hashes it together with the address
    Merkle tree public key. This is done both in the third-party program
    and in light-system-program. light-system-program does that as a
    check whether the correct Merkle tree is used.
- Adjust the stateless.js API:
  - Provide `deriveAddressSeed` function.
  - Add unit tests, make sure that `deriveAddressSeed` and
    `deriveAddress` provide the same results as the equivalent functions
    in Rust SDK.
  • Loading branch information
vadorovsky committed Sep 30, 2024
1 parent 12f1750 commit 862b455
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 87 deletions.
6 changes: 1 addition & 5 deletions examples/name-service/programs/name-service/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ async fn test_name_service() {
address_queue_pubkey: env.address_merkle_tree_queue_pubkey,
};

let address_seed = derive_address_seed(
&[b"name-service", name.as_bytes()],
&name_service::ID,
&address_merkle_context,
);
let address_seed = derive_address_seed(&[b"name-service", name.as_bytes()], &name_service::ID);
let address = derive_address(&address_seed, &address_merkle_context);

let address_merkle_context =
Expand Down
17 changes: 9 additions & 8 deletions js/stateless.js/src/actions/create-account.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
NewAddressParams,
buildAndSignTx,
deriveAddress,
deriveAddressSeed,
sendAndConfirmTx,
} from '../utils';
import { defaultTestStateTreeAccounts } from '../constants';
Expand All @@ -25,7 +26,7 @@ import { BN } from '@coral-xyz/anchor';
*
* @param rpc RPC to use
* @param payer Payer of the transaction and initialization fees
* @param seed Seed to derive the new account address
* @param seeds Seeds to derive the new account address
* @param programId Owner of the new account
* @param addressTree Optional address tree. Defaults to a current shared
* address tree.
Expand All @@ -40,7 +41,7 @@ import { BN } from '@coral-xyz/anchor';
export async function createAccount(
rpc: Rpc,
payer: Signer,
seed: Uint8Array,
seeds: Uint8Array[],
programId: PublicKey,
addressTree?: PublicKey,
addressQueue?: PublicKey,
Expand All @@ -52,8 +53,8 @@ export async function createAccount(
addressTree = addressTree ?? defaultTestStateTreeAccounts().addressTree;
addressQueue = addressQueue ?? defaultTestStateTreeAccounts().addressQueue;

/// TODO: enforce program-derived
const address = await deriveAddress(seed, addressTree);
const seed = deriveAddressSeed(seeds, programId);
const address = deriveAddress(seed, addressTree);

const proof = await rpc.getValidityProofV0(undefined, [
{
Expand Down Expand Up @@ -96,7 +97,7 @@ export async function createAccount(
*
* @param rpc RPC to use
* @param payer Payer of the transaction and initialization fees
* @param seed Seed to derive the new account address
* @param seeds Seeds to derive the new account address
* @param lamports Number of compressed lamports to initialize the
* account with
* @param programId Owner of the new account
Expand All @@ -114,7 +115,7 @@ export async function createAccount(
export async function createAccountWithLamports(
rpc: Rpc,
payer: Signer,
seed: Uint8Array,
seeds: Uint8Array[],
lamports: number | BN,
programId: PublicKey,
addressTree?: PublicKey,
Expand All @@ -138,8 +139,8 @@ export async function createAccountWithLamports(
addressTree = addressTree ?? defaultTestStateTreeAccounts().addressTree;
addressQueue = addressQueue ?? defaultTestStateTreeAccounts().addressQueue;

/// TODO: enforce program-derived
const address = await deriveAddress(seed, addressTree);
const seed = deriveAddressSeed(seeds, programId);
const address = deriveAddress(seed, addressTree);

const proof = await rpc.getValidityProof(
inputAccounts.map(account => bn(account.hash)),
Expand Down
78 changes: 72 additions & 6 deletions js/stateless.js/src/utils/address.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import { AccountMeta, PublicKey } from '@solana/web3.js';
import { hashToBn254FieldSizeBe } from './conversion';
import { hashToBn254FieldSizeBe, hashvToBn254FieldSizeBe } from './conversion';
import { defaultTestStateTreeAccounts } from '../constants';
import { getIndexOrAdd } from '../instruction';

export function deriveAddressSeed(
seeds: Uint8Array[],
programId: PublicKey,
): Uint8Array {
const combinedSeeds: Uint8Array[] = [programId.toBytes(), ...seeds];
const hash = hashvToBn254FieldSizeBe(combinedSeeds);
return hash;
}

/**
* Derive an address for a compressed account from a seed and a merkle tree
* public key.
Expand All @@ -12,13 +21,13 @@ import { getIndexOrAdd } from '../instruction';
* defaultTestStateTreeAccounts().merkleTree
* @returns Derived address
*/
export async function deriveAddress(
export function deriveAddress(
seed: Uint8Array,
merkleTreePubkey: PublicKey = defaultTestStateTreeAccounts().merkleTree,
): Promise<PublicKey> {
): PublicKey {
const bytes = merkleTreePubkey.toBytes();
const combined = Buffer.from([...bytes, ...seed]);
const hash = await hashToBn254FieldSizeBe(combined);
const hash = hashToBn254FieldSizeBe(combined);

if (hash === null) {
throw new Error('DeriveAddressError');
Expand Down Expand Up @@ -115,14 +124,71 @@ if (import.meta.vitest) {
//@ts-ignore
const { it, expect, describe } = import.meta.vitest;

const programId = new PublicKey(
'7yucc7fL3JGbyMwg4neUaenNSdySS39hbAk89Ao3t1Hz',
);

describe('derive address seed', () => {
it('should derive a valid address seed', () => {
const seeds: Uint8Array[] = [
new TextEncoder().encode('foo'),
new TextEncoder().encode('bar'),
];
expect(deriveAddressSeed(seeds, programId)).toStrictEqual(
new Uint8Array([
0, 246, 150, 3, 192, 95, 53, 123, 56, 139, 206, 179, 253,
133, 115, 103, 120, 155, 251, 72, 250, 47, 117, 217, 118,
59, 174, 207, 49, 101, 201, 110,
]),
);
});

it('should derive a valid address seed', () => {
const seeds: Uint8Array[] = [
new TextEncoder().encode('ayy'),
new TextEncoder().encode('lmao'),
];
expect(deriveAddressSeed(seeds, programId)).toStrictEqual(
new Uint8Array([
0, 202, 44, 25, 221, 74, 144, 92, 69, 168, 38, 19, 206, 208,
29, 162, 53, 27, 120, 214, 152, 116, 15, 107, 212, 168, 33,
121, 187, 10, 76, 233,
]),
);
});
});

describe('deriveAddress function', () => {
it('should derive a valid address from a seed and a merkle tree public key', async () => {
const seed = new Uint8Array([1, 2, 3, 4]);
const seeds: Uint8Array[] = [
new TextEncoder().encode('foo'),
new TextEncoder().encode('bar'),
];
const seed = deriveAddressSeed(seeds, programId);
const merkleTreePubkey = new PublicKey(
'11111111111111111111111111111111',
);
const derivedAddress = await deriveAddress(seed, merkleTreePubkey);
const derivedAddress = deriveAddress(seed, merkleTreePubkey);
expect(derivedAddress).toBeInstanceOf(PublicKey);
expect(derivedAddress).toStrictEqual(
new PublicKey('139uhyyBtEh4e1CBDJ68ooK5nCeWoncZf9HPyAfRrukA'),
);
});

it('should derive a valid address from a seed and a merkle tree public key', async () => {
const seeds: Uint8Array[] = [
new TextEncoder().encode('ayy'),
new TextEncoder().encode('lmao'),
];
const seed = deriveAddressSeed(seeds, programId);
const merkleTreePubkey = new PublicKey(
'11111111111111111111111111111111',
);
const derivedAddress = deriveAddress(seed, merkleTreePubkey);
expect(derivedAddress).toBeInstanceOf(PublicKey);
expect(derivedAddress).toStrictEqual(
new PublicKey('12bhHm6PQjbNmEn3Yu1Gq9k7XwVn2rZpzYokmLwbFazN'),
);
});
});

Expand Down
38 changes: 35 additions & 3 deletions js/stateless.js/src/utils/conversion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,23 @@ function isSmallerThanBn254FieldSizeBe(bytes: Buffer): boolean {
return bigint.lt(FIELD_SIZE);
}

export async function hashToBn254FieldSizeBe(
bytes: Buffer,
): Promise<[Buffer, number] | null> {
/**
* Hash the provided `bytes` with Keccak256 and ensure the result fits in the
* BN254 prime field by repeatedly hashing the inputs with various "bump seeds"
* and truncating the resulting hash to 31 bytes.
*
* @deprecated Use `hashvToBn254FieldSizeBe` instead.
*/
export function hashToBn254FieldSizeBe(bytes: Buffer): [Buffer, number] | null {
// TODO(vadorovsky, affects-onchain): Get rid of the bump mechanism, it
// makes no sense. Doing the same as in the `hashvToBn254FieldSizeBe` below
// - overwriting the most significant byte with zero - is sufficient for
// truncation, it's also faster, doesn't force us to return `Option` and
// care about handling an error which is practically never returned.
//
// The reason we can't do it now is that it would affect on-chain programs.
// Once we can update programs, we can get rid of the seed bump (or even of
// this function all together in favor of the `hashv` variant).
let bumpSeed = 255;
while (bumpSeed >= 0) {
const inputWithBumpSeed = Buffer.concat([
Expand All @@ -51,6 +65,24 @@ export async function hashToBn254FieldSizeBe(
return null;
}

/**
* Hash the provided `bytes` with Keccak256 and ensure that the result fits in
* the BN254 prime field by truncating the resulting hash to 31 bytes.
*
* @param bytes Input bytes
*
* @returns Hash digest
*/
export function hashvToBn254FieldSizeBe(bytes: Uint8Array[]): Uint8Array {
const hasher = keccak_256.create();
for (const input of bytes) {
hasher.update(input);
}
const hash = hasher.digest();
hash[0] = 0;
return hash;
}

/** Mutates array in place */
export function pushUniqueItems<T>(items: T[], map: T[]): void {
items.forEach(item => {
Expand Down
61 changes: 37 additions & 24 deletions js/stateless.js/tests/e2e/compress.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,51 +78,62 @@ describe('compress', () => {
await createAccount(
rpc as TestRpc,
payer,
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
]),
[
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
]),
],
LightSystemProgram.programId,
);

await createAccountWithLamports(
rpc as TestRpc,
payer,
new Uint8Array([
1, 2, 255, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
]),
[
new Uint8Array([
1, 2, 255, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
]),
],
0,
LightSystemProgram.programId,
);

await createAccount(
rpc as TestRpc,
payer,
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 1,
]),
[
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 1,
]),
],
LightSystemProgram.programId,
);

await createAccount(
rpc as TestRpc,
payer,
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 2,
]),
[
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 2,
]),
],
LightSystemProgram.programId,
);
await expect(
createAccount(
rpc as TestRpc,
payer,
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 2,
]),
[
new Uint8Array([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 2,
]),
],
LightSystemProgram.programId,
),
).rejects.toThrow();
Expand Down Expand Up @@ -169,10 +180,12 @@ describe('compress', () => {
await createAccountWithLamports(
rpc as TestRpc,
payer,
new Uint8Array([
1, 255, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
]),
[
new Uint8Array([
1, 255, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
]),
],
100,
LightSystemProgram.programId,
);
Expand Down
Loading

0 comments on commit 862b455

Please sign in to comment.