diff --git a/prover/server/integration_test.go b/prover/server/integration_test.go index 8f0086b070..74b9dc5158 100644 --- a/prover/server/integration_test.go +++ b/prover/server/integration_test.go @@ -465,7 +465,7 @@ func testBatchAddressAppendWithPreviousState40_100(t *testing.T) { } func runBatchAddressAppendWithPreviousStateTest(t *testing.T, treeHeight uint32, batchSize uint32) { - startIndex := uint32(2) + startIndex := uint64(2) params1, err := prover.BuildTestAddressTree(treeHeight, batchSize, nil, startIndex) if err != nil { t.Fatalf("Failed to build first test tree: %v", err) @@ -487,7 +487,7 @@ func runBatchAddressAppendWithPreviousStateTest(t *testing.T, treeHeight uint32, } response1.Body.Close() - startIndex += batchSize + startIndex += uint64(batchSize) params2, err := prover.BuildTestAddressTree(treeHeight, batchSize, params1.Tree, startIndex) if err != nil { t.Fatalf("Failed to build second test tree: %v", err) @@ -521,7 +521,7 @@ func runBatchAddressAppendWithPreviousStateTest(t *testing.T, treeHeight uint32, func testBatchAddressAppendInvalidInput40_10(t *testing.T) { treeHeight := uint32(40) batchSize := uint32(10) - startIndex := uint32(0) + startIndex := uint64(0) params, err := prover.BuildTestAddressTree(treeHeight, batchSize, nil, startIndex) if err != nil { diff --git a/prover/server/prover/batch_address_append_circuit.go b/prover/server/prover/batch_address_append_circuit.go index 058e51dc00..6c1617591f 100644 --- a/prover/server/prover/batch_address_append_circuit.go +++ b/prover/server/prover/batch_address_append_circuit.go @@ -248,7 +248,7 @@ type BatchAddressAppendParameters struct { OldRoot *big.Int NewRoot *big.Int HashchainHash *big.Int - StartIndex uint32 + StartIndex uint64 LowElementValues []big.Int LowElementIndices []big.Int diff --git a/prover/server/prover/batch_address_append_circuit_test.go b/prover/server/prover/batch_address_append_circuit_test.go index 5bad93a9d2..d5a22fbb4e 100644 --- a/prover/server/prover/batch_address_append_circuit_test.go +++ b/prover/server/prover/batch_address_append_circuit_test.go @@ -45,7 +45,7 @@ func TestBatchAddressAppendCircuit(t *testing.T) { name string treeHeight uint32 batchSize uint32 - startIndex uint32 + startIndex uint64 shouldPass bool }{ {"Single insert height 4", 4, 1, 2, true}, @@ -83,8 +83,9 @@ func TestBatchAddressAppendCircuit(t *testing.T) { name string treeHeight uint32 batchSize uint32 - startIndex uint32 + startIndex uint64 modifyParams func(*BatchAddressAppendParameters) + wantPanic bool }{ { name: "Invalid OldRoot", @@ -122,6 +123,138 @@ func TestBatchAddressAppendCircuit(t *testing.T) { p.LowElementValues[0].Add(&p.LowElementValues[0], big.NewInt(1)) }, }, + { + name: "StartIndex too large", + treeHeight: 4, + batchSize: 1, + startIndex: 0, + modifyParams: func(p *BatchAddressAppendParameters) { + p.StartIndex = ^uint64(0) + }, + }, + { + name: "Mismatched array length", + treeHeight: 4, + batchSize: 2, + startIndex: 0, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementValues = p.LowElementValues[:len(p.LowElementValues)-1] + }, + wantPanic: true, + }, + { + name: "Invalid proof length", + treeHeight: 4, + batchSize: 2, + startIndex: 0, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementProofs[0] = p.LowElementProofs[0][:len(p.LowElementProofs[0])-1] + }, + wantPanic: true, + }, + { + name: "Empty arrays", + treeHeight: 4, + batchSize: 2, + startIndex: 0, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementValues = make([]big.Int, p.BatchSize) + p.NewElementValues = make([]big.Int, p.BatchSize) + }, + }, + { + name: "Max values", + treeHeight: 4, + batchSize: 1, + startIndex: 0, + modifyParams: func(p *BatchAddressAppendParameters) { + maxBigInt := new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)) + p.NewElementValues[0] = *maxBigInt + }, + }, + { + name: "Inconsistent start index with proofs", + treeHeight: 4, + batchSize: 1, + startIndex: 0, + modifyParams: func(p *BatchAddressAppendParameters) { + p.StartIndex = 5 + }, + }, + { + name: "Low element below expected range", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementValues[0].Sub(&p.LowElementValues[0], big.NewInt(1)) + }, + }, + { + name: "Low element above expected range", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + // Set low element value above valid range + maxVal := new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil) + p.LowElementValues[0].Add(&p.LowElementValues[0], maxVal) + }, + }, + { + name: "Invalid low element next indices", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementNextIndices[0].Add(&p.LowElementNextIndices[0], big.NewInt(5)) + }, + }, + { + name: "Invalid low element next values", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementNextValues[0].Add(&p.LowElementNextValues[0], big.NewInt(1)) + }, + }, + { + name: "Invalid low element indices", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementIndices[0].Add(&p.LowElementIndices[0], big.NewInt(3)) + }, + }, + { + name: "Invalid low element proofs", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.LowElementProofs[0][0].Add(&p.LowElementProofs[0][0], big.NewInt(1)) + }, + }, + { + name: "Invalid new element proofs", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.NewElementProofs[0][0].Add(&p.NewElementProofs[0][0], big.NewInt(1)) + }, + }, + { + name: "Invalid new element values", + treeHeight: 4, + batchSize: 1, + startIndex: 2, + modifyParams: func(p *BatchAddressAppendParameters) { + p.NewElementValues[0].Add(&p.NewElementValues[0], big.NewInt(1)) + }, + }, } for _, tc := range testCases { @@ -135,6 +268,14 @@ func TestBatchAddressAppendCircuit(t *testing.T) { tc.modifyParams(params) + if tc.wantPanic { + assert.Panics(func() { + witness, _ := params.CreateWitness() + test.IsSolved(&circuit, witness, ecc.BN254.ScalarField()) + }) + return + } + witness, err := params.CreateWitness() if err != nil { return diff --git a/prover/server/prover/batch_append_with_proofs_circuit.go b/prover/server/prover/batch_append_with_proofs_circuit.go index f9d8e04d4b..fbc8496b28 100644 --- a/prover/server/prover/batch_append_with_proofs_circuit.go +++ b/prover/server/prover/batch_append_with_proofs_circuit.go @@ -80,7 +80,7 @@ type BatchAppendWithProofsParameters struct { LeavesHashchainHash *big.Int Leaves []*big.Int MerkleProofs [][]big.Int - StartIndex uint32 + StartIndex uint64 Height uint32 BatchSize uint32 Tree *merkle_tree.PoseidonTree diff --git a/prover/server/prover/batch_append_with_proofs_circuit_test.go b/prover/server/prover/batch_append_with_proofs_circuit_test.go index f91383f450..19275b29ca 100644 --- a/prover/server/prover/batch_append_with_proofs_circuit_test.go +++ b/prover/server/prover/batch_append_with_proofs_circuit_test.go @@ -1,6 +1,7 @@ package prover import ( + "math/big" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -111,4 +112,157 @@ func TestBatchAppendWithProofsCircuit(t *testing.T) { err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) assert.NoError(err) }) + + t.Run("Invalid public input hash", func(t *testing.T) { + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + params.PublicInputHash = big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid old root", func(t *testing.T) { + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + params.OldRoot = big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid new root", func(t *testing.T) { + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + params.NewRoot = big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid leaves hashchain", func(t *testing.T) { + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + params.LeavesHashchainHash = big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid merkle proof", func(t *testing.T) { + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + params.MerkleProofs[0][0] = *big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid start index", func(t *testing.T) { + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + params.StartIndex = uint64(1 << treeDepth) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid old leaves", func(t *testing.T) { + assert := test.NewAssert(t) + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + + params.OldLeaves[0] = big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid leaves", func(t *testing.T) { + assert := test.NewAssert(t) + treeDepth := 10 + batchSize := 2 + params := BuildTestBatchAppendWithProofsTree(treeDepth, batchSize, nil, 0, false) + + params.Leaves[0] = big.NewInt(999) + + witness := createTestWitness(params) + circuit := createTestCircuit(treeDepth, batchSize) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) +} + +func createTestCircuit(treeDepth, batchSize int) BatchAppendWithProofsCircuit { + circuit := BatchAppendWithProofsCircuit{ + PublicInputHash: frontend.Variable(0), + OldRoot: frontend.Variable(0), + NewRoot: frontend.Variable(0), + LeavesHashchainHash: frontend.Variable(0), + OldLeaves: make([]frontend.Variable, batchSize), + Leaves: make([]frontend.Variable, batchSize), + StartIndex: frontend.Variable(0), + MerkleProofs: make([][]frontend.Variable, batchSize), + Height: uint32(treeDepth), + BatchSize: uint32(batchSize), + } + + for i := range circuit.MerkleProofs { + circuit.MerkleProofs[i] = make([]frontend.Variable, treeDepth) + } + return circuit +} + +func createTestWitness(params *BatchAppendWithProofsParameters) BatchAppendWithProofsCircuit { + witness := BatchAppendWithProofsCircuit{ + PublicInputHash: frontend.Variable(params.PublicInputHash), + OldRoot: frontend.Variable(params.OldRoot), + NewRoot: frontend.Variable(params.NewRoot), + LeavesHashchainHash: frontend.Variable(params.LeavesHashchainHash), + OldLeaves: make([]frontend.Variable, int(params.BatchSize)), + Leaves: make([]frontend.Variable, int(params.BatchSize)), + MerkleProofs: make([][]frontend.Variable, int(params.BatchSize)), + StartIndex: frontend.Variable(params.StartIndex), + Height: params.Height, + BatchSize: params.BatchSize, + } + + for i := 0; i < int(params.BatchSize); i++ { + witness.Leaves[i] = frontend.Variable(params.Leaves[i]) + witness.OldLeaves[i] = frontend.Variable(params.OldLeaves[i]) + witness.MerkleProofs[i] = make([]frontend.Variable, params.Height) + for j := 0; j < int(params.Height); j++ { + witness.MerkleProofs[i][j] = frontend.Variable(params.MerkleProofs[i][j]) + } + } + return witness } diff --git a/prover/server/prover/batch_update_circuit_test.go b/prover/server/prover/batch_update_circuit_test.go index 829ab8cbae..8d6aab8bf8 100644 --- a/prover/server/prover/batch_update_circuit_test.go +++ b/prover/server/prover/batch_update_circuit_test.go @@ -107,6 +107,20 @@ func TestBatchUpdateCircuit(t *testing.T) { } }) + t.Run("Invalid OldRoot", func(t *testing.T) { + treeDepth := 10 + batchSize := 5 + params := BuildTestBatchUpdateTree(treeDepth, batchSize, nil, nil) + + circuit := createBatchUpdateCircuit(treeDepth, batchSize) + witness := createBatchUpdateWitness(params, 0, batchSize) + + witness.OldRoot = frontend.Variable(new(big.Int).Add(params.OldRoot, big.NewInt(1))) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + t.Run("Invalid NewRoot", func(t *testing.T) { treeDepth := 10 batchSize := 5 @@ -122,6 +136,65 @@ func TestBatchUpdateCircuit(t *testing.T) { assert.Error(err) }) + t.Run("Invalid old leaf", func(t *testing.T) { + treeDepth := 10 + batchSize := 5 + params := BuildTestBatchUpdateTree(treeDepth, batchSize, nil, nil) + + circuit := createBatchUpdateCircuit(treeDepth, batchSize) + witness := createBatchUpdateWitness(params, 0, batchSize) + + // Modify one old leaf to make it invalid + witness.OldLeaves[0] = frontend.Variable(new(big.Int).Add(params.OldLeaves[0], big.NewInt(1))) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid PublicInputHash", func(t *testing.T) { + treeDepth := 10 + batchSize := 5 + params := BuildTestBatchUpdateTree(treeDepth, batchSize, nil, nil) + + circuit := createBatchUpdateCircuit(treeDepth, batchSize) + witness := createBatchUpdateWitness(params, 0, batchSize) + + witness.PublicInputHash = frontend.Variable(new(big.Int).Add(params.PublicInputHash, big.NewInt(1))) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid PathIndex", func(t *testing.T) { + treeDepth := 10 + batchSize := 5 + params := BuildTestBatchUpdateTree(treeDepth, batchSize, nil, nil) + + circuit := createBatchUpdateCircuit(treeDepth, batchSize) + witness := createBatchUpdateWitness(params, 0, batchSize) + + // Set invalid path index + witness.PathIndices[0] = frontend.Variable(uint32(1 << treeDepth)) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + + t.Run("Invalid MerkleProof", func(t *testing.T) { + treeDepth := 10 + batchSize := 5 + params := BuildTestBatchUpdateTree(treeDepth, batchSize, nil, nil) + + circuit := createBatchUpdateCircuit(treeDepth, batchSize) + witness := createBatchUpdateWitness(params, 0, batchSize) + + // Corrupt merkle proof + witness.MerkleProofs[0][0] = frontend.Variable(new(big.Int).Add(big.NewInt(0).SetBytes(params.MerkleProofs[0][0].Bytes()), big.NewInt(1))) + + err := test.IsSolved(&circuit, &witness, ecc.BN254.ScalarField()) + assert.Error(err) + }) + t.Run("Invalid LeavesHashchainHash", func(t *testing.T) { treeDepth := 10 batchSize := 5 diff --git a/prover/server/prover/marshal_batch_address_append.go b/prover/server/prover/marshal_batch_address_append.go index e84010dbed..b8c5e23afe 100644 --- a/prover/server/prover/marshal_batch_address_append.go +++ b/prover/server/prover/marshal_batch_address_append.go @@ -12,7 +12,7 @@ type BatchAddressAppendParametersJSON struct { OldRoot string `json:"oldRoot"` NewRoot string `json:"newRoot"` HashchainHash string `json:"hashchainHash"` - StartIndex uint32 `json:"startIndex"` + StartIndex uint64 `json:"startIndex"` LowElementValues []string `json:"lowElementValues"` LowElementIndices []string `json:"lowElementIndices"` LowElementNextIndices []string `json:"lowElementNextIndices"` diff --git a/prover/server/prover/marshal_batch_append_with_proofs.go b/prover/server/prover/marshal_batch_append_with_proofs.go index eea2f13665..7f16597af9 100644 --- a/prover/server/prover/marshal_batch_append_with_proofs.go +++ b/prover/server/prover/marshal_batch_append_with_proofs.go @@ -12,7 +12,7 @@ type BatchAppendWithProofsInputsJSON struct { OldRoot string `json:"oldRoot"` NewRoot string `json:"newRoot"` LeavesHashchainHash string `json:"leavesHashchainHash"` - StartIndex uint32 `json:"startIndex"` + StartIndex uint64 `json:"startIndex"` OldLeaves []string `json:"oldLeaves"` Leaves []string `json:"leaves"` MerkleProofs [][]string `json:"merkleProofs"` diff --git a/prover/server/prover/test_data_helpers.go b/prover/server/prover/test_data_helpers.go index d7064eaf71..fc3828fd85 100644 --- a/prover/server/prover/test_data_helpers.go +++ b/prover/server/prover/test_data_helpers.go @@ -362,11 +362,11 @@ func BuildTestBatchAppendWithProofsTree(treeDepth int, batchSize int, previousTr Height: uint32(treeDepth), BatchSize: uint32(batchSize), Tree: &tree, - StartIndex: uint32(startIndex), + StartIndex: uint64(startIndex), } } -func BuildTestAddressTree(treeHeight uint32, batchSize uint32, previousTree *merkletree.IndexedMerkleTree, startIndex uint32) (*BatchAddressAppendParameters, error) { +func BuildTestAddressTree(treeHeight uint32, batchSize uint32, previousTree *merkletree.IndexedMerkleTree, startIndex uint64) (*BatchAddressAppendParameters, error) { var tree *merkletree.IndexedMerkleTree if previousTree == nil { @@ -409,8 +409,7 @@ func BuildTestAddressTree(treeHeight uint32, batchSize uint32, previousTree *mer newValues := make([]*big.Int, batchSize) for i := uint32(0); i < batchSize; i++ { - newValues[i] = new(big.Int).SetUint64(uint64(startIndex + i + 2)) - + newValues[i] = new(big.Int).SetUint64(startIndex + uint64(i) + 2) lowElementIndex, _ := tree.IndexArray.FindLowElementIndex(newValues[i]) lowElement := tree.IndexArray.Get(lowElementIndex) @@ -427,7 +426,7 @@ func BuildTestAddressTree(treeHeight uint32, batchSize uint32, previousTree *mer return nil, fmt.Errorf("failed to get low element proof: %v", err) } - newIndex := startIndex + i + newIndex := startIndex + uint64(i) if err := tree.Append(newValues[i]); err != nil { return nil, fmt.Errorf("failed to append value: %v", err) @@ -462,12 +461,12 @@ func computeNewElementsHashChain(values []big.Int) *big.Int { return result } -func computePublicInputHash(oldRoot *big.Int, newRoot *big.Int, hashchainHash *big.Int, startIndex uint32) *big.Int { +func computePublicInputHash(oldRoot *big.Int, newRoot *big.Int, hashchainHash *big.Int, startIndex uint64) *big.Int { inputs := []*big.Int{ oldRoot, newRoot, hashchainHash, - big.NewInt(int64(startIndex)), + new(big.Int).SetUint64(startIndex), } return calculateHashChain(inputs, 4)