Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: scalarmul in sumcheck #1189

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
XXX: full sumcheck-scalarmul
ivokub committed Jul 4, 2024
commit 1ed7dd7befb18044a7cbf7e8627c3c3c7667298d
79 changes: 66 additions & 13 deletions std/recursion/sumcheck/scalarmul_test.go
Original file line number Diff line number Diff line change
@@ -11,12 +11,17 @@ import (
fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/std/algebra/emulated/sw_emulated"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"github.com/consensys/gnark/test"
)

type ProjectivePoint[Base emulated.FieldParams] struct {
X, Y, Z emulated.Element[Base]
}

type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct {
Points []sw_emulated.AffinePoint[Base]
Scalars []emulated.Element[Scalars]
@@ -37,19 +42,19 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error {
return fmt.Errorf("new scalar field: %w", err)
}
for i := range c.Points {
step, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i])
results, accs, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i])
if err != nil {
return fmt.Errorf("hint scalar mul steps: %w", err)
}
_ = step
_, _ = results, accs
}
return nil
}

func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API,
baseApi *emulated.Field[B], scalarApi *emulated.Field[S],
nbScalarBits int,
point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) ([][6]*emulated.Element[B], error) {
point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) (results []ProjectivePoint[B], accumulators []ProjectivePoint[B], err error) {
var fp B
var fr S
inputs := []frontend.Variable{fp.BitsPerLimb(), fp.NbLimbs()}
@@ -62,16 +67,28 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API,
nbRes := nbScalarBits * int(fp.NbLimbs()) * 6
hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...)
if err != nil {
return nil, fmt.Errorf("new hint: %w", err)
return nil, nil, fmt.Errorf("new hint: %w", err)
}
res := make([][6]*emulated.Element[B], nbScalarBits)
res := make([]ProjectivePoint[B], nbScalarBits)
acc := make([]ProjectivePoint[B], nbScalarBits)
for i := range res {
for j := 0; j < 6; j++ {
coords := make([]*emulated.Element[B], 6)
for j := range coords {
limbs := hintRes[i*(6*int(fp.NbLimbs()))+j*int(fp.NbLimbs()) : i*(6*int(fp.NbLimbs()))+(j+1)*int(fp.NbLimbs())]
res[i][j] = baseApi.NewElement(limbs)
coords[j] = baseApi.NewElement(limbs)
}
res[i] = ProjectivePoint[B]{
X: *coords[0],
Y: *coords[1],
Z: *coords[2],
}
acc[i] = ProjectivePoint[B]{
X: *coords[3],
Y: *coords[4],
Z: *coords[5],
}
}
return res, nil
return res, acc, nil
}

func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error {
@@ -105,9 +122,44 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err
if err := recompose(scalarLimbs, uint(nbScalarBits), scalar); err != nil {
return fmt.Errorf("recompose scalar: %w", err)
}
fmt.Println(fp, fr, x, y, scalar)

scalarLength := len(outputs) / (6 * nbLimbs)
accX := new(big.Int).Set(x)
accY := new(big.Int).Set(y)
accZ := big.NewInt(1)
resultX := big.NewInt(0)
resultY := big.NewInt(1)
resultZ := big.NewInt(0)
api := newBigIntEngine(fp)
selector := new(big.Int)

for i := 0; i < scalarLength; i++ {
// selector := scalar.And()
selector.And(scalar, big.NewInt(1))
scalar.Rsh(scalar, 1)
tmpX, tmpY, tmpZ := projAdd(api, accX, accY, accZ, resultX, resultY, resultZ)
resultX, resultY, resultZ = projSelect(api, selector, tmpX, tmpY, tmpZ, resultX, resultY, resultZ)
accX, accY, accZ = projDbl(api, accX, accY, accZ)
if err := decompose(resultX, uint(nbBits), outputs[i*6*nbLimbs:i*6*nbLimbs+nbLimbs]); err != nil {
return fmt.Errorf("decompose resultX: %w", err)
}
if err := decompose(resultY, uint(nbBits), outputs[i*6*nbLimbs+nbLimbs:i*6*nbLimbs+2*nbLimbs]); err != nil {
return fmt.Errorf("decompose resultY: %w", err)
}
if err := decompose(resultZ, uint(nbBits), outputs[i*6*nbLimbs+2*nbLimbs:i*6*nbLimbs+3*nbLimbs]); err != nil {
return fmt.Errorf("decompose resultZ: %w", err)
}
if err := decompose(accX, uint(nbBits), outputs[i*6*nbLimbs+3*nbLimbs:i*6*nbLimbs+4*nbLimbs]); err != nil {
return fmt.Errorf("decompose accX: %w", err)
}
if err := decompose(accY, uint(nbBits), outputs[i*6*nbLimbs+4*nbLimbs:i*6*nbLimbs+5*nbLimbs]); err != nil {
return fmt.Errorf("decompose accY: %w", err)
}
if err := decompose(accZ, uint(nbBits), outputs[i*6*nbLimbs+5*nbLimbs:(i+1)*6*nbLimbs]); err != nil {
return fmt.Errorf("decompose accZ: %w", err)
}
}

return nil
}

@@ -150,19 +202,19 @@ func TestScalarMul(t *testing.T) {
assert := test.NewAssert(t)
type B = emparams.Secp256k1Fp
type S = emparams.Secp256k1Fr
t.Log(B{}.Modulus(), S{}.Modulus())
var P secp256k1.G1Affine
var s fr_secp256k1.Element
nbInputs := 1 << 0
nbScalarBits := 2
nbInputs := 1 << 2
nbScalarBits := 256
scalarBound := new(big.Int).Lsh(big.NewInt(1), uint(nbScalarBits))
points := make([]sw_emulated.AffinePoint[B], nbInputs)
scalars := make([]emulated.Element[S], nbInputs)
for i := range points {
P.ScalarMultiplicationBase(big.NewInt(1))
s.SetRandom()
P.ScalarMultiplicationBase(s.BigInt(new(big.Int)))
sc, _ := rand.Int(rand.Reader, scalarBound)
t.Log(P.X.String(), P.Y.String(), sc.String())
// t.Log(P.X.String(), P.Y.String(), sc.String())
points[i] = sw_emulated.AffinePoint[B]{
X: emulated.ValueOf[B](P.X),
Y: emulated.ValueOf[B](P.Y),
@@ -180,4 +232,5 @@ func TestScalarMul(t *testing.T) {
}
err := test.IsSolved(&circuit, &witness, ecc.BLS12_377.ScalarField())
assert.NoError(err)
frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit)
}