-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Add Mersenne to v3 #579
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
2c81e1e
add mersenne to v3
emirsoyturk b925017
clang-format fix
emirsoyturk bf867b9
Merge branch 'yshekel/V3' into mersenne-V3
emirsoyturk 045433a
add mersenne to the github workflows
emirsoyturk f208068
add mersenne to v3
emirsoyturk 7a72127
clang-format fix
emirsoyturk fb32202
add mersenne to the github workflows
emirsoyturk a38301d
m31 reduce fix
nonam3e c9ba833
Merge branch 'mersenne-V3' of github.com:ingonyama-zk/icicle into mer…
emirsoyturk a840759
Merge branch 'yshekel/V3' into mersenne-V3
emirsoyturk f06f8f3
add m31 to icicle fields with ext fields
emirsoyturk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,5 +10,6 @@ | |
|
||
#define BABY_BEAR 1001 | ||
#define STARK_252 1002 | ||
#define M31 1003 | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
#pragma once | ||
|
||
#include "icicle/fields/storage.h" | ||
#include "icicle/fields/field.h" | ||
#include "icicle/fields/quartic_extension.h" | ||
#include "icicle/fields/params_gen.h" | ||
|
||
namespace m31 { | ||
template <class CONFIG> | ||
class MersenneField : public Field<CONFIG> | ||
{ | ||
public: | ||
HOST_DEVICE_INLINE MersenneField(const MersenneField& other) : Field<CONFIG>(other) {} | ||
HOST_DEVICE_INLINE MersenneField(const uint32_t& x = 0) : Field<CONFIG>({x}) {} | ||
HOST_DEVICE_INLINE MersenneField(storage<CONFIG::limbs_count> x) : Field<CONFIG>{x} {} | ||
HOST_DEVICE_INLINE MersenneField(const Field<CONFIG>& other) : Field<CONFIG>(other) {} | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField zero() { return MersenneField{CONFIG::zero}; } | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField one() { return MersenneField{CONFIG::one}; } | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField from(uint32_t value) { return MersenneField(value); } | ||
|
||
static HOST_INLINE MersenneField rand_host() { return MersenneField(Field<CONFIG>::rand_host()); } | ||
|
||
static void rand_host_many(MersenneField* out, int size) | ||
{ | ||
for (int i = 0; i < size; i++) | ||
out[i] = rand_host(); | ||
} | ||
|
||
HOST_DEVICE_INLINE MersenneField& operator=(const Field<CONFIG>& other) | ||
{ | ||
if (this != &other) { Field<CONFIG>::operator=(other); } | ||
return *this; | ||
} | ||
|
||
HOST_DEVICE_INLINE uint32_t get_limb() const { return this->limbs_storage.limbs[0]; } | ||
|
||
// The `Wide` struct represents a redundant 32-bit form of the Mersenne Field. | ||
struct Wide { | ||
uint32_t storage; | ||
static constexpr HOST_DEVICE_INLINE Wide from_field(const MersenneField& xs) | ||
{ | ||
Wide out{}; | ||
out.storage = xs.get_limb(); | ||
return out; | ||
} | ||
static constexpr HOST_DEVICE_INLINE Wide from_number(const uint32_t& xs) | ||
{ | ||
Wide out{}; | ||
out.storage = xs; | ||
return out; | ||
} | ||
friend HOST_DEVICE_INLINE Wide operator+(Wide xs, const Wide& ys) | ||
{ | ||
uint64_t tmp = (uint64_t)xs.storage + ys.storage; // max: 2^33 - 2 = 2^32(1) + (2^32 - 2) | ||
tmp = ((tmp >> 32) << 1) + (uint32_t)(tmp); // 2(1)+(2^32-2) = 2^32(1)+(0) | ||
return from_number((uint32_t)((tmp >> 32) << 1) + (uint32_t)(tmp)); // max: 2(1) + 0 = 2 | ||
} | ||
friend HOST_DEVICE_INLINE Wide operator-(Wide xs, const Wide& ys) | ||
{ | ||
uint64_t tmp = CONFIG::modulus_3 + xs.storage - | ||
ys.storage; // max: 3(2^31-1) + 2^32-1 - 0 = 2^33 + 2^31-4 = 2^32(2) + (2^31-4) | ||
return from_number(((uint32_t)(tmp >> 32) << 1) + (uint32_t)(tmp)); // max: 2(2)+(2^31-4) = 2^31 | ||
} | ||
template <unsigned MODULUS_MULTIPLE = 1> | ||
static constexpr HOST_DEVICE_INLINE Wide neg(const Wide& xs) | ||
{ | ||
uint64_t tmp = CONFIG::modulus_3 - xs.storage; // max: 3(2^31-1) - 0 = 2^32(1) + (2^31 - 3) | ||
return from_number(((uint32_t)(tmp >> 32) << 1) + (uint32_t)(tmp)); // max: 2(1)+(2^31-3) = 2^31 - 1 | ||
} | ||
friend HOST_DEVICE_INLINE Wide operator*(Wide xs, const Wide& ys) | ||
{ | ||
uint64_t t1 = (uint64_t)xs.storage * ys.storage; // max: 2^64 - 2^33+1 = 2^32(2^32 - 2) + 1 | ||
t1 = ((t1 >> 32) << 1) + (uint32_t)(t1); // max: 2(2^32 - 2) + 1 = 2^32(1) + (2^32 - 3) | ||
return from_number((((uint32_t)(t1 >> 32)) << 1) + (uint32_t)(t1)); // max: 2(1) - (2^32 - 3) = 2^32 - 1 | ||
} | ||
}; | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField div2(const MersenneField& xs, const uint32_t& power = 1) | ||
{ | ||
uint32_t t = xs.get_limb(); | ||
return MersenneField{{((t >> power) | (t << (31 - power))) & MersenneField::get_modulus().limbs[0]}}; | ||
} | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField neg(const MersenneField& xs) | ||
{ | ||
uint32_t t = xs.get_limb(); | ||
return MersenneField{{t == 0 ? t : MersenneField::get_modulus().limbs[0] - t}}; | ||
} | ||
|
||
template <unsigned MODULUS_MULTIPLE = 1> | ||
static constexpr HOST_DEVICE_INLINE MersenneField reduce(Wide xs) | ||
{ | ||
const uint32_t modulus = MersenneField::get_modulus().limbs[0]; | ||
uint32_t tmp = (xs.storage >> 31) + (xs.storage & modulus); // max: 1 + 2^31-1 = 2^31 | ||
tmp = (tmp >> 31) + (tmp & modulus); // max: 1 + 0 = 1 | ||
return MersenneField{{tmp == modulus ? 0 : tmp}}; | ||
} | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField inverse(const MersenneField& x) | ||
{ | ||
uint32_t xs = x.limbs_storage.limbs[0]; | ||
if (xs <= 1) return xs; | ||
uint32_t a = 1, b = 0, y = xs, z = MersenneField::get_modulus().limbs[0], e, m = z; | ||
while (1) { | ||
#ifdef __CUDA_ARCH__ | ||
e = __ffs(y) - 1; | ||
#else | ||
e = __builtin_ctz(y); | ||
#endif | ||
y >>= e; | ||
if (a >= m) { | ||
a = (a & m) + (a >> 31); | ||
if (a == m) a = 0; | ||
} | ||
a = ((a >> e) | (a << (31 - e))) & m; | ||
if (y == 1) return a; | ||
e = a + b; | ||
b = a; | ||
a = e; | ||
e = y + z; | ||
z = y; | ||
y = e; | ||
} | ||
} | ||
|
||
friend HOST_DEVICE_INLINE MersenneField operator+(MersenneField xs, const MersenneField& ys) | ||
{ | ||
uint32_t m = MersenneField::get_modulus().limbs[0]; | ||
uint32_t t = xs.get_limb() + ys.get_limb(); | ||
if (t > m) t = (t & m) + (t >> 31); | ||
if (t == m) t = 0; | ||
return MersenneField{{t}}; | ||
} | ||
|
||
friend HOST_DEVICE_INLINE MersenneField operator-(MersenneField xs, const MersenneField& ys) | ||
{ | ||
return xs + neg(ys); | ||
} | ||
|
||
friend HOST_DEVICE_INLINE MersenneField operator*(MersenneField xs, const MersenneField& ys) | ||
{ | ||
uint64_t x = (uint64_t)(xs.get_limb()) * ys.get_limb(); | ||
uint32_t t = ((x >> 31) + (x & MersenneField::get_modulus().limbs[0])); | ||
uint32_t m = MersenneField::get_modulus().limbs[0]; | ||
if (t > m) t = (t & m) + (t >> 31); | ||
if (t > m) t = (t & m) + (t >> 31); | ||
if (t == m) t = 0; | ||
return MersenneField{{t}}; | ||
} | ||
|
||
static constexpr HOST_DEVICE_INLINE Wide mul_wide(const MersenneField& xs, const MersenneField& ys) | ||
{ | ||
return Wide::from_field(xs) * Wide::from_field(ys); | ||
} | ||
|
||
template <unsigned MODULUS_MULTIPLE = 1> | ||
static constexpr HOST_DEVICE_INLINE Wide sqr_wide(const MersenneField& xs) | ||
{ | ||
return mul_wide(xs, xs); | ||
} | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField sqr(const MersenneField& xs) { return xs * xs; } | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField to_montgomery(const MersenneField& xs) { return xs; } | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField from_montgomery(const MersenneField& xs) { return xs; } | ||
|
||
static constexpr HOST_DEVICE_INLINE MersenneField pow(MersenneField base, int exp) | ||
{ | ||
MersenneField res = one(); | ||
while (exp > 0) { | ||
if (exp & 1) res = res * base; | ||
base = base * base; | ||
exp >>= 1; | ||
} | ||
return res; | ||
} | ||
}; | ||
struct fp_config { | ||
static constexpr unsigned limbs_count = 1; | ||
static constexpr unsigned omegas_count = 1; | ||
static constexpr unsigned modulus_bit_count = 31; | ||
static constexpr unsigned num_of_reductions = 1; | ||
|
||
static constexpr storage<limbs_count> modulus = {0x7fffffff}; | ||
static constexpr storage<limbs_count> modulus_2 = {0xfffffffe}; | ||
static constexpr uint64_t modulus_3 = 0x17ffffffd; | ||
static constexpr storage<limbs_count> modulus_4 = {0xfffffffc}; | ||
static constexpr storage<limbs_count> neg_modulus = {0x87ffffff}; | ||
static constexpr storage<2 * limbs_count> modulus_wide = {0x7fffffff, 0x00000000}; | ||
static constexpr storage<2 * limbs_count> modulus_squared = {0x00000001, 0x3fffffff}; | ||
static constexpr storage<2 * limbs_count> modulus_squared_2 = {0x00000002, 0x7ffffffe}; | ||
static constexpr storage<2 * limbs_count> modulus_squared_4 = {0x00000004, 0xfffffffc}; | ||
|
||
static constexpr storage<limbs_count> m = {0x80000001}; | ||
static constexpr storage<limbs_count> one = {0x00000001}; | ||
static constexpr storage<limbs_count> zero = {0x00000000}; | ||
static constexpr storage<limbs_count> montgomery_r = {0x00000001}; | ||
static constexpr storage<limbs_count> montgomery_r_inv = {0x00000001}; | ||
|
||
static constexpr storage_array<omegas_count, limbs_count> omega = {{{0x7ffffffe}}}; | ||
|
||
static constexpr storage_array<omegas_count, limbs_count> omega_inv = {{{0x7ffffffe}}}; | ||
|
||
static constexpr storage_array<omegas_count, limbs_count> inv = {{{0x40000000}}}; | ||
|
||
// nonresidue to generate the extension field | ||
static constexpr uint32_t nonresidue = 1; | ||
// true if nonresidue is negative. | ||
static constexpr bool nonresidue_is_negative = true; | ||
}; | ||
|
||
/** | ||
* Scalar field. Is always a prime field. | ||
*/ | ||
typedef MersenneField<fp_config> scalar_t; | ||
|
||
/** | ||
* Extension field of `scalar_t` enabled if `-DEXT_FIELD` env variable is. | ||
*/ | ||
typedef ExtensionField<fp_config, scalar_t> extension_t; | ||
} // namespace m31 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[package] | ||
name = "icicle-m31" | ||
version.workspace = true | ||
edition.workspace = true | ||
authors.workspace = true | ||
description = "Rust wrapper the implementation of m31 prime field by Ingonyama" | ||
homepage.workspace = true | ||
repository.workspace = true | ||
|
||
[dependencies] | ||
icicle-core = { workspace = true } | ||
icicle-runtime = { workspace = true } | ||
|
||
[build-dependencies] | ||
cmake = "0.1.50" | ||
|
||
[features] | ||
cuda_backend = ["icicle-runtime/cuda_backend"] | ||
pull_cuda_backend = ["icicle-runtime/pull_cuda_backend"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
use cmake::Config; | ||
use std::{env, path::PathBuf}; | ||
|
||
fn main() { | ||
// Construct the path to the deps directory | ||
let out_dir = env::var("OUT_DIR").expect("OUT_DIR is not set"); | ||
let build_dir = PathBuf::from(format!("{}/../../../", &out_dir)); | ||
let deps_dir = build_dir.join("deps"); | ||
|
||
// Construct the path to icicle source directory | ||
let main_dir = env::current_dir().expect("Failed to get current directory"); | ||
let icicle_src_dir = PathBuf::from(format!("{}/../../../../icicle_v3", main_dir.display())); | ||
|
||
println!("cargo:rerun-if-env-changed=CXXFLAGS"); | ||
println!("cargo:rerun-if-changed={}", icicle_src_dir.display()); | ||
|
||
// Base config | ||
let mut config = Config::new(format!("{}", icicle_src_dir.display())); | ||
// Check if ICICLE_INSTALL_DIR is defined | ||
let icicle_install_dir = if let Ok(dir) = env::var("ICICLE_INSTALL_DIR") { | ||
PathBuf::from(dir) | ||
} else { | ||
// Define the default install directory to be under the build directory | ||
PathBuf::from(format!("{}/icicle/", deps_dir.display())) | ||
}; | ||
config | ||
.define("FIELD", "m31") | ||
.define("EXT_FIELD", "ON") | ||
.define("CMAKE_BUILD_TYPE", "Release") | ||
.define("CMAKE_INSTALL_PREFIX", &icicle_install_dir); | ||
|
||
#[cfg(feature = "cuda_backend")] | ||
config.define("CUDA_BACKEND", "local"); | ||
|
||
#[cfg(feature = "pull_cuda_backend")] | ||
config.define("CUDA_BACKEND", "main"); | ||
|
||
// Build | ||
let _ = config | ||
.build_target("install") | ||
.build(); | ||
|
||
println!("cargo:rustc-link-search={}/lib", icicle_install_dir.display()); | ||
println!("cargo:rustc-link-lib=icicle_field_m31"); | ||
println!("cargo:rustc-link-arg=-Wl,-rpath,{}/lib", icicle_install_dir.display()); // Add RPATH linker arguments | ||
|
||
// default backends dir | ||
if cfg!(feature = "cuda_backend") || cfg!(feature = "pull_cuda_backend") { | ||
println!( | ||
"cargo:rustc-env=ICICLE_BACKEND_INSTALL_DIR={}/lib/backend", | ||
icicle_install_dir.display() | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use icicle_core::field::{Field, MontgomeryConvertibleField}; | ||
use icicle_core::traits::{FieldConfig, FieldImpl, GenerateRandom}; | ||
use icicle_core::{impl_field, impl_scalar_field}; | ||
use icicle_runtime::errors::eIcicleError; | ||
use icicle_runtime::memory::{DeviceSlice, HostOrDeviceSlice}; | ||
use icicle_runtime::stream::IcicleStream; | ||
|
||
pub(crate) const SCALAR_LIMBS: usize = 1; | ||
pub(crate) const EXTENSION_LIMBS: usize = 4; | ||
|
||
impl_scalar_field!("m31", m31, SCALAR_LIMBS, ScalarField, ScalarCfg); | ||
impl_scalar_field!( | ||
"m31_extension", | ||
m31_extension, | ||
EXTENSION_LIMBS, | ||
ExtensionField, | ||
ExtensionCfg | ||
); | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::{ExtensionField, ScalarField}; | ||
use icicle_core::impl_field_tests; | ||
use icicle_core::tests::*; | ||
|
||
impl_field_tests!(ScalarField); | ||
mod extension { | ||
use super::*; | ||
|
||
impl_field_tests!(ExtensionField); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pub mod field; | ||
pub mod vec_ops; |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we need to support both extensions of m31? Afaik both Quadratic and Quartic extensions are used (2 and 4 elements)