Skip to content

Commit

Permalink
feat: Point::fromXandSign(...) (#7455)
Browse files Browse the repository at this point in the history
  • Loading branch information
benesjan authored Jul 15, 2024
1 parent a3f6feb commit 225c6f6
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 4 deletions.
19 changes: 19 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "../bn254/fr.hpp"
#include "barretenberg/common/wasm_export.hpp"

using namespace bb;

WASM_EXPORT void bn254_fr_sqrt(uint8_t const* input, uint8_t* result)
{
using serialize::write;
auto input_fr = from_buffer<bb::fr>(input);
auto [is_sqr, root] = input_fr.sqrt();

uint8_t* is_sqrt_result_ptr = result;
uint8_t* root_result_ptr = result + 1;

write(is_sqrt_result_ptr, is_sqr);
write(root_result_ptr, root);
}

// NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier)
9 changes: 9 additions & 0 deletions yarn-project/foundation/src/crypto/random/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@ export const randomBigInt = (max: bigint) => {
const randomBigInt = BigInt(`0x${randomBuffer.toString('hex')}`); // Convert buffer to a large integer.
return randomBigInt % max; // Use modulo to ensure the result is less than max.
};

/**
* Generate a random boolean value.
* @returns A random boolean value.
*/
export const randomBoolean = () => {
const randomByte = randomBytes(1)[0]; // Generate a single random byte.
return randomByte % 2 === 0; // Use modulo to determine if the byte is even or odd.
};
26 changes: 25 additions & 1 deletion yarn-project/foundation/src/fields/fields.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ describe('Bn254 arithmetic', () => {
expect(actual).toEqual(expected);
});

it('High Bonudary', () => {
it('High Boundary', () => {
// -1 - (-1) = 0
const a = new Fr(Fr.MODULUS - 1n);
const b = new Fr(Fr.MODULUS - 1n);
Expand Down Expand Up @@ -184,6 +184,30 @@ describe('Bn254 arithmetic', () => {
});
});

describe('Square root', () => {
it.each([
[new Fr(0), 0n],
[new Fr(4), 2n],
[new Fr(9), 3n],
[new Fr(16), 4n],
])('Should return the correct square root for %p', (input, expected) => {
const actual = input.sqrt()!.toBigInt();

// The square root can be either the expected value or the modulus - expected value
const isValid = actual == expected || actual == Fr.MODULUS - expected;

expect(isValid).toBeTruthy();
});

it('Should return the correct square root for random value', () => {
const a = Fr.random();
const squared = a.mul(a);

const actual = squared.sqrt();
expect(actual!.mul(actual!)).toEqual(squared);
});
});

describe('Comparison', () => {
it.each([
[new Fr(5), new Fr(10), -1],
Expand Down
21 changes: 21 additions & 0 deletions yarn-project/foundation/src/fields/fields.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { BarretenbergSync } from '@aztec/bb.js';

import { inspect } from 'util';

import { toBigIntBE, toBufferBE } from '../bigint-buffer/index.js';
Expand Down Expand Up @@ -280,6 +282,25 @@ export class Fr extends BaseField {
return new Fr(this.toBigInt() / rhs.toBigInt());
}

/**
* Computes a square root of the field element.
* @returns A square root of the field element (null if it does not exist).
*/
sqrt(): Fr | null {
const wasm = BarretenbergSync.getSingleton().getWasm();
wasm.writeMemory(0, this.toBuffer());
wasm.call('bn254_fr_sqrt', 0, Fr.SIZE_IN_BYTES);
const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES + 1));
const isSqrt = isSqrtBuf[0] === 1;
if (!isSqrt) {
// Field element is not a quadratic residue mod p so it has no square root.
return null;
}

const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1));
return Fr.fromBuffer(rootBuf);
}

toJSON() {
return {
type: 'Fr',
Expand Down
35 changes: 35 additions & 0 deletions yarn-project/foundation/src/fields/point.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { Fr } from './fields.js';
import { Point } from './point.js';

describe('Point', () => {
it('converts to and from x and sign of y coordinate', () => {
const p = new Point(
new Fr(0x30426e64aee30e998c13c8ceecda3a77807dbead52bc2f3bf0eae851b4b710c1n),
new Fr(0x113156a068f603023240c96b4da5474667db3b8711c521c748212a15bc034ea6n),
false,
);

const [x, sign] = p.toXAndSign();
const p2 = Point.fromXAndSign(x, sign);

expect(p.equals(p2)).toBeTruthy();
});

it('creates a valid random point', () => {
expect(Point.random().isOnGrumpkin()).toBeTruthy();
});

it('converts to and from buffer', () => {
const p = Point.random();
const p2 = Point.fromBuffer(p.toBuffer());

expect(p.equals(p2)).toBeTruthy();
});

it('converts to and from compressed buffer', () => {
const p = Point.random();
const p2 = Point.fromCompressedBuffer(p.toCompressedBuffer());

expect(p.equals(p2)).toBeTruthy();
});
});
83 changes: 80 additions & 3 deletions yarn-project/foundation/src/fields/point.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { poseidon2Hash } from '../crypto/index.js';
import { poseidon2Hash, randomBoolean } from '../crypto/index.js';
import { BufferReader, FieldReader, serializeToBuffer } from '../serialize/index.js';
import { Fr } from './fields.js';

Expand All @@ -10,6 +10,7 @@ import { Fr } from './fields.js';
export class Point {
static ZERO = new Point(Fr.ZERO, Fr.ZERO, false);
static SIZE_IN_BYTES = Fr.SIZE_IN_BYTES * 2;
static COMPRESSED_SIZE_IN_BYTES = Fr.SIZE_IN_BYTES + 1;

/** Used to differentiate this class from AztecAddress */
public readonly kind = 'point';
Expand Down Expand Up @@ -37,8 +38,17 @@ export class Point {
* @returns A randomly generated Point instance.
*/
static random() {
// TODO make this return an actual point on curve.
return new Point(Fr.random(), Fr.random(), false);
while (true) {
try {
return Point.fromXAndSign(Fr.random(), randomBoolean());
} catch (e: any) {
if (!(e instanceof NotOnCurveError)) {
throw e;
}
// The random point is not on the curve - we try again
continue;
}
}
}

/**
Expand All @@ -53,6 +63,18 @@ export class Point {
return new this(Fr.fromBuffer(reader), Fr.fromBuffer(reader), false);
}

/**
* Create a Point instance from a compressed buffer.
* The input 'buffer' should have exactly 33 bytes representing the x coordinate and the sign of the y coordinate.
*
* @param buffer - The buffer containing the x coordinate and the sign of the y coordinate.
* @returns A Point instance.
*/
static fromCompressedBuffer(buffer: Buffer | BufferReader) {
const reader = BufferReader.asReader(buffer);
return this.fromXAndSign(Fr.fromBuffer(reader), reader.readBoolean());
}

/**
* Create a Point instance from a hex-encoded string.
* The input 'address' should be prefixed with '0x' or not, and have exactly 128 hex characters representing the x and y coordinates.
Expand All @@ -78,6 +100,46 @@ export class Point {
return new this(reader.readField(), reader.readField(), reader.readBoolean());
}

/**
* Uses the x coordinate and isPositive flag (+/-) to reconstruct the point.
* @dev The y coordinate can be derived from the x coordinate and the "sign" flag by solving the grumpkin curve
* equation for y.
* @param x - The x coordinate of the point
* @param sign - The "sign" of the y coordinate - note that this is not a sign as is known in integer arithmetic.
* Instead it is a boolean flag that determines whether the y coordinate is <= (Fr.MODULUS - 1) / 2
* @returns The point as an array of 2 fields
*/
static fromXAndSign(x: Fr, sign: boolean) {
// Calculate y^2 = x^3 - 17
const ySquared = x.square().mul(x).sub(new Fr(17));

// Calculate the square root of ySquared
const y = ySquared.sqrt();

// If y is null, the x-coordinate is not on the curve
if (y === null) {
throw new NotOnCurveError();
}

const yPositiveBigInt = y.toBigInt() > (Fr.MODULUS - 1n) / 2n ? Fr.MODULUS - y.toBigInt() : y.toBigInt();
const yNegativeBigInt = Fr.MODULUS - yPositiveBigInt;

// Choose the positive or negative root based on isPositive
const finalY = sign ? new Fr(yPositiveBigInt) : new Fr(yNegativeBigInt);

// Create and return the new Point
return new this(x, finalY, false);
}

/**
* Returns the x coordinate and the sign of the y coordinate.
* @dev The y sign can be determined by checking if the y coordinate is greater than half of the modulus.
* @returns The x coordinate and the sign of the y coordinate.
*/
toXAndSign(): [Fr, boolean] {
return [this.x, this.y.toBigInt() <= (Fr.MODULUS - 1n) / 2n];
}

/**
* Returns the contents of the point as BigInts.
* @returns The point as BigInts
Expand Down Expand Up @@ -111,6 +173,14 @@ export class Point {
return buf;
}

/**
* Converts the Point instance to a compressed Buffer representation of the coordinates.
* @returns A Buffer representation of the Point instance
*/
toCompressedBuffer() {
return serializeToBuffer(this.toXAndSign());
}

/**
* Convert the Point instance to a hexadecimal string representation.
* The output string is prefixed with '0x' and consists of exactly 128 hex characters,
Expand Down Expand Up @@ -194,3 +264,10 @@ export function isPoint(obj: object): obj is Point {
const point = obj as Point;
return point.kind === 'point' && point.x !== undefined && point.y !== undefined;
}

class NotOnCurveError extends Error {
constructor() {
super('The given x-coordinate is not on the Grumpkin curve');
this.name = 'NotOnCurveError';
}
}

1 comment on commit 225c6f6

@AztecBot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'C++ Benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.05.

Benchmark suite Current: 225c6f6 Previous: 3acef09 Ratio
Goblin::merge(t) 208096682 ns/iter 197048349 ns/iter 1.06

This comment was automatically generated by workflow using github-action-benchmark.

CC: @ludamad @codygunton

Please sign in to comment.