Skip to content

Commit

Permalink
chore(avm): Handle specific MSM errors (#11068)
Browse files Browse the repository at this point in the history
Resolves #10854
  • Loading branch information
jeanmon authored Jan 6, 2025
1 parent ecbd59e commit a5097a9
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 5 deletions.
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ enum class AvmError : uint32_t {
OUT_OF_GAS,
STATIC_CALL_ALTERATION,
FAILED_BYTECODE_RETRIEVAL,
MSM_POINTS_LEN_INVALID,
MSM_POINT_NOT_ON_CURVE,
};

} // namespace bb::avm_trace
8 changes: 8 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ std::string to_name(AvmError error)
return "TAG CHECKING ERROR";
case AvmError::ADDR_RES_TAG_ERROR:
return "ADDRESS RESOLUTION TAG ERROR";
case AvmError::MEM_SLICE_OUT_OF_RANGE:
return "MEMORY SLICE OUT OF RANGE";
case AvmError::REL_ADDR_OUT_OF_RANGE:
return "RELATIVE ADDRESS IS OUT OF RANGE";
case AvmError::DIV_ZERO:
Expand All @@ -135,8 +137,14 @@ std::string to_name(AvmError error)
return "SIDE EFFECT LIMIT REACHED";
case AvmError::OUT_OF_GAS:
return "OUT OF GAS";
case AvmError::STATIC_CALL_ALTERATION:
return "STATIC CALL ALTERATION";
case AvmError::FAILED_BYTECODE_RETRIEVAL:
return "FAILED BYTECODE RETRIEVAL";
case AvmError::MSM_POINTS_LEN_INVALID:
return "MSM POINTS LEN INVALID";
case AvmError::MSM_POINT_NOT_ON_CURVE:
return "MSM POINT NOT ON CURVE";
default:
throw std::runtime_error("Invalid error type");
break;
Expand Down
9 changes: 9 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4780,6 +4780,11 @@ AvmError AvmTraceBuilder::op_variable_msm(uint8_t indirect,

const FF points_length = is_ok(error) ? unconstrained_read_from_memory(resolved_point_length_offset) : 0;

// Unconstrained check that points_length must be a multiple of 3.
if (is_ok(error) && static_cast<uint32_t>(points_length) % 3 != 0) {
error = AvmError::MSM_POINTS_LEN_INVALID;
}

if (is_ok(error) && !check_slice_mem_range(resolved_points_offset, static_cast<uint32_t>(points_length))) {
error = AvmError::MEM_SLICE_OUT_OF_RANGE;
}
Expand Down Expand Up @@ -4863,6 +4868,10 @@ AvmError AvmTraceBuilder::op_variable_msm(uint8_t indirect,
points.emplace_back(grumpkin::g1::affine_element::infinity());
} else {
points.emplace_back(x, y);
// Unconstrained check that this point lies on the Grumpkin curve.
if (!points.back().on_curve()) {
return AvmError::MSM_POINT_NOT_ON_CURVE;
}
}
}
// Reconstruct Grumpkin scalars
Expand Down
22 changes: 21 additions & 1 deletion yarn-project/simulator/src/avm/errors.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type FailingFunction, type NoirCallStack } from '@aztec/circuit-types';
import { type AztecAddress, type Fr } from '@aztec/circuits.js';
import { type AztecAddress, type Fr, type Point } from '@aztec/circuits.js';

import { ExecutionError } from '../common/errors.js';
import { type AvmContext } from './avm_context.js';
Expand Down Expand Up @@ -128,6 +128,26 @@ export class OutOfGasError extends AvmExecutionError {
}
}

/**
* Error is thrown when the supplied points length is not a multiple of 3. Specific for MSM opcode.
*/
export class MSMPointsLengthError extends AvmExecutionError {
constructor(pointsReadLength: number) {
super(`Points vector length should be a multiple of 3, was ${pointsReadLength}`);
this.name = 'MSMPointsLengthError';
}
}

/**
* Error is thrown when one of the supplied points does not lie on the Grumpkin curve. Specific for MSM opcode.
*/
export class MSMPointNotOnCurveError extends AvmExecutionError {
constructor(point: Point) {
super(`Point ${point.toString()} is not on the curve.`);
this.name = 'MSMPointNotOnCurveError';
}
}

/**
* Error is thrown when a static call attempts to alter some state
*/
Expand Down
53 changes: 52 additions & 1 deletion yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { Fq, Fr } from '@aztec/circuits.js';
import { Fq, Fr, Point } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, type MemoryValue, Uint1, Uint32 } from '../avm_memory_types.js';
import { MSMPointNotOnCurveError, MSMPointsLengthError } from '../errors.js';
import { initContext } from '../fixtures/index.js';
import { MultiScalarMul } from './multi_scalar_mul.js';

Expand Down Expand Up @@ -127,4 +128,54 @@ describe('MultiScalarMul Opcode', () => {

expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});

it('Should throw an error if points length is not a multiple of 3', async () => {
const indirect = 0;

// No need to set up points nor scalars as it is expected to fail before any processing of them.
const pointsReadLength = 17; // Not multiple of 3
const pointsOffset = 0;
const scalarsOffset = 20;
const pointsLengthOffset = 100;
const outputOffset = 120;

context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));

await expect(
new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context),
).rejects.toThrow(MSMPointsLengthError);
});

it('Should throw an error if a point is not on Grumpkin curve', async () => {
const indirect = 0;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, NOT_ON_CURVE]
const points = Array.from({ length: 2 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));
points.push(new Point(new Fr(13), new Fr(14), false));

const scalars = [new Fq(5n), new Fq(3n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.lo), new Field(s.hi)]);
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
const storedPoints: MemoryValue[] = points
.map(p => p.toFields())
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint1(inf.toNumber())]);
const pointsOffset = 0;
context.machineState.memory.setSlice(pointsOffset, storedPoints);
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

await expect(
new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context),
).rejects.toThrow(MSMPointNotOnCurveError);
});
});
6 changes: 3 additions & 3 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, TypeTag, Uint1 } from '../avm_memory_types.js';
import { InstructionExecutionError } from '../errors.js';
import { MSMPointNotOnCurveError, MSMPointsLengthError } from '../errors.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';
Expand Down Expand Up @@ -44,7 +44,7 @@ export class MultiScalarMul extends Instruction {
// Get the size of the unrolled (x, y , inf) points vector
const pointsReadLength = memory.get(pointsLengthOffset).toNumber();
if (pointsReadLength % 3 !== 0) {
throw new InstructionExecutionError(`Points vector offset should be a multiple of 3, was ${pointsReadLength}`);
throw new MSMPointsLengthError(pointsReadLength);
}

// Get the unrolled (x, y, inf) representing the points
Expand Down Expand Up @@ -76,7 +76,7 @@ export class MultiScalarMul extends Instruction {
const isInf = pointsVector[3 * i + 2].toNumber() === 1;
const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr(), isInf);
if (!p.isOnGrumpkin()) {
throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`);
throw new MSMPointNotOnCurveError(p);
}
grumpkinPoints.push(p);
}
Expand Down

0 comments on commit a5097a9

Please sign in to comment.