From f596998c9b959ae2b620a667a35ae99eb7145e37 Mon Sep 17 00:00:00 2001 From: Facundo Date: Thu, 25 Jan 2024 20:51:08 +0000 Subject: [PATCH] feat(avm): tagged memory (#4213) - Added tagged memory model for public VM - Updated opcodes and tests to pass Some things still missing - Checking in/dstTag in most opcodes - Default values for uninitialized memory (had a discussion with @Maddiaa0 for now this is good and might be ok long term but we need to double check) - Of course addressing modes, etc --- .../src/avm/avm_machine_state.ts | 40 +-- .../src/avm/avm_memory_types.test.ts | 22 ++ .../src/avm/avm_memory_types.ts | 276 +++++++++++++++++ .../src/avm/interpreter/interpreter.ts | 4 +- .../src/avm/opcodes/arithmetic.test.ts | 79 +++-- .../src/avm/opcodes/arithmetic.ts | 29 +- .../src/avm/opcodes/bitwise.test.ts | 186 ++++++------ .../acir-simulator/src/avm/opcodes/bitwise.ts | 86 +++--- .../src/avm/opcodes/comparators.ts | 31 +- .../src/avm/opcodes/control_flow.test.ts | 33 +- .../src/avm/opcodes/control_flow.ts | 29 +- .../src/avm/opcodes/instruction.ts | 17 ++ .../src/avm/opcodes/memory.test.ts | 286 ++++++++++++------ .../acir-simulator/src/avm/opcodes/memory.ts | 57 ++-- 14 files changed, 769 insertions(+), 406 deletions(-) create mode 100644 yarn-project/acir-simulator/src/avm/avm_memory_types.test.ts create mode 100644 yarn-project/acir-simulator/src/avm/avm_memory_types.ts diff --git a/yarn-project/acir-simulator/src/avm/avm_machine_state.ts b/yarn-project/acir-simulator/src/avm/avm_machine_state.ts index 53e8599b7c3..bdb9dfcc7c8 100644 --- a/yarn-project/acir-simulator/src/avm/avm_machine_state.ts +++ b/yarn-project/acir-simulator/src/avm/avm_machine_state.ts @@ -1,5 +1,7 @@ import { Fr } from '@aztec/foundation/fields'; +import { TaggedMemory } from './avm_memory_types.js'; + /** * Store's data for an Avm execution frame */ @@ -8,9 +10,8 @@ export class AvmMachineState { public readonly calldata: Fr[]; private returnData: Fr[]; - // TODO: implement tagged memory /** - */ - public memory: Fr[]; + public readonly memory: TaggedMemory; /** * When an internal_call is invoked, the internal call stack is added to with the current pc + 1 @@ -35,7 +36,7 @@ export class AvmMachineState { constructor(calldata: Fr[]) { this.calldata = calldata; this.returnData = []; - this.memory = []; + this.memory = new TaggedMemory(); this.internalCallStack = []; this.pc = 0; @@ -57,37 +58,4 @@ export class AvmMachineState { public getReturnData(): Fr[] { return this.returnData; } - - /** - - * @param offset - - */ - public readMemory(offset: number): Fr { - // TODO: check offset is within bounds - return this.memory[offset] ?? Fr.ZERO; - } - - /** - - * @param offset - - * @param size - - */ - public readMemoryChunk(offset: number, size: number): Fr[] { - // TODO: bounds -> initialise to 0 - return this.memory.slice(offset, offset + size); - } - - /** - - * @param offset - - * @param value - - */ - public writeMemory(offset: number, value: Fr): void { - this.memory[offset] = value; - } - - /** - - * @param offset - - * @param values - - */ - public writeMemoryChunk(offset: number, values: Fr[]): void { - this.memory.splice(offset, values.length, ...values); - } } diff --git a/yarn-project/acir-simulator/src/avm/avm_memory_types.test.ts b/yarn-project/acir-simulator/src/avm/avm_memory_types.test.ts new file mode 100644 index 00000000000..6df75930fe1 --- /dev/null +++ b/yarn-project/acir-simulator/src/avm/avm_memory_types.test.ts @@ -0,0 +1,22 @@ +import { Field, Uint8 } from './avm_memory_types.js'; + +// TODO: complete +describe('Uint8', () => { + it('Unsigned 8 max value', () => { + expect(new Uint8(255).toBigInt()).toEqual(255n); + }); + + it('Unsigned 8 bit add', () => { + expect(new Uint8(50).add(new Uint8(20))).toEqual(new Uint8(70)); + }); + + it('Unsigned 8 bit add wraps', () => { + expect(new Uint8(200).add(new Uint8(100))).toEqual(new Uint8(44)); + }); +}); + +describe('Field', () => { + it('Add correctly without wrapping', () => { + expect(new Field(27).add(new Field(48))).toEqual(new Field(75)); + }); +}); diff --git a/yarn-project/acir-simulator/src/avm/avm_memory_types.ts b/yarn-project/acir-simulator/src/avm/avm_memory_types.ts new file mode 100644 index 00000000000..69ed95c4a66 --- /dev/null +++ b/yarn-project/acir-simulator/src/avm/avm_memory_types.ts @@ -0,0 +1,276 @@ +import { Fr } from '@aztec/foundation/fields'; + +import { strict as assert } from 'assert'; + +export interface MemoryValue { + add(rhs: MemoryValue): MemoryValue; + sub(rhs: MemoryValue): MemoryValue; + mul(rhs: MemoryValue): MemoryValue; + div(rhs: MemoryValue): MemoryValue; + + // Use sparingly. + toBigInt(): bigint; +} + +export interface IntegralValue extends MemoryValue { + shl(rhs: IntegralValue): IntegralValue; + shr(rhs: IntegralValue): IntegralValue; + and(rhs: IntegralValue): IntegralValue; + or(rhs: IntegralValue): IntegralValue; + xor(rhs: IntegralValue): IntegralValue; + not(): IntegralValue; +} + +// TODO: Optimize calculation of mod, etc. Can only do once per class? +abstract class UnsignedInteger implements IntegralValue { + private readonly bitmask: bigint; + private readonly mod: bigint; + + protected constructor(private n: bigint, private bits: bigint) { + assert(bits > 0); + // x % 2^n == x & (2^n - 1) + this.mod = 1n << bits; + this.bitmask = this.mod - 1n; + assert(n < this.mod); + } + + // We need this to be able to build an instance of the subclass + // and not of type UnsignedInteger. + protected abstract build(n: bigint): UnsignedInteger; + + public add(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build((this.n + rhs.n) & this.bitmask); + } + + public sub(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + const res: bigint = this.n - rhs.n; + return this.build(res >= 0 ? res : res + this.mod); + } + + public mul(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build((this.n * rhs.n) & this.bitmask); + } + + public div(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build(this.n / rhs.n); + } + + // No sign extension. + public shr(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + // Note that this.n is > 0 by class invariant. + return this.build(this.n >> rhs.n); + } + + public shl(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build((this.n << rhs.n) & this.bitmask); + } + + public and(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build(this.n & rhs.n); + } + + public or(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build(this.n | rhs.n); + } + + public xor(rhs: UnsignedInteger): UnsignedInteger { + assert(this.bits == rhs.bits); + return this.build(this.n ^ rhs.n); + } + + public not(): UnsignedInteger { + return this.build(~this.n & this.bitmask); + } + + public toBigInt(): bigint { + return this.n; + } + + public equals(rhs: UnsignedInteger) { + return this.bits == rhs.bits && this.toBigInt() == rhs.toBigInt(); + } +} + +export class Uint8 extends UnsignedInteger { + constructor(n: number | bigint) { + super(BigInt(n), 8n); + } + + protected build(n: bigint): Uint8 { + return new Uint8(n); + } +} + +export class Uint16 extends UnsignedInteger { + constructor(n: number | bigint) { + super(BigInt(n), 16n); + } + + protected build(n: bigint): Uint16 { + return new Uint16(n); + } +} + +export class Uint32 extends UnsignedInteger { + constructor(n: number | bigint) { + super(BigInt(n), 32n); + } + + protected build(n: bigint): Uint32 { + return new Uint32(n); + } +} + +export class Uint64 extends UnsignedInteger { + constructor(n: number | bigint) { + super(BigInt(n), 64n); + } + + protected build(n: bigint): Uint64 { + return new Uint64(n); + } +} + +export class Uint128 extends UnsignedInteger { + constructor(n: number | bigint) { + super(BigInt(n), 128n); + } + + protected build(n: bigint): Uint128 { + return new Uint128(n); + } +} + +export class Field implements MemoryValue { + public static readonly MODULUS: bigint = Fr.MODULUS; + private readonly rep: Fr; + + constructor(v: number | bigint | Fr) { + this.rep = new Fr(v); + } + + public add(rhs: Field): Field { + return new Field(this.rep.add(rhs.rep)); + } + + public sub(rhs: Field): Field { + return new Field(this.rep.sub(rhs.rep)); + } + + public mul(rhs: Field): Field { + return new Field(this.rep.mul(rhs.rep)); + } + + public div(rhs: Field): Field { + return new Field(this.rep.div(rhs.rep)); + } + + public toBigInt(): bigint { + return this.rep.toBigInt(); + } +} + +export enum TypeTag { + UNINITIALIZED, + UINT8, + UINT16, + UINT32, + UINT64, + UINT128, + FIELD, + INVALID, +} + +// TODO: Consider automatic conversion when getting undefined values. +export class TaggedMemory { + static readonly MAX_MEMORY_SIZE = 1n << 32n; + private _mem: MemoryValue[]; + + constructor() { + this._mem = []; + } + + public get(offset: number): MemoryValue { + return this.getAs(offset); + } + + public getAs(offset: number): T { + assert(offset < TaggedMemory.MAX_MEMORY_SIZE); + const e = this._mem[offset]; + return e; + } + + public getSlice(offset: number, size: number): MemoryValue[] { + assert(offset < TaggedMemory.MAX_MEMORY_SIZE); + return this._mem.slice(offset, offset + size); + } + + public getSliceTags(offset: number, size: number): TypeTag[] { + assert(offset < TaggedMemory.MAX_MEMORY_SIZE); + return this._mem.slice(offset, offset + size).map(TaggedMemory.getTag); + } + + public set(offset: number, v: MemoryValue) { + assert(offset < TaggedMemory.MAX_MEMORY_SIZE); + this._mem[offset] = v; + } + + public setSlice(offset: number, vs: MemoryValue[]) { + assert(offset < TaggedMemory.MAX_MEMORY_SIZE); + this._mem.splice(offset, vs.length, ...vs); + } + + public getTag(offset: number): TypeTag { + return TaggedMemory.getTag(this._mem[offset]); + } + + // TODO: this might be slow, but I don't want to have the types know of their tags. + // It might be possible to have a map. + public static getTag(v: MemoryValue | undefined): TypeTag { + let tag = TypeTag.INVALID; + + if (v === undefined) { + tag = TypeTag.UNINITIALIZED; + } else if (v instanceof Field) { + tag = TypeTag.FIELD; + } else if (v instanceof Uint8) { + tag = TypeTag.UINT8; + } else if (v instanceof Uint16) { + tag = TypeTag.UINT16; + } else if (v instanceof Uint32) { + tag = TypeTag.UINT32; + } else if (v instanceof Uint64) { + tag = TypeTag.UINT64; + } else if (v instanceof Uint128) { + tag = TypeTag.UINT128; + } + + return tag; + } + + // Truncates the value to fit the type. + public static integralFromTag(v: bigint, tag: TypeTag): IntegralValue { + switch (tag) { + case TypeTag.UINT8: + return new Uint8(v & ((1n << 8n) - 1n)); + case TypeTag.UINT16: + return new Uint16(v & ((1n << 16n) - 1n)); + case TypeTag.UINT32: + return new Uint32(v & ((1n << 32n) - 1n)); + case TypeTag.UINT64: + return new Uint64(v & ((1n << 64n) - 1n)); + case TypeTag.UINT128: + return new Uint128(v & ((1n << 128n) - 1n)); + default: + throw new Error(`${TypeTag[tag]} is not a valid integral type.`); + } + } +} diff --git a/yarn-project/acir-simulator/src/avm/interpreter/interpreter.ts b/yarn-project/acir-simulator/src/avm/interpreter/interpreter.ts index f3970ade43f..55204f88b47 100644 --- a/yarn-project/acir-simulator/src/avm/interpreter/interpreter.ts +++ b/yarn-project/acir-simulator/src/avm/interpreter/interpreter.ts @@ -71,8 +71,8 @@ export class AvmInterpreter { * Avm-specific errors should derive from this */ export abstract class AvmInterpreterError extends Error { - constructor(message: string) { - super(message); + constructor(message: string, ...rest: any[]) { + super(message, ...rest); this.name = 'AvmInterpreterError'; } } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts index 76a802231a3..51059a38854 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.test.ts @@ -1,8 +1,7 @@ -import { Fr } from '@aztec/foundation/fields'; - import { mock } from 'jest-mock-extended'; import { AvmMachineState } from '../avm_machine_state.js'; +import { Field } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { Add, Div, Mul, Sub } from './arithmetic.js'; @@ -16,93 +15,93 @@ describe('Arithmetic Instructions', () => { }); describe('Add', () => { - it('Should add correctly over Fr type', () => { - const a = new Fr(1n); - const b = new Fr(2n); + it('Should add correctly over field elements', () => { + const a = new Field(1n); + const b = new Field(2n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); new Add(0, 1, 2).execute(machineState, stateManager); - const expected = new Fr(3n); - const actual = machineState.readMemory(2); + const expected = new Field(3n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); it('Should wrap around on addition', () => { - const a = new Fr(1n); - const b = new Fr(Fr.MODULUS - 1n); + const a = new Field(1n); + const b = new Field(Field.MODULUS - 1n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); new Add(0, 1, 2).execute(machineState, stateManager); - const expected = new Fr(0n); - const actual = machineState.readMemory(3); + const expected = new Field(0n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); }); describe('Sub', () => { - it('Should subtract correctly over Fr type', () => { - const a = new Fr(1n); - const b = new Fr(2n); + it('Should subtract correctly over field elements', () => { + const a = new Field(1n); + const b = new Field(2n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); new Sub(0, 1, 2).execute(machineState, stateManager); - const expected = new Fr(Fr.MODULUS - 1n); - const actual = machineState.readMemory(2); + const expected = new Field(Field.MODULUS - 1n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); }); describe('Mul', () => { - it('Should multiply correctly over Fr type', () => { - const a = new Fr(2n); - const b = new Fr(3n); + it('Should multiply correctly over field elements', () => { + const a = new Field(2n); + const b = new Field(3n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); new Mul(0, 1, 2).execute(machineState, stateManager); - const expected = new Fr(6n); - const actual = machineState.readMemory(2); + const expected = new Field(6n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); it('Should wrap around on multiplication', () => { - const a = new Fr(2n); - const b = new Fr(Fr.MODULUS / 2n - 1n); + const a = new Field(2n); + const b = new Field(Field.MODULUS / 2n - 1n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); new Mul(0, 1, 2).execute(machineState, stateManager); - const expected = new Fr(Fr.MODULUS - 3n); - const actual = machineState.readMemory(2); + const expected = new Field(Field.MODULUS - 3n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); }); describe('Div', () => { it('Should perform field division', () => { - const a = new Fr(2n); - const b = new Fr(3n); + const a = new Field(2n); + const b = new Field(3n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); new Div(0, 1, 2).execute(machineState, stateManager); // Note - const actual = machineState.readMemory(2); + const actual = machineState.memory.get(2); const recovered = actual.mul(b); expect(recovered).toEqual(a); }); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts index 94df9713c83..6bfc61ff584 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/arithmetic.ts @@ -1,10 +1,7 @@ -import { Fr } from '@aztec/foundation/fields'; - import { AvmMachineState } from '../avm_machine_state.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { Instruction } from './instruction.js'; -/** -*/ export class Add extends Instruction { static type: string = 'ADD'; static numberOfOperands = 3; @@ -14,17 +11,16 @@ export class Add extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a = machineState.readMemory(this.aOffset); - const b = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); const dest = a.add(b); - machineState.writeMemory(this.destOffset, dest); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } } -/** -*/ export class Sub extends Instruction { static type: string = 'SUB'; static numberOfOperands = 3; @@ -34,17 +30,16 @@ export class Sub extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a = machineState.readMemory(this.aOffset); - const b = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); const dest = a.sub(b); - machineState.writeMemory(this.destOffset, dest); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } } -/** -*/ export class Mul extends Instruction { static type: string = 'MUL'; static numberOfOperands = 3; @@ -54,11 +49,11 @@ export class Mul extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); const dest = a.mul(b); - machineState.writeMemory(this.destOffset, dest); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } @@ -74,11 +69,11 @@ export class Div extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); const dest = a.div(b); - machineState.writeMemory(this.destOffset, dest); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/bitwise.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/bitwise.test.ts index e1de7356041..6fe969513fa 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/bitwise.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/bitwise.test.ts @@ -1,17 +1,9 @@ -import { Fr } from '@aztec/foundation/fields'; - import { mock } from 'jest-mock-extended'; import { AvmMachineState } from '../avm_machine_state.js'; +import { TypeTag, Uint16, Uint32 } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; -import { - And, - /*Not,*/ - Or, - Shl, - Shr, - Xor, -} from './bitwise.js'; +import { And, Not, Or, Shl, Shr, Xor } from './bitwise.js'; describe('Bitwise instructions', () => { let machineState: AvmMachineState; @@ -22,145 +14,155 @@ describe('Bitwise instructions', () => { stateManager = mock(); }); - it('Should AND correctly over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(0b11100100111001001111n); - - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + it('Should AND correctly over integral types', () => { + machineState.memory.set(0, new Uint32(0b11111110010011100100n)); + machineState.memory.set(1, new Uint32(0b11100100111001001111n)); - new And(0, 1, 2).execute(machineState, stateManager); + new And(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); - const expected = new Fr(0b11100100010001000100n); - const actual = machineState.readMemory(2); - expect(actual).toEqual(expected); + const actual = machineState.memory.get(2); + expect(actual).toEqual(new Uint32(0b11100100010001000100n)); }); - it('Should OR correctly over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(0b11100100111001001111n); + it('Should OR correctly over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(0b11100100111001001111n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Or(0, 1, 2).execute(machineState, stateManager); + new Or(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); - const expected = new Fr(0b11111110111011101111n); - const actual = machineState.readMemory(2); + const expected = new Uint32(0b11111110111011101111n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); - it('Should XOR correctly over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(0b11100100111001001111n); + it('Should XOR correctly over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(0b11100100111001001111n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Xor(0, 1, 2).execute(machineState, stateManager); + new Xor(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); - const expected = new Fr(0b00011010101010101011n); - const actual = machineState.readMemory(2); + const expected = new Uint32(0b00011010101010101011n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); describe('SHR', () => { - it('Should shift correctly 0 positions over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(0n); + it('Should shift correctly 0 positions over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(0n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Shr(0, 1, 2).execute(machineState, stateManager); + new Shr(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); const expected = a; - const actual = machineState.readMemory(2); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); - it('Should shift correctly 2 positions over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(2n); + it('Should shift correctly 2 positions over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(2n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Shr(0, 1, 2).execute(machineState, stateManager); + new Shr(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); - const expected = new Fr(0b00111111100100111001n); - const actual = machineState.readMemory(2); + const expected = new Uint32(0b00111111100100111001n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); - it('Should shift correctly 19 positions over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(19n); + it('Should shift correctly 19 positions over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(19n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Shr(0, 1, 2).execute(machineState, stateManager); + new Shr(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); - const expected = new Fr(0b01n); - const actual = machineState.readMemory(2); + const expected = new Uint32(0b01n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); }); describe('SHL', () => { - it('Should shift correctly 0 positions over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(0n); + it('Should shift correctly 0 positions over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(0n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Shl(0, 1, 2).execute(machineState, stateManager); + new Shl(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); const expected = a; - const actual = machineState.readMemory(2); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); - it('Should shift correctly 2 positions over Fr type', () => { - const a = new Fr(0b11111110010011100100n); - const b = new Fr(2n); + it('Should shift correctly 2 positions over integral types', () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint32(2n); - machineState.writeMemory(0, a); - machineState.writeMemory(1, b); + machineState.memory.set(0, a); + machineState.memory.set(1, b); - new Shl(0, 1, 2).execute(machineState, stateManager); + new Shl(0, 1, 2, TypeTag.UINT32).execute(machineState, stateManager); - const expected = new Fr(0b1111111001001110010000n); - const actual = machineState.readMemory(2); + const expected = new Uint32(0b1111111001001110010000n); + const actual = machineState.memory.get(2); expect(actual).toEqual(expected); }); - // it('Should shift correctly over bit limit over Fr type', () => { - // const a = new Fr(0b11111110010011100100n); - // const b = new Fr(19n); + it('Should shift correctly over bit limit over integral types', () => { + const a = new Uint16(0b1110010011100111n); + const b = new Uint16(17n); + + machineState.memory.set(0, a); + machineState.memory.set(1, b); - // machineState.writeMemory(0, a); - // machineState.writeMemory(1, b); + new Shl(0, 1, 2, TypeTag.UINT16).execute(machineState, stateManager); + + const expected = new Uint16(0n); + const actual = machineState.memory.get(2); + expect(actual).toEqual(expected); + }); - // new Shl(0, 1, 2).execute(machineState, stateManager); + it('Should truncate when shifting over bit size over integral types', () => { + const a = new Uint16(0b1110010011100111n); + const b = new Uint16(2n); - // const expected = new Fr(0b01n); - // const actual = machineState.readMemory(2); - // expect(actual).toEqual(expected); - // }); + machineState.memory.set(0, a); + machineState.memory.set(1, b); + + new Shl(0, 1, 2, TypeTag.UINT16).execute(machineState, stateManager); + + const expected = new Uint16(0b1001001110011100n); + const actual = machineState.memory.get(2); + expect(actual).toEqual(expected); + }); }); - // it('Should NOT correctly over Fr type', () => { - // const a = new Fr(0b11111110010011100100n); + it('Should NOT correctly over integral types', () => { + const a = new Uint16(0b0110010011100100n); - // machineState.writeMemory(0, a); + machineState.memory.set(0, a); - // new Not(0, 1).execute(machineState, stateManager); + new Not(0, 1, TypeTag.UINT16).execute(machineState, stateManager); - // const expected = new Fr(0b00000001101100011011n); // high bits! - // const actual = machineState.readMemory(1); - // expect(actual).toEqual(expected); - // }); + const expected = new Uint16(0b1001101100011011n); // high bits! + const actual = machineState.memory.get(1); + expect(actual).toEqual(expected); + }); }); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts b/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts index ff7802aac61..e788abcc1f5 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/bitwise.ts @@ -1,130 +1,128 @@ -import { Fr } from '@aztec/foundation/fields'; - import { AvmMachineState } from '../avm_machine_state.js'; +import { IntegralValue, TypeTag } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { Instruction } from './instruction.js'; -/** - */ export class And extends Instruction { static type: string = 'AND'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private aOffset: number, private bOffset: number, private destOffset: number, private inTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); + + const a = machineState.memory.getAs(this.aOffset); + const b = machineState.memory.getAs(this.bOffset); - const dest = new Fr(a.toBigInt() & b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const res = a.and(b); + machineState.memory.set(this.destOffset, res); this.incrementPc(machineState); } } -/** - */ export class Or extends Instruction { static type: string = 'OR'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private aOffset: number, private bOffset: number, private destOffset: number, private inTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); - const dest = new Fr(a.toBigInt() | b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const a = machineState.memory.getAs(this.aOffset); + const b = machineState.memory.getAs(this.bOffset); + + const res = a.or(b); + machineState.memory.set(this.destOffset, res); this.incrementPc(machineState); } } -/** - */ export class Xor extends Instruction { static type: string = 'XOR'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private aOffset: number, private bOffset: number, private destOffset: number, private inTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); + + const a = machineState.memory.getAs(this.aOffset); + const b = machineState.memory.getAs(this.bOffset); - const dest = new Fr(a.toBigInt() ^ b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const res = a.xor(b); + machineState.memory.set(this.destOffset, res); this.incrementPc(machineState); } } -/** - */ export class Not extends Instruction { static type: string = 'NOT'; static numberOfOperands = 2; - constructor(private aOffset: number, private destOffset: number) { + constructor(private aOffset: number, private destOffset: number, private inTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); + Instruction.checkTags(machineState, this.inTag, this.aOffset); - // TODO: hack -> Bitwise operations should not occur over field elements - // It should only work over integers - const result = ~a.toBigInt(); + const a = machineState.memory.getAs(this.aOffset); - const dest = new Fr(result < 0 ? Fr.MODULUS + /* using a + as result is -ve*/ result : result); - machineState.writeMemory(this.destOffset, dest); + const res = a.not(); + machineState.memory.set(this.destOffset, res); this.incrementPc(machineState); } } -/** -*/ export class Shl extends Instruction { static type: string = 'SHL'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private aOffset: number, private bOffset: number, private destOffset: number, private inTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); - const dest = new Fr(a.toBigInt() << b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const a = machineState.memory.getAs(this.aOffset); + const b = machineState.memory.getAs(this.bOffset); + + const res = a.shl(b); + machineState.memory.set(this.destOffset, res); this.incrementPc(machineState); } } -/** -*/ export class Shr extends Instruction { static type: string = 'SHR'; static numberOfOperands = 3; - constructor(private aOffset: number, private bOffset: number, private destOffset: number) { + constructor(private aOffset: number, private bOffset: number, private destOffset: number, private inTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); - - // Here we are assuming that the field element maps to a positive number. - // The >> operator is *signed* in JS (and it sign extends). - // E.g.: -1n >> 3n == -1n. - const dest = new Fr(a.toBigInt() >> b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + Instruction.checkTags(machineState, this.inTag, this.aOffset, this.bOffset); + + const a = machineState.memory.getAs(this.aOffset); + const b = machineState.memory.getAs(this.bOffset); + + const res = a.shr(b); + machineState.memory.set(this.destOffset, res); this.incrementPc(machineState); } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts b/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts index 685a5438c42..b4f65a66f7b 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/comparators.ts @@ -1,10 +1,8 @@ -import { Fr } from '@aztec/foundation/fields'; - import { AvmMachineState } from '../avm_machine_state.js'; +import { Field } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { Instruction } from './instruction.js'; -/** -*/ export class Eq extends Instruction { static type: string = 'EQ'; static numberOfOperands = 3; @@ -14,16 +12,16 @@ export class Eq extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); - const dest = new Fr(a.toBigInt() == b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const dest = new Field(a.toBigInt() == b.toBigInt() ? 1 : 0); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } } -/** -*/ + export class Lt extends Instruction { static type: string = 'Lt'; static numberOfOperands = 3; @@ -33,17 +31,16 @@ export class Lt extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); - const dest = new Fr(a.toBigInt() < b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const dest = new Field(a.toBigInt() < b.toBigInt() ? 1 : 0); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } } -/** -*/ export class Lte extends Instruction { static type: string = 'LTE'; static numberOfOperands = 3; @@ -53,11 +50,11 @@ export class Lte extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a: Fr = machineState.readMemory(this.aOffset); - const b: Fr = machineState.readMemory(this.bOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); - const dest = new Fr(a.toBigInt() < b.toBigInt()); - machineState.writeMemory(this.destOffset, dest); + const dest = new Field(a.toBigInt() < b.toBigInt() ? 1 : 0); + machineState.memory.set(this.destOffset, dest); this.incrementPc(machineState); } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts index 906584b3493..c30825ba912 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/control_flow.test.ts @@ -1,13 +1,13 @@ -import { Fr } from '@aztec/foundation/fields'; - import { mock } from 'jest-mock-extended'; import { AvmMachineState } from '../avm_machine_state.js'; +import { TypeTag, Uint16 } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { Add, Mul, Sub } from './arithmetic.js'; import { And, Not, Or, Shl, Shr, Xor } from './bitwise.js'; import { Eq, Lt, Lte } from './comparators.js'; -import { InternalCall, InternalCallStackEmptyError, InternalReturn, Jump, JumpI } from './control_flow.js'; +import { InternalCall, InternalReturn, Jump, JumpI } from './control_flow.js'; +import { InstructionExecutionError } from './instruction.js'; import { CMov, CalldataCopy, Cast, Mov, Set } from './memory.js'; describe('Control Flow Opcodes', () => { @@ -35,8 +35,8 @@ describe('Control Flow Opcodes', () => { expect(machineState.pc).toBe(0); - machineState.writeMemory(0, new Fr(1n)); - machineState.writeMemory(1, new Fr(2n)); + machineState.memory.set(0, new Uint16(1n)); + machineState.memory.set(1, new Uint16(2n)); const instruction = new JumpI(jumpLocation, 0); instruction.execute(machineState, stateManager); @@ -53,7 +53,7 @@ describe('Control Flow Opcodes', () => { expect(machineState.pc).toBe(0); - machineState.writeMemory(0, new Fr(0n)); + machineState.memory.set(0, new Uint16(0n)); const instruction = new JumpI(jumpLocation, 0); instruction.execute(machineState, stateManager); @@ -112,7 +112,7 @@ describe('Control Flow Opcodes', () => { it('Should error if Internal Return is called without a corresponding Internal Call', () => { const returnInstruction = new InternalReturn(); - expect(() => returnInstruction.execute(machineState, stateManager)).toThrow(InternalCallStackEmptyError); + expect(() => returnInstruction.execute(machineState, stateManager)).toThrow(InstructionExecutionError); }); it('Should increment PC on All other Instructions', () => { @@ -123,22 +123,25 @@ describe('Control Flow Opcodes', () => { new Lt(0, 1, 2), new Lte(0, 1, 2), new Eq(0, 1, 2), - new Xor(0, 1, 2), - new And(0, 1, 2), - new Or(0, 1, 2), - new Shl(0, 1, 2), - new Shr(0, 1, 2), - new Not(0, 2), + new Xor(0, 1, 2, TypeTag.UINT16), + new And(0, 1, 2, TypeTag.UINT16), + new Or(0, 1, 2, TypeTag.UINT16), + new Shl(0, 1, 2, TypeTag.UINT16), + new Shr(0, 1, 2, TypeTag.UINT16), + new Not(0, 2, TypeTag.UINT16), new CalldataCopy(0, 1, 2), - new Set(0n, 1), + new Set(0n, 1, TypeTag.UINT16), new Mov(0, 1), new CMov(0, 1, 2, 3), - new Cast(0, 1), + new Cast(0, 1, TypeTag.UINT16), ]; for (const instruction of instructions) { // Use a fresh machine state each run const innerMachineState = new AvmMachineState([]); + innerMachineState.memory.set(0, new Uint16(4n)); + innerMachineState.memory.set(1, new Uint16(8n)); + innerMachineState.memory.set(2, new Uint16(12n)); expect(machineState.pc).toBe(0); instruction.execute(innerMachineState, stateManager); expect(innerMachineState.pc).toBe(1); diff --git a/yarn-project/acir-simulator/src/avm/opcodes/control_flow.ts b/yarn-project/acir-simulator/src/avm/opcodes/control_flow.ts index b988979076e..7967a417c8a 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/control_flow.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/control_flow.ts @@ -1,8 +1,10 @@ +import { Fr } from '@aztec/foundation/fields'; + import { AvmMachineState } from '../avm_machine_state.js'; +import { IntegralValue } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; -import { Instruction } from './instruction.js'; +import { Instruction, InstructionExecutionError } from './instruction.js'; -/** - */ export class Return extends Instruction { static type: string = 'RETURN'; static numberOfOperands = 2; @@ -12,14 +14,16 @@ export class Return extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const returnData = machineState.readMemoryChunk(this.returnOffset, this.returnOffset + this.copySize); + const returnData = machineState.memory + .getSlice(this.returnOffset, this.copySize) + .map(fvt => new Fr(fvt.toBigInt())); + machineState.setReturnData(returnData); this.halt(machineState); } } -/** -*/ export class Jump extends Instruction { static type: string = 'JUMP'; static numberOfOperands = 1; @@ -33,7 +37,6 @@ export class Jump extends Instruction { } } -/** -*/ export class JumpI extends Instruction { static type: string = 'JUMPI'; static numberOfOperands = 1; @@ -43,8 +46,9 @@ export class JumpI extends Instruction { } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const condition = machineState.readMemory(this.condOffset); + const condition = machineState.memory.getAs(this.condOffset); + // TODO: reconsider this casting if (condition.toBigInt() == 0n) { this.incrementPc(machineState); } else { @@ -53,7 +57,6 @@ export class JumpI extends Instruction { } } -/** -*/ export class InternalCall extends Instruction { static type: string = 'INTERNALCALL'; static numberOfOperands = 1; @@ -68,7 +71,6 @@ export class InternalCall extends Instruction { } } -/** -*/ export class InternalReturn extends Instruction { static type: string = 'INTERNALRETURN'; static numberOfOperands = 0; @@ -80,17 +82,8 @@ export class InternalReturn extends Instruction { execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { const jumpOffset = machineState.internalCallStack.pop(); if (jumpOffset === undefined) { - throw new InternalCallStackEmptyError(); + throw new InstructionExecutionError('Internal call empty!'); } machineState.pc = jumpOffset; } } - -/** - * Thrown if the internal call stack is popped when it is empty - */ -export class InternalCallStackEmptyError extends Error { - constructor() { - super('Internal call stack is empty'); - } -} diff --git a/yarn-project/acir-simulator/src/avm/opcodes/instruction.ts b/yarn-project/acir-simulator/src/avm/opcodes/instruction.ts index c77a1a9ddd3..766e5ee158e 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/instruction.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/instruction.ts @@ -1,4 +1,5 @@ import { AvmMachineState } from '../avm_machine_state.js'; +import { TypeTag } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; export const AVM_OPERAND_BYTE_LENGTH = 4; @@ -17,4 +18,20 @@ export abstract class Instruction { halt(machineState: AvmMachineState): void { machineState.halted = true; } + + static checkTags(machineState: AvmMachineState, tag: TypeTag, ...offsets: number[]) { + for (const off of offsets) { + if (machineState.memory.getTag(off) !== tag) { + const error = `Offset ${off} has tag ${TypeTag[machineState.memory.getTag(off)]}, expected ${TypeTag[tag]}`; + throw new InstructionExecutionError(error); + } + } + } +} + +export class InstructionExecutionError extends Error { + constructor(message: string) { + super(message); + this.name = 'InstructionExecutionError'; + } } diff --git a/yarn-project/acir-simulator/src/avm/opcodes/memory.test.ts b/yarn-project/acir-simulator/src/avm/opcodes/memory.test.ts index 3b02ef36c32..1c8b49ea678 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/memory.test.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/memory.test.ts @@ -3,6 +3,7 @@ import { Fr } from '@aztec/foundation/fields'; import { mock } from 'jest-mock-extended'; import { AvmMachineState } from '../avm_machine_state.js'; +import { Field, TypeTag, Uint8, Uint16, Uint32, Uint64, Uint128 } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { CMov, CalldataCopy, Cast, Mov, Set } from './memory.js'; @@ -15,166 +16,261 @@ describe('Memory instructions', () => { stateManager = mock(); }); - it('Should SET memory correctly', () => { - const value = 123456n; + describe('SET', () => { + it('should correctly set value and tag (uninitialized)', () => { + new Set(/*value=*/ 1234n, /*offset=*/ 1, TypeTag.UINT16).execute(machineState, stateManager); - new Set(value, 1).execute(machineState, stateManager); + const actual = machineState.memory.get(1); + const tag = machineState.memory.getTag(1); - const expected = new Fr(value); - const actual = machineState.readMemory(1); - expect(actual).toEqual(expected); + expect(actual).toEqual(new Uint16(1234n)); + expect(tag).toEqual(TypeTag.UINT16); + }); + + it('should correctly set value and tag (overwriting)', () => { + machineState.memory.set(1, new Field(27)); + + new Set(/*value=*/ 1234n, /*offset=*/ 1, TypeTag.UINT32).execute(machineState, stateManager); + + const actual = machineState.memory.get(1); + const tag = machineState.memory.getTag(1); + + expect(actual).toEqual(new Uint32(1234n)); + expect(tag).toEqual(TypeTag.UINT32); + }); }); - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/3987): tags are not implemented yet - this will behave as a mov describe('CAST', () => { - it('Should work correctly on different memory cells', () => { - const value = new Fr(123456n); - - machineState.writeMemory(0, value); + it('Should upcast between integral types', () => { + machineState.memory.set(0, new Uint8(20n)); + machineState.memory.set(1, new Uint16(65000n)); + machineState.memory.set(2, new Uint32(1n << 30n)); + machineState.memory.set(3, new Uint64(1n << 50n)); + machineState.memory.set(4, new Uint128(1n << 100n)); + + [ + new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 10, TypeTag.UINT16), + new Cast(/*aOffset=*/ 1, /*dstOffset=*/ 11, TypeTag.UINT32), + new Cast(/*aOffset=*/ 2, /*dstOffset=*/ 12, TypeTag.UINT64), + new Cast(/*aOffset=*/ 3, /*dstOffset=*/ 13, TypeTag.UINT128), + new Cast(/*aOffset=*/ 4, /*dstOffset=*/ 14, TypeTag.UINT128), + ].forEach(i => i.execute(machineState, stateManager)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 5); + expect(actual).toEqual([ + new Uint16(20n), + new Uint32(65000n), + new Uint64(1n << 30n), + new Uint128(1n << 50n), + new Uint128(1n << 100n), + ]); + const tags = machineState.memory.getSliceTags(/*offset=*/ 10, /*size=*/ 5); + expect(tags).toEqual([TypeTag.UINT16, TypeTag.UINT32, TypeTag.UINT64, TypeTag.UINT128, TypeTag.UINT128]); + }); - new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 1).execute(machineState, stateManager); + it('Should downcast (truncating) between integral types', () => { + machineState.memory.set(0, new Uint8(20n)); + machineState.memory.set(1, new Uint16(65000n)); + machineState.memory.set(2, new Uint32((1n << 30n) - 1n)); + machineState.memory.set(3, new Uint64((1n << 50n) - 1n)); + machineState.memory.set(4, new Uint128((1n << 100n) - 1n)); + + [ + new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 10, TypeTag.UINT8), + new Cast(/*aOffset=*/ 1, /*dstOffset=*/ 11, TypeTag.UINT8), + new Cast(/*aOffset=*/ 2, /*dstOffset=*/ 12, TypeTag.UINT16), + new Cast(/*aOffset=*/ 3, /*dstOffset=*/ 13, TypeTag.UINT32), + new Cast(/*aOffset=*/ 4, /*dstOffset=*/ 14, TypeTag.UINT64), + ].forEach(i => i.execute(machineState, stateManager)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 5); + expect(actual).toEqual([ + new Uint8(20n), + new Uint8(232), + new Uint16((1n << 16n) - 1n), + new Uint32((1n << 32n) - 1n), + new Uint64((1n << 64n) - 1n), + ]); + const tags = machineState.memory.getSliceTags(/*offset=*/ 10, /*size=*/ 5); + expect(tags).toEqual([TypeTag.UINT8, TypeTag.UINT8, TypeTag.UINT16, TypeTag.UINT32, TypeTag.UINT64]); + }); - const actual = machineState.readMemory(1); - expect(actual).toEqual(value); + it('Should upcast from integral types to field', () => { + machineState.memory.set(0, new Uint8(20n)); + machineState.memory.set(1, new Uint16(65000n)); + machineState.memory.set(2, new Uint32(1n << 30n)); + machineState.memory.set(3, new Uint64(1n << 50n)); + machineState.memory.set(4, new Uint128(1n << 100n)); + + [ + new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 10, TypeTag.FIELD), + new Cast(/*aOffset=*/ 1, /*dstOffset=*/ 11, TypeTag.FIELD), + new Cast(/*aOffset=*/ 2, /*dstOffset=*/ 12, TypeTag.FIELD), + new Cast(/*aOffset=*/ 3, /*dstOffset=*/ 13, TypeTag.FIELD), + new Cast(/*aOffset=*/ 4, /*dstOffset=*/ 14, TypeTag.FIELD), + ].forEach(i => i.execute(machineState, stateManager)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 5); + expect(actual).toEqual([ + new Field(20n), + new Field(65000n), + new Field(1n << 30n), + new Field(1n << 50n), + new Field(1n << 100n), + ]); + const tags = machineState.memory.getSliceTags(/*offset=*/ 10, /*size=*/ 5); + expect(tags).toEqual([TypeTag.FIELD, TypeTag.FIELD, TypeTag.FIELD, TypeTag.FIELD, TypeTag.FIELD]); }); - it('Should work correctly on same memory cell', () => { - const value = new Fr(123456n); + it('Should downcast (truncating) from field to integral types', () => { + machineState.memory.set(0, new Field((1n << 200n) - 1n)); + machineState.memory.set(1, new Field((1n << 200n) - 1n)); + machineState.memory.set(2, new Field((1n << 200n) - 1n)); + machineState.memory.set(3, new Field((1n << 200n) - 1n)); + machineState.memory.set(4, new Field((1n << 200n) - 1n)); + + [ + new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 10, TypeTag.UINT8), + new Cast(/*aOffset=*/ 1, /*dstOffset=*/ 11, TypeTag.UINT16), + new Cast(/*aOffset=*/ 2, /*dstOffset=*/ 12, TypeTag.UINT32), + new Cast(/*aOffset=*/ 3, /*dstOffset=*/ 13, TypeTag.UINT64), + new Cast(/*aOffset=*/ 4, /*dstOffset=*/ 14, TypeTag.UINT128), + ].forEach(i => i.execute(machineState, stateManager)); + + const actual = machineState.memory.getSlice(/*offset=*/ 10, /*size=*/ 5); + expect(actual).toEqual([ + new Uint8((1n << 8n) - 1n), + new Uint16((1n << 16n) - 1n), + new Uint32((1n << 32n) - 1n), + new Uint64((1n << 64n) - 1n), + new Uint128((1n << 128n) - 1n), + ]); + const tags = machineState.memory.getSliceTags(/*offset=*/ 10, /*size=*/ 5); + expect(tags).toEqual([TypeTag.UINT8, TypeTag.UINT16, TypeTag.UINT32, TypeTag.UINT64, TypeTag.UINT128]); + }); - machineState.writeMemory(0, value); + it('Should cast between field elements', () => { + machineState.memory.set(0, new Field(12345678n)); - new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 0).execute(machineState, stateManager); + new Cast(/*aOffset=*/ 0, /*dstOffset=*/ 1, TypeTag.FIELD).execute(machineState, stateManager); - const actual = machineState.readMemory(0); - expect(actual).toEqual(value); + const actual = machineState.memory.get(1); + expect(actual).toEqual(new Field(12345678n)); + const tags = machineState.memory.getTag(1); + expect(tags).toEqual(TypeTag.FIELD); }); }); describe('MOV', () => { - it('Should work correctly on different memory cells', () => { - const value = new Fr(123456n); - - machineState.writeMemory(0, value); + it('Should move integrals on different memory cells', () => { + machineState.memory.set(1, new Uint16(27)); + new Mov(/*offsetA=*/ 1, /*offsetA=*/ 2).execute(machineState, stateManager); - new Mov(/*aOffset=*/ 0, /*dstOffset=*/ 1).execute(machineState, stateManager); + const actual = machineState.memory.get(2); + const tag = machineState.memory.getTag(2); - const actual = machineState.readMemory(1); - expect(actual).toEqual(value); + expect(actual).toEqual(new Uint16(27n)); + expect(tag).toEqual(TypeTag.UINT16); }); - it('Should work correctly on same memory cell', () => { - const value = new Fr(123456n); + it('Should move field elements on different memory cells', () => { + machineState.memory.set(1, new Field(27)); + new Mov(/*offsetA=*/ 1, /*offsetA=*/ 2).execute(machineState, stateManager); - machineState.writeMemory(0, value); + const actual = machineState.memory.get(2); + const tag = machineState.memory.getTag(2); - new Mov(/*aOffset=*/ 0, /*dstOffset=*/ 0).execute(machineState, stateManager); - - const actual = machineState.readMemory(0); - expect(actual).toEqual(value); + expect(actual).toEqual(new Field(27n)); + expect(tag).toEqual(TypeTag.FIELD); }); }); - describe('MOV', () => { - it('Should move A if COND is true, on different memory cells', () => { - const valueA = new Fr(123456n); - const valueB = new Fr(80n); - const valueCondition = new Fr(22n); - - machineState.writeMemory(0, valueA); - machineState.writeMemory(1, valueB); - machineState.writeMemory(2, valueCondition); + describe('CMOV', () => { + it('Should move A if COND is true, on different memory cells (integral condition)', () => { + machineState.memory.set(0, new Uint32(123)); // A + machineState.memory.set(1, new Uint16(456)); // B + machineState.memory.set(2, new Uint8(2)); // Condition new CMov(/*aOffset=*/ 0, /*bOffset=*/ 1, /*condOffset=*/ 2, /*dstOffset=*/ 3).execute(machineState, stateManager); - const actual = machineState.readMemory(3); - expect(actual).toEqual(valueA); + const actual = machineState.memory.get(3); + const tag = machineState.memory.getTag(3); + expect(actual).toEqual(new Uint32(123)); + expect(tag).toEqual(TypeTag.UINT32); }); - it('Should move B if COND is false, on different memory cells', () => { - const valueA = new Fr(123456n); - const valueB = new Fr(80n); - const valueCondition = new Fr(0n); - - machineState.writeMemory(0, valueA); - machineState.writeMemory(1, valueB); - machineState.writeMemory(2, valueCondition); + it('Should move B if COND is false, on different memory cells (integral condition)', () => { + machineState.memory.set(0, new Uint32(123)); // A + machineState.memory.set(1, new Uint16(456)); // B + machineState.memory.set(2, new Uint8(0)); // Condition new CMov(/*aOffset=*/ 0, /*bOffset=*/ 1, /*condOffset=*/ 2, /*dstOffset=*/ 3).execute(machineState, stateManager); - const actual = machineState.readMemory(3); - expect(actual).toEqual(valueB); + const actual = machineState.memory.get(3); + const tag = machineState.memory.getTag(3); + expect(actual).toEqual(new Uint16(456)); + expect(tag).toEqual(TypeTag.UINT16); }); - it('Should move A if COND is true, on overlapping memory cells', () => { - const valueA = new Fr(123456n); - const valueB = new Fr(80n); - const valueCondition = new Fr(22n); - - machineState.writeMemory(0, valueA); - machineState.writeMemory(1, valueB); - machineState.writeMemory(2, valueCondition); + it('Should move A if COND is true, on different memory cells (field condition)', () => { + machineState.memory.set(0, new Uint32(123)); // A + machineState.memory.set(1, new Uint16(456)); // B + machineState.memory.set(2, new Field(1)); // Condition - new CMov(/*aOffset=*/ 0, /*bOffset=*/ 1, /*condOffset=*/ 2, /*dstOffset=*/ 2).execute(machineState, stateManager); + new CMov(/*aOffset=*/ 0, /*bOffset=*/ 1, /*condOffset=*/ 2, /*dstOffset=*/ 3).execute(machineState, stateManager); - const actual = machineState.readMemory(2); - expect(actual).toEqual(valueA); + const actual = machineState.memory.get(3); + const tag = machineState.memory.getTag(3); + expect(actual).toEqual(new Uint32(123)); + expect(tag).toEqual(TypeTag.UINT32); }); - it('Should move B if COND is false, on overlapping memory cells', () => { - const valueA = new Fr(123456n); - const valueB = new Fr(80n); - const valueCondition = new Fr(0n); + it('Should move B if COND is false, on different memory cells (integral condition)', () => { + machineState.memory.set(0, new Uint32(123)); // A + machineState.memory.set(1, new Uint16(456)); // B + machineState.memory.set(2, new Field(0)); // Condition - machineState.writeMemory(0, valueA); - machineState.writeMemory(1, valueB); - machineState.writeMemory(2, valueCondition); - - new CMov(/*aOffset=*/ 0, /*bOffset=*/ 1, /*condOffset=*/ 2, /*dstOffset=*/ 2).execute(machineState, stateManager); + new CMov(/*aOffset=*/ 0, /*bOffset=*/ 1, /*condOffset=*/ 2, /*dstOffset=*/ 3).execute(machineState, stateManager); - const actual = machineState.readMemory(2); - expect(actual).toEqual(valueB); + const actual = machineState.memory.get(3); + const tag = machineState.memory.getTag(3); + expect(actual).toEqual(new Uint16(456)); + expect(tag).toEqual(TypeTag.UINT16); }); }); - describe('CALLDATA', () => { + describe('CALLDATACOPY', () => { it('Writes nothing if size is 0', () => { - const previousValue = new Fr(123456n); const calldata = [new Fr(1n), new Fr(2n), new Fr(3n)]; - machineState = new AvmMachineState(calldata); - machineState.writeMemory(0, previousValue); + machineState.memory.set(0, new Uint16(12)); // Some previous data to be overwritten - new CalldataCopy(/*cdOffset=*/ 2, /*copySize=*/ 0, /*dstOffset=*/ 0).execute(machineState, stateManager); + new CalldataCopy(/*cdOffset=*/ 0, /*copySize=*/ 0, /*dstOffset=*/ 0).execute(machineState, stateManager); - const actual = machineState.readMemory(0); - expect(actual).toEqual(previousValue); + const actual = machineState.memory.get(0); + expect(actual).toEqual(new Uint16(12)); }); it('Copies all calldata', () => { - const previousValue = new Fr(123456n); const calldata = [new Fr(1n), new Fr(2n), new Fr(3n)]; - machineState = new AvmMachineState(calldata); - machineState.writeMemory(0, previousValue); + machineState.memory.set(0, new Uint16(12)); // Some previous data to be overwritten new CalldataCopy(/*cdOffset=*/ 0, /*copySize=*/ 3, /*dstOffset=*/ 0).execute(machineState, stateManager); - const actual = machineState.readMemoryChunk(/*offset=*/ 0, /*size=*/ 3); - expect(actual).toEqual(calldata); + const actual = machineState.memory.getSlice(/*offset=*/ 0, /*size=*/ 3); + expect(actual).toEqual([new Field(1), new Field(2), new Field(3)]); }); it('Copies slice of calldata', () => { - const previousValue = new Fr(123456n); const calldata = [new Fr(1n), new Fr(2n), new Fr(3n)]; - machineState = new AvmMachineState(calldata); - machineState.writeMemory(0, previousValue); + machineState.memory.set(0, new Uint16(12)); // Some previous data to be overwritten new CalldataCopy(/*cdOffset=*/ 1, /*copySize=*/ 2, /*dstOffset=*/ 0).execute(machineState, stateManager); - const expected = calldata.slice(1); - const actual = machineState.readMemoryChunk(/*offset=*/ 0, /*size=*/ 2); - expect(actual).toEqual(expected); + const actual = machineState.memory.getSlice(/*offset=*/ 0, /*size=*/ 2); + expect(actual).toEqual([new Field(2), new Field(3)]); }); // TODO: check bad cases (i.e., out of bounds) diff --git a/yarn-project/acir-simulator/src/avm/opcodes/memory.ts b/yarn-project/acir-simulator/src/avm/opcodes/memory.ts index 3e6bbc62afd..cd546a448c6 100644 --- a/yarn-project/acir-simulator/src/avm/opcodes/memory.ts +++ b/yarn-project/acir-simulator/src/avm/opcodes/memory.ts @@ -1,99 +1,96 @@ -import { Fr } from '@aztec/foundation/fields'; - import { AvmMachineState } from '../avm_machine_state.js'; +import { Field, TaggedMemory, TypeTag } from '../avm_memory_types.js'; import { AvmStateManager } from '../avm_state_manager.js'; import { Instruction } from './instruction.js'; -/** - */ export class Set extends Instruction { static type: string = 'SET'; static numberOfOperands = 2; - constructor(private value: bigint, private destOffset: number) { + constructor(private value: bigint, private dstOffset: number, private dstTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - machineState.writeMemory(this.destOffset, new Fr(this.value)); + const res = TaggedMemory.integralFromTag(this.value, this.dstTag); + + machineState.memory.set(this.dstOffset, res); this.incrementPc(machineState); } } -// TODO(https://github.com/AztecProtocol/aztec-packages/issues/3987): tags are not implemented yet - this will behave as a mov -/** - */ export class Cast extends Instruction { static type: string = 'CAST'; static numberOfOperands = 2; - constructor(private aOffset: number, private destOffset: number) { + constructor(private aOffset: number, private dstOffset: number, private dstTag: TypeTag) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a = machineState.readMemory(this.aOffset); + const a = machineState.memory.get(this.aOffset); - machineState.writeMemory(this.destOffset, a); + // TODO: consider not using toBigInt() + const casted = + this.dstTag == TypeTag.FIELD ? new Field(a.toBigInt()) : TaggedMemory.integralFromTag(a.toBigInt(), this.dstTag); + + machineState.memory.set(this.dstOffset, casted); this.incrementPc(machineState); } } -/** - */ export class Mov extends Instruction { static type: string = 'MOV'; static numberOfOperands = 2; - constructor(private aOffset: number, private destOffset: number) { + constructor(private aOffset: number, private dstOffset: number) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a = machineState.readMemory(this.aOffset); + const a = machineState.memory.get(this.aOffset); - machineState.writeMemory(this.destOffset, a); + machineState.memory.set(this.dstOffset, a); this.incrementPc(machineState); } } -/** - */ export class CMov extends Instruction { - static type: string = 'MOV'; + static type: string = 'CMOV'; static numberOfOperands = 4; - constructor( - private aOffset: number, - private bOffset: number, - private condOffset: number, - private destOffset: number, - ) { + constructor(private aOffset: number, private bOffset: number, private condOffset: number, private dstOffset: number) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const a = machineState.readMemory(this.aOffset); - const b = machineState.readMemory(this.bOffset); - const cond = machineState.readMemory(this.condOffset); + const a = machineState.memory.get(this.aOffset); + const b = machineState.memory.get(this.bOffset); + const cond = machineState.memory.get(this.condOffset); - machineState.writeMemory(this.destOffset, cond.toBigInt() ? a : b); + // TODO: reconsider toBigInt() here + machineState.memory.set(this.dstOffset, cond.toBigInt() > 0 ? a : b); this.incrementPc(machineState); } } -/** - */ export class CalldataCopy extends Instruction { static type: string = 'CALLDATACOPY'; static numberOfOperands = 3; - constructor(private cdOffset: number, private copySize: number, private destOffset: number) { + constructor(private cdOffset: number, private copySize: number, private dstOffset: number) { super(); } execute(machineState: AvmMachineState, _stateManager: AvmStateManager): void { - const calldata = machineState.calldata.slice(this.cdOffset, this.cdOffset + this.copySize); - machineState.writeMemoryChunk(this.destOffset, calldata); + const transformedData = machineState.calldata + .slice(this.cdOffset, this.cdOffset + this.copySize) + .map(f => new Field(f)); + machineState.memory.setSlice(this.dstOffset, transformedData); this.incrementPc(machineState); }