Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: overhead free type conversion for load_poly_to_gpu #516

Merged
merged 10 commits into from
Apr 1, 2024
Merged
38 changes: 32 additions & 6 deletions primitives/src/pcs/univariate_kzg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,16 +818,18 @@ pub(crate) mod icicle {

#[cfg(feature = "kzg-print-trace")]
let conv_time = start_timer!(|| "Type Conversion: ark->ICICLE: Scalar");
let scalars: Vec<<Self::IC as IcicleCurve>::ScalarField> = poly.coeffs()[..size]
.par_iter()
.map(|&s| Self::ark_field_to_icicle(s))
.collect();
// We assume that two types use the same underline repr.
let scalars = unsafe {
poly.coeffs()[..size]
.align_to::<<Self::IC as IcicleCurve>::ScalarField>()
.1
};
#[cfg(feature = "kzg-print-trace")]
end_timer!(conv_time);

#[cfg(feature = "kzg-print-trace")]
let load_time = start_timer!(|| "Load scalars: CPU->GPU");
scalars_on_device.copy_from_host(&scalars)?;
scalars_on_device.copy_from_host(scalars)?;
#[cfg(feature = "kzg-print-trace")]
end_timer!(load_time);

Expand Down Expand Up @@ -1306,7 +1308,11 @@ mod tests {
use super::*;
#[cfg(feature = "kzg-print-trace")]
use crate::icicle_deps::warmup_new_stream;
use crate::{icicle_deps::curves::*, pcs::univariate_kzg::icicle::GPUCommittable};
use crate::{
icicle_deps::{curves::*, IcicleCurve},
pcs::univariate_kzg::icicle::GPUCommittable,
};
use core::mem::size_of;

#[cfg(feature = "kzg-print-trace")]
fn gpu_profiling<E: Pairing>() -> Result<(), PCSError>
Expand Down Expand Up @@ -1401,6 +1407,26 @@ mod tests {
test_gpu_e2e_template::<Bn254>().unwrap();
}

fn test_gpu_ark_conversion_template<E: Pairing>()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this test. But can we do more? Earlier code comment says:

We assume that two types use the same underline repr.

In that case, we should be able to assert that the bytes are exactly the same. That way, if an incompatibility is introduced then the test will detect it immediately, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit annoying but doable: c82e230

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I don't fully understand this new test but here's what I see:

  1. make a slice of random new ark scalars in scalars
  2. convert them into a IC scalars slice ic_scalars via from_ark.
  3. copy the slice into a new GPU slice called d_scalars.
  4. transform the GPU slice to Montgomery form on the GPU
  5. copy the GPU scalars out of GPU and back into ic_scalars
  6. these ic_scalars (in Montgomery form) should be byte-for-byte identical to the original scalars. Check this via align_to

Not sure why you need to copy to/from GPU but AFIACT this test does what it's supposed to.

where
UnivariateKzgPCS<E>: GPUCommittable<E>,
{
assert_eq!(
size_of::<E::ScalarField>(),
size_of::<
<<UnivariateKzgPCS<E> as GPUCommittable<E>>::IC as IcicleCurve>::ScalarField,
>()
);
}

#[test]
/// This test checks whether the scalar field type in Ark has the size
/// with the one in icicle. So that we could do direct reinterpret_cast
/// between them.
fn test_gpu_ark_conversion() {
test_gpu_ark_conversion_template::<Bn254>();
}

#[cfg(feature = "kzg-print-trace")]
#[test]
fn profile_gpu_commit() {
Expand Down
20 changes: 10 additions & 10 deletions primitives/src/vid/advz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ where
UnivariateKzgPCS<E>: GPUCommittable<E>,
{
/// Like [`Advz::new`] except with SRS loaded to GPU
pub fn with_multiplicity(
num_storage_nodes: usize,
recovery_threshold: usize,
pub fn new(
num_storage_nodes: u32,
recovery_threshold: u32,
srs: impl Borrow<KzgSrs<E>>,
) -> VidResult<Self> {
let mut advz = Self::new_internal(num_storage_nodes, recovery_threshold, srs)?;
Expand All @@ -260,9 +260,9 @@ where
}
/// Like [`Advz::with_multiplicity`] except with SRS loaded to GPU
pub fn with_multiplicity(
num_storage_nodes: usize,
recovery_threshold: usize,
multiplicity: usize,
num_storage_nodes: u32,
recovery_threshold: u32,
multiplicity: u32,
srs: impl Borrow<KzgSrs<E>>,
) -> VidResult<Self> {
let mut advz = Self::with_multiplicity_internal(
Expand Down Expand Up @@ -985,11 +985,11 @@ mod tests {
let (recovery_threshold, num_storage_nodes) = (256, 512);
let mut rng = jf_utils::test_rng();
let srs = init_srs(recovery_threshold as usize, &mut rng);
let mut advz =
Advz::<Bn254, Sha256>::new(num_storage_nodes, recovery_threshold, srs).unwrap();
#[cfg(feature = "gpu-vid")]
let mut advz_gpu =
AdvzGPU::<'_, Bn254, Sha256>::new(num_storage_nodes, recovery_threshold, &srs).unwrap();
let mut advz =
Advz::<Bn254, Sha256>::new(num_storage_nodes, recovery_threshold, srs).unwrap();

let payload_random = init_random_payload(1 << 25, &mut rng);

Expand All @@ -1005,11 +1005,11 @@ mod tests {
let (recovery_threshold, num_storage_nodes) = (256, 512);
let mut rng = jf_utils::test_rng();
let srs = init_srs(recovery_threshold as usize, &mut rng);
let mut advz =
Advz::<Bn254, Sha256>::new(num_storage_nodes, recovery_threshold, srs).unwrap();
#[cfg(feature = "gpu-vid")]
let mut advz_gpu =
AdvzGPU::<'_, Bn254, Sha256>::new(num_storage_nodes, recovery_threshold, &srs).unwrap();
let mut advz =
Advz::<Bn254, Sha256>::new(num_storage_nodes, recovery_threshold, srs).unwrap();

let payload_random = init_random_payload(1 << 25, &mut rng);

Expand Down
Loading