Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] ReleaseDomain API #465

Merged
merged 10 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions icicle/appUtils/ntt/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ namespace ntt {
domain.fast_internal_twiddles_inv = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles_inv, ctx.stream));
domain.fast_basic_twiddles_inv = nullptr;
domain.initialized = false;

return CHK_LAST();
}
Expand Down Expand Up @@ -747,6 +748,17 @@ namespace ntt {
return NTT<curve_config::scalar_t, curve_config::scalar_t>(input, size, dir, config, output);
}

/**
* Extern "C" version of [ReleaseDomain](@ref ReleaseDomain) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
* - `S` is the [scalar field](@ref scalar_t) of the curve;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, ReleaseDomain)(device_context::DeviceContext& ctx)
{
return ReleaseDomain<curve_config::scalar_t>(ctx);
}

#if defined(ECNTT_DEFINED)

/**
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bls12377/include/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern "C" {
cudaError_t bls12_377NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_377ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bls12_377ReleaseDomain(DeviceContext* ctx);

#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions wrappers/golang/curves/bls12377/ntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bls12_377ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
33 changes: 30 additions & 3 deletions wrappers/golang/curves/bls12377/ntt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bls12377

import (
"os"
"reflect"
"testing"

Expand All @@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}

func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])

rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}

func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
Expand Down Expand Up @@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}

func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
Expand Down Expand Up @@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}

func TestReleaseDomain(t *testing.T) {
jeremyfelder marked this conversation as resolved.
Show resolved Hide resolved
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}

func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}

// execute tests
os.Exit(m.Run())

// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}

// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bls12381/include/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern "C" {
cudaError_t bls12_381NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_381ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bls12_381ReleaseDomain(DeviceContext* ctx);

#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions wrappers/golang/curves/bls12381/ntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bls12_381ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
33 changes: 30 additions & 3 deletions wrappers/golang/curves/bls12381/ntt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bls12381

import (
"os"
"reflect"
"testing"

Expand All @@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}

func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])

rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}

func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
Expand Down Expand Up @@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}

func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
Expand Down Expand Up @@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}

func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}

func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}

// execute tests
os.Exit(m.Run())

// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}

// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bn254/include/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern "C" {
cudaError_t bn254NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bn254ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bn254ReleaseDomain(DeviceContext* ctx);

#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions wrappers/golang/curves/bn254/ntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bn254ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
33 changes: 30 additions & 3 deletions wrappers/golang/curves/bn254/ntt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bn254

import (
"os"
"reflect"
"testing"

Expand All @@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}

func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])

rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}

func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
Expand Down Expand Up @@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}

func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
Expand Down Expand Up @@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}

func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}

func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}

// execute tests
os.Exit(m.Run())

// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}

// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bw6761/include/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern "C" {
cudaError_t bw6_761NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bw6_761ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bw6_761ReleaseDomain(DeviceContext* ctx);

#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions wrappers/golang/curves/bw6761/ntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bw6_761ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
33 changes: 30 additions & 3 deletions wrappers/golang/curves/bw6761/ntt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bw6761

import (
"os"
"reflect"
"testing"

Expand All @@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}

func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])

rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}

func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
Expand Down Expand Up @@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}

func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
Expand Down Expand Up @@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}

func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}

func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}

// execute tests
os.Exit(m.Run())

// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}

// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ extern "C" {
cudaError_t {{.Curve}}NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t {{.Curve}}ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t {{.Curve}}ReleaseDomain(DeviceContext* ctx);

#ifdef __cplusplus
}
Expand Down
7 changes: 7 additions & 0 deletions wrappers/golang/internal/generator/templates/ntt.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.{{.Curve}}ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
Loading
Loading