Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add Mersenne to v3 #579

Merged
merged 11 commits into from
Aug 27, 2024
2 changes: 2 additions & 0 deletions .github/workflows/cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ jobs:
build_args: -DEXT_FIELD=ON
- name: stark252
build_args: -DEXT_FIELD=OFF
- name: m31
build_args: -DEXT_FIELD=ON
steps:
- name: Checkout Repo
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions icicle/cmake/fields_and_curves.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
set(ICICLE_FIELDS
1001:babybear:NTT,EXT_FIELD
1002:stark252:NTT
1003:m31:EXT_FIELD
)

# Define available curves with an index and their supported features
Expand Down
3 changes: 1 addition & 2 deletions icicle/include/icicle/backend/vec_ops_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@ namespace icicle {
const VecOpsConfig& config,
extension_t* output)>;

extern "C" void
register_extension_scalar_convert_montgomery(const std::string& deviceType, extFieldConvertMontgomeryImpl);
void register_extension_scalar_convert_montgomery(const std::string& deviceType, extFieldConvertMontgomeryImpl);

#define REGISTER_CONVERT_MONTGOMERY_EXT_FIELD_BACKEND(DEVICE_TYPE, FUNC) \
namespace { \
Expand Down
3 changes: 3 additions & 0 deletions icicle/include/icicle/fields/field_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@ namespace field_config = babybear;
#elif FIELD_ID == STARK_252
#include "icicle/fields/stark_fields/stark252.h"
namespace field_config = stark252;
#elif FIELD_ID == M31
#include "icicle/fields/stark_fields/m31.h"
namespace field_config = m31;
#endif
1 change: 1 addition & 0 deletions icicle/include/icicle/fields/id.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@

#define BABY_BEAR 1001
#define STARK_252 1002
#define M31 1003

#endif
225 changes: 225 additions & 0 deletions icicle/include/icicle/fields/stark_fields/m31.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#pragma once

#include "icicle/fields/storage.h"
#include "icicle/fields/field.h"
#include "icicle/fields/quartic_extension.h"
#include "icicle/fields/params_gen.h"

namespace m31 {
template <class CONFIG>
class MersenneField : public Field<CONFIG>
{
public:
HOST_DEVICE_INLINE MersenneField(const MersenneField& other) : Field<CONFIG>(other) {}
HOST_DEVICE_INLINE MersenneField(const uint32_t& x = 0) : Field<CONFIG>({x}) {}
HOST_DEVICE_INLINE MersenneField(storage<CONFIG::limbs_count> x) : Field<CONFIG>{x} {}
HOST_DEVICE_INLINE MersenneField(const Field<CONFIG>& other) : Field<CONFIG>(other) {}

static constexpr HOST_DEVICE_INLINE MersenneField zero() { return MersenneField{CONFIG::zero}; }

static constexpr HOST_DEVICE_INLINE MersenneField one() { return MersenneField{CONFIG::one}; }

static constexpr HOST_DEVICE_INLINE MersenneField from(uint32_t value) { return MersenneField(value); }

static HOST_INLINE MersenneField rand_host() { return MersenneField(Field<CONFIG>::rand_host()); }

static void rand_host_many(MersenneField* out, int size)
{
for (int i = 0; i < size; i++)
out[i] = rand_host();
}

HOST_DEVICE_INLINE MersenneField& operator=(const Field<CONFIG>& other)
{
if (this != &other) { Field<CONFIG>::operator=(other); }
return *this;
}

HOST_DEVICE_INLINE uint32_t get_limb() const { return this->limbs_storage.limbs[0]; }

// The `Wide` struct represents a redundant 32-bit form of the Mersenne Field.
struct Wide {
uint32_t storage;
static constexpr HOST_DEVICE_INLINE Wide from_field(const MersenneField& xs)
{
Wide out{};
out.storage = xs.get_limb();
return out;
}
static constexpr HOST_DEVICE_INLINE Wide from_number(const uint32_t& xs)
{
Wide out{};
out.storage = xs;
return out;
}
friend HOST_DEVICE_INLINE Wide operator+(Wide xs, const Wide& ys)
{
uint64_t tmp = (uint64_t)xs.storage + ys.storage; // max: 2^33 - 2 = 2^32(1) + (2^32 - 2)
tmp = ((tmp >> 32) << 1) + (uint32_t)(tmp); // 2(1)+(2^32-2) = 2^32(1)+(0)
return from_number((uint32_t)((tmp >> 32) << 1) + (uint32_t)(tmp)); // max: 2(1) + 0 = 2
}
friend HOST_DEVICE_INLINE Wide operator-(Wide xs, const Wide& ys)
{
uint64_t tmp = CONFIG::modulus_3 + xs.storage -
ys.storage; // max: 3(2^31-1) + 2^32-1 - 0 = 2^33 + 2^31-4 = 2^32(2) + (2^31-4)
return from_number(((uint32_t)(tmp >> 32) << 1) + (uint32_t)(tmp)); // max: 2(2)+(2^31-4) = 2^31
}
template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE Wide neg(const Wide& xs)
{
uint64_t tmp = CONFIG::modulus_3 - xs.storage; // max: 3(2^31-1) - 0 = 2^32(1) + (2^31 - 3)
return from_number(((uint32_t)(tmp >> 32) << 1) + (uint32_t)(tmp)); // max: 2(1)+(2^31-3) = 2^31 - 1
}
friend HOST_DEVICE_INLINE Wide operator*(Wide xs, const Wide& ys)
{
uint64_t t1 = (uint64_t)xs.storage * ys.storage; // max: 2^64 - 2^33+1 = 2^32(2^32 - 2) + 1
t1 = ((t1 >> 32) << 1) + (uint32_t)(t1); // max: 2(2^32 - 2) + 1 = 2^32(1) + (2^32 - 3)
return from_number((((uint32_t)(t1 >> 32)) << 1) + (uint32_t)(t1)); // max: 2(1) - (2^32 - 3) = 2^32 - 1
}
};

static constexpr HOST_DEVICE_INLINE MersenneField div2(const MersenneField& xs, const uint32_t& power = 1)
{
uint32_t t = xs.get_limb();
return MersenneField{{((t >> power) | (t << (31 - power))) & MersenneField::get_modulus().limbs[0]}};
}

static constexpr HOST_DEVICE_INLINE MersenneField neg(const MersenneField& xs)
{
uint32_t t = xs.get_limb();
return MersenneField{{t == 0 ? t : MersenneField::get_modulus().limbs[0] - t}};
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE MersenneField reduce(Wide xs)
{
const uint32_t modulus = MersenneField::get_modulus().limbs[0];
uint32_t tmp = (xs.storage >> 31) + (xs.storage & modulus); // max: 1 + 2^31-1 = 2^31
tmp = (tmp >> 31) + (tmp & modulus); // max: 1 + 0 = 1
return MersenneField{{tmp == modulus ? 0 : tmp}};
}

static constexpr HOST_DEVICE_INLINE MersenneField inverse(const MersenneField& x)
{
uint32_t xs = x.limbs_storage.limbs[0];
if (xs <= 1) return xs;
uint32_t a = 1, b = 0, y = xs, z = MersenneField::get_modulus().limbs[0], e, m = z;
while (1) {
#ifdef __CUDA_ARCH__
e = __ffs(y) - 1;
#else
e = __builtin_ctz(y);
#endif
y >>= e;
if (a >= m) {
a = (a & m) + (a >> 31);
if (a == m) a = 0;
}
a = ((a >> e) | (a << (31 - e))) & m;
if (y == 1) return a;
e = a + b;
b = a;
a = e;
e = y + z;
z = y;
y = e;
}
}

friend HOST_DEVICE_INLINE MersenneField operator+(MersenneField xs, const MersenneField& ys)
{
uint32_t m = MersenneField::get_modulus().limbs[0];
uint32_t t = xs.get_limb() + ys.get_limb();
if (t > m) t = (t & m) + (t >> 31);
if (t == m) t = 0;
return MersenneField{{t}};
}

friend HOST_DEVICE_INLINE MersenneField operator-(MersenneField xs, const MersenneField& ys)
{
return xs + neg(ys);
}

friend HOST_DEVICE_INLINE MersenneField operator*(MersenneField xs, const MersenneField& ys)
{
uint64_t x = (uint64_t)(xs.get_limb()) * ys.get_limb();
uint32_t t = ((x >> 31) + (x & MersenneField::get_modulus().limbs[0]));
uint32_t m = MersenneField::get_modulus().limbs[0];
if (t > m) t = (t & m) + (t >> 31);
if (t > m) t = (t & m) + (t >> 31);
if (t == m) t = 0;
return MersenneField{{t}};
}

static constexpr HOST_DEVICE_INLINE Wide mul_wide(const MersenneField& xs, const MersenneField& ys)
{
return Wide::from_field(xs) * Wide::from_field(ys);
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE Wide sqr_wide(const MersenneField& xs)
{
return mul_wide(xs, xs);
}

static constexpr HOST_DEVICE_INLINE MersenneField sqr(const MersenneField& xs) { return xs * xs; }

static constexpr HOST_DEVICE_INLINE MersenneField to_montgomery(const MersenneField& xs) { return xs; }

static constexpr HOST_DEVICE_INLINE MersenneField from_montgomery(const MersenneField& xs) { return xs; }

static constexpr HOST_DEVICE_INLINE MersenneField pow(MersenneField base, int exp)
{
MersenneField res = one();
while (exp > 0) {
if (exp & 1) res = res * base;
base = base * base;
exp >>= 1;
}
return res;
}
};
struct fp_config {
static constexpr unsigned limbs_count = 1;
static constexpr unsigned omegas_count = 1;
static constexpr unsigned modulus_bit_count = 31;
static constexpr unsigned num_of_reductions = 1;

static constexpr storage<limbs_count> modulus = {0x7fffffff};
static constexpr storage<limbs_count> modulus_2 = {0xfffffffe};
static constexpr uint64_t modulus_3 = 0x17ffffffd;
static constexpr storage<limbs_count> modulus_4 = {0xfffffffc};
static constexpr storage<limbs_count> neg_modulus = {0x87ffffff};
static constexpr storage<2 * limbs_count> modulus_wide = {0x7fffffff, 0x00000000};
static constexpr storage<2 * limbs_count> modulus_squared = {0x00000001, 0x3fffffff};
static constexpr storage<2 * limbs_count> modulus_squared_2 = {0x00000002, 0x7ffffffe};
static constexpr storage<2 * limbs_count> modulus_squared_4 = {0x00000004, 0xfffffffc};

static constexpr storage<limbs_count> m = {0x80000001};
static constexpr storage<limbs_count> one = {0x00000001};
static constexpr storage<limbs_count> zero = {0x00000000};
static constexpr storage<limbs_count> montgomery_r = {0x00000001};
static constexpr storage<limbs_count> montgomery_r_inv = {0x00000001};

static constexpr storage_array<omegas_count, limbs_count> omega = {{{0x7ffffffe}}};

static constexpr storage_array<omegas_count, limbs_count> omega_inv = {{{0x7ffffffe}}};

static constexpr storage_array<omegas_count, limbs_count> inv = {{{0x40000000}}};

// nonresidue to generate the extension field
static constexpr uint32_t nonresidue = 1;
// true if nonresidue is negative.
static constexpr bool nonresidue_is_negative = true;
};

/**
* Scalar field. Is always a prime field.
*/
typedef MersenneField<fp_config> scalar_t;

/**
* Extension field of `scalar_t` enabled if `-DEXT_FIELD` env variable is.
*/
typedef ExtensionField<fp_config, scalar_t> extension_t;
} // namespace m31
19 changes: 19 additions & 0 deletions wrappers/rust_v3/icicle-fields/icicle-m31/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "icicle-m31"
version.workspace = true
edition.workspace = true
authors.workspace = true
description = "Rust wrapper the implementation of m31 prime field by Ingonyama"
homepage.workspace = true
repository.workspace = true

[dependencies]
icicle-core = { workspace = true }
icicle-runtime = { workspace = true }

[build-dependencies]
cmake = "0.1.50"

[features]
cuda_backend = ["icicle-runtime/cuda_backend"]
pull_cuda_backend = ["icicle-runtime/pull_cuda_backend"]
54 changes: 54 additions & 0 deletions wrappers/rust_v3/icicle-fields/icicle-m31/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use cmake::Config;
use std::{env, path::PathBuf};

fn main() {
// Construct the path to the deps directory
let out_dir = env::var("OUT_DIR").expect("OUT_DIR is not set");
let build_dir = PathBuf::from(format!("{}/../../../", &out_dir));
let deps_dir = build_dir.join("deps");

// Construct the path to icicle source directory
let main_dir = env::current_dir().expect("Failed to get current directory");
let icicle_src_dir = PathBuf::from(format!("{}/../../../../icicle_v3", main_dir.display()));

println!("cargo:rerun-if-env-changed=CXXFLAGS");
println!("cargo:rerun-if-changed={}", icicle_src_dir.display());

// Base config
let mut config = Config::new(format!("{}", icicle_src_dir.display()));
// Check if ICICLE_INSTALL_DIR is defined
let icicle_install_dir = if let Ok(dir) = env::var("ICICLE_INSTALL_DIR") {
PathBuf::from(dir)
} else {
// Define the default install directory to be under the build directory
PathBuf::from(format!("{}/icicle/", deps_dir.display()))
};
config
.define("FIELD", "m31")
.define("EXT_FIELD", "ON")
.define("CMAKE_BUILD_TYPE", "Release")
.define("CMAKE_INSTALL_PREFIX", &icicle_install_dir);

#[cfg(feature = "cuda_backend")]
config.define("CUDA_BACKEND", "local");

#[cfg(feature = "pull_cuda_backend")]
config.define("CUDA_BACKEND", "main");

// Build
let _ = config
.build_target("install")
.build();

println!("cargo:rustc-link-search={}/lib", icicle_install_dir.display());
println!("cargo:rustc-link-lib=icicle_field_m31");
println!("cargo:rustc-link-arg=-Wl,-rpath,{}/lib", icicle_install_dir.display()); // Add RPATH linker arguments

// default backends dir
if cfg!(feature = "cuda_backend") || cfg!(feature = "pull_cuda_backend") {
println!(
"cargo:rustc-env=ICICLE_BACKEND_INSTALL_DIR={}/lib/backend",
icicle_install_dir.display()
);
}
}
32 changes: 32 additions & 0 deletions wrappers/rust_v3/icicle-fields/icicle-m31/src/field.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use icicle_core::field::{Field, MontgomeryConvertibleField};
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom};
use icicle_core::{impl_field, impl_scalar_field};
use icicle_runtime::errors::eIcicleError;
use icicle_runtime::memory::{DeviceSlice, HostOrDeviceSlice};
use icicle_runtime::stream::IcicleStream;

pub(crate) const SCALAR_LIMBS: usize = 1;
pub(crate) const EXTENSION_LIMBS: usize = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to support both extensions of m31? Afaik both Quadratic and Quartic extensions are used (2 and 4 elements)


impl_scalar_field!("m31", m31, SCALAR_LIMBS, ScalarField, ScalarCfg);
impl_scalar_field!(
"m31_extension",
m31_extension,
EXTENSION_LIMBS,
ExtensionField,
ExtensionCfg
);

#[cfg(test)]
mod tests {
use super::{ExtensionField, ScalarField};
use icicle_core::impl_field_tests;
use icicle_core::tests::*;

impl_field_tests!(ScalarField);
mod extension {
use super::*;

impl_field_tests!(ExtensionField);
}
}
2 changes: 2 additions & 0 deletions wrappers/rust_v3/icicle-fields/icicle-m31/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod field;
pub mod vec_ops;
Loading
Loading