From bad1f2e425753ce23c2b4a132255fdb570aa4202 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 23 Mar 2022 09:30:10 -0500 Subject: [PATCH] refactor: remove StaticHint wrapper, log duplicate hints (#289) * refactor: remove statichint * feat: log.Warn when duplicate hint registration. * style: code cleaning --- backend/backend.go | 21 ++++++-- backend/hint/builtin.go | 12 ++--- backend/hint/hint.go | 50 +++++-------------- backend/hint/registry.go | 18 ++++--- frontend/cs/r1cs/builder.go | 4 +- frontend/cs/scs/builder.go | 4 +- internal/backend/bls12-377/cs/solution.go | 13 ++--- internal/backend/bls12-381/cs/solution.go | 13 ++--- internal/backend/bls24-315/cs/solution.go | 13 ++--- internal/backend/bn254/cs/solution.go | 13 ++--- internal/backend/bw6-633/cs/solution.go | 13 ++--- internal/backend/bw6-761/cs/solution.go | 13 ++--- internal/backend/circuits/hint.go | 25 ++-------- .../template/representations/solution.go.tmpl | 13 ++--- std/algebra/sw_bls12377/g1.go | 4 +- std/algebra/sw_bls24315/g1.go | 4 +- std/hints.go | 30 +++++++---- std/hints_test.go | 21 ++++++++ std/math/bits/conversion_binary.go | 21 +++----- std/math/bits/conversion_ternary.go | 2 +- std/math/bits/naf.go | 2 +- test/engine.go | 2 +- 22 files changed, 127 insertions(+), 184 deletions(-) create mode 100644 std/hints_test.go diff --git a/backend/backend.go b/backend/backend.go index 66078fca8a..a01ac0df19 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -54,16 +54,19 @@ type ProverOption func(*ProverConfig) error // ProverConfig is the configuration for the prover with the options applied. type ProverConfig struct { - Force bool // defaults to false - HintFunctions []hint.Function // defaults to all built-in hint functions - CircuitLogger zerolog.Logger // defaults to gnark.Logger + Force bool // defaults to false + HintFunctions map[hint.ID]hint.Function // defaults to all built-in hint functions + CircuitLogger zerolog.Logger // defaults to gnark.Logger } // NewProverConfig returns a default ProverConfig with given prover options opts // applied. func NewProverConfig(opts ...ProverOption) (ProverConfig, error) { log := logger.Logger() - opt := ProverConfig{CircuitLogger: log, HintFunctions: hint.GetAll()} + opt := ProverConfig{CircuitLogger: log, HintFunctions: make(map[hint.ID]hint.Function)} + for _, v := range hint.GetRegistered() { + opt.HintFunctions[hint.UUID(v)] = v + } for _, option := range opts { if err := option(&opt); err != nil { return ProverConfig{}, err @@ -86,10 +89,18 @@ func IgnoreSolverError() ProverOption { // WithHints is a prover option that specifies additional hint functions to be used // by the constraint solver. func WithHints(hintFunctions ...hint.Function) ProverOption { + log := logger.Logger() return func(opt *ProverConfig) error { // it is an error to register hint function several times, but as the // prover already checks it then omit here. - opt.HintFunctions = append(opt.HintFunctions, hintFunctions...) + for _, h := range hintFunctions { + uuid := hint.UUID(h) + if _, ok := opt.HintFunctions[uuid]; ok { + log.Warn().Int("hintID", int(uuid)).Str("name", hint.Name(h)).Msg("duplicate hint function") + } else { + opt.HintFunctions[uuid] = h + } + } return nil } } diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go index f8d66f1ce8..313414ede6 100644 --- a/backend/hint/builtin.go +++ b/backend/hint/builtin.go @@ -6,18 +6,14 @@ import ( "github.com/consensys/gnark-crypto/ecc" ) -var ( - // IsZero computes the value 1 - a^(modulus-1) for the single input a. This - // corresponds to checking if a == 0 (for which the function returns 1) or a - // != 0 (for which the function returns 0). - IsZero = NewStaticHint(isZero) -) - func init() { Register(IsZero) } -func isZero(curveID ecc.ID, inputs []*big.Int, results []*big.Int) error { +// IsZero computes the value 1 - a^(modulus-1) for the single input a. This +// corresponds to checking if a == 0 (for which the function returns 1) or a +// != 0 (for which the function returns 0). +func IsZero(curveID ecc.ID, inputs []*big.Int, results []*big.Int) error { result := results[0] // get fr modulus diff --git a/backend/hint/hint.go b/backend/hint/hint.go index c660787764..1d09f1eda6 100644 --- a/backend/hint/hint.go +++ b/backend/hint/hint.go @@ -76,37 +76,19 @@ import ( // ID is a unique identifier for a hint function used for lookup. type ID uint32 -// StaticFunction is a function which takes a constant number of inputs and -// returns a constant number of outputs. Use NewStaticHint() to construct an -// instance compatible with Function interface. -type StaticFunction func(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error - -// Function defines an annotated hint function. To initialize a hint function -// with static number of inputs and outputs, use NewStaticHint(). -type Function interface { - // UUID returns an unique identifier for the hint function. UUID is used for - // lookup of the hint function. - UUID() ID - - // Call is invoked by the framework to obtain the result from inputs. - // Elements in outputs are not guaranteed to be initialized to 0 - Call(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error - - // String returns a human-readable description of the function used in logs - // and debug messages. - String() string -} - -func NewStaticHint(fn StaticFunction) Function { - return fn -} +// Function defines an annotated hint function; the number of inputs and outputs injected at solving +// time is defined in the circuit (compile time). +// +// For example: +// b := api.NewHint(hint, 2, a) +// --> at solving time, hint is going to be invoked with 1 input (a) and is expected to return 2 outputs +// b[0] and b[1]. +type Function func(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error // UUID is a reference function for computing the hint ID based on a function name -func UUID(fn StaticFunction) ID { +func UUID(fn Function) ID { hf := fnv.New32a() - name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() - // using a name for identifying different hints should be enough as we get a - // solve-time error when there are duplicate hints with the same signature. + name := Name(fn) // TODO relying on name to derive UUID is risky; if fn is an anonymous func, wil be package.glob..funcN // and if new anonymous functions are added in the package, N may change, so will UUID. @@ -115,15 +97,7 @@ func UUID(fn StaticFunction) ID { return ID(hf.Sum32()) } -func (h StaticFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { - return h(curveID, inputs, res) -} - -func (h StaticFunction) UUID() ID { - return UUID(h) -} - -func (h StaticFunction) String() string { - fnptr := reflect.ValueOf(h).Pointer() +func Name(fn Function) string { + fnptr := reflect.ValueOf(fn).Pointer() return runtime.FuncForPC(fnptr).Name() } diff --git a/backend/hint/registry.go b/backend/hint/registry.go index 0813995a3e..0d59eb0e2d 100644 --- a/backend/hint/registry.go +++ b/backend/hint/registry.go @@ -1,28 +1,30 @@ package hint import ( - "fmt" "sync" + + "github.com/consensys/gnark/logger" ) var registry = make(map[ID]Function) var registryM sync.RWMutex -// Register registers an annotated hint function in the global registry. All -// registered hint functions can be retrieved with a call to GetAll(). It is an -// error to register a single function twice and results in a panic. +// Register registers an hint function in the global registry. func Register(hintFn Function) { registryM.Lock() defer registryM.Unlock() - key := hintFn.UUID() + key := UUID(hintFn) + name := Name(hintFn) if _, ok := registry[key]; ok { - panic(fmt.Sprintf("function %s registered twice", hintFn)) + log := logger.Logger() + log.Warn().Str("name", name).Msg("function registered multiple times") + return } registry[key] = hintFn } -// GetAll returns all registered hint functions. -func GetAll() []Function { +// GetRegistered returns all registered hint functions. +func GetRegistered() []Function { registryM.RLock() defer registryM.RUnlock() ret := make([]Function, 0, len(registry)) diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index 971d607aa2..f531b58652 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -601,7 +601,7 @@ func (system *r1cs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.V } // register the hint as dependency - hintUUID, hintID := f.UUID(), f.String() + hintUUID, hintID := hint.UUID(f), hint.Name(f) if id, ok := system.MHintsDependencies[hintUUID]; ok { // hint already registered, let's ensure string id matches if id != hintID { @@ -636,7 +636,7 @@ func (system *r1cs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.V res[i] = r } - ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} + ch := &compiled.Hint{ID: hintUUID, Inputs: hintInputs, Wires: varIDs} for _, vID := range varIDs { system.MHints[vID] = ch } diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index be0afc41a4..4333c96ba2 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -509,7 +509,7 @@ func (system *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Va } // register the hint as dependency - hintUUID, hintID := f.UUID(), f.String() + hintUUID, hintID := hint.UUID(f), hint.Name(f) if id, ok := system.MHintsDependencies[hintUUID]; ok { // hint already registered, let's ensure string id matches if id != hintID { @@ -541,7 +541,7 @@ func (system *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Va res[i] = r } - ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} + ch := &compiled.Hint{ID: hintUUID, Inputs: hintInputs, Wires: varIDs} for _, vID := range varIDs { system.MHints[vID] = ch } diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index f0f878fac6..be7811f7b8 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -42,20 +42,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -195,7 +188,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index e79c620b87..cc0087a3a7 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -42,20 +42,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -195,7 +188,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 14f33d3fc0..94c1d1d5ef 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -42,20 +42,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -195,7 +188,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 8c26edc1b1..3ba226c875 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -42,20 +42,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -195,7 +188,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index edfd9e7e70..d1f378ee28 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -42,20 +42,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -195,7 +188,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index d2b6173827..084e519385 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -42,20 +42,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -195,7 +188,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index 3b9e99294e..dd4586e001 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -5,7 +5,6 @@ import ( "math/big" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" ) @@ -98,36 +97,20 @@ func init() { } } -var mulBy7 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { +var mulBy7 = func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].Mul(inputs[0], big.NewInt(7)).Mod(result[0], curveID.Info().Fr.Modulus()) return nil -}) +} -var make3 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { +var make3 = func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].SetUint64(3) return nil -}) - -var dvHint = &doubleVector{} - -type doubleVector struct{} - -func (dv *doubleVector) UUID() hint.ID { - return hint.UUID(dv.Call) } -func (dv *doubleVector) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { +var dvHint = func(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { two := big.NewInt(2) for i := range inputs { res[i].Mul(two, inputs[i]) } return nil } - -func (dv *doubleVector) NbOutputs(curveID ecc.ID, nInputs int) (nOutputs int) { - return nInputs -} - -func (dv *doubleVector) String() string { - return "double" -} diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index b094f4091a..42d19d659f 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -22,20 +22,13 @@ type solution struct { mHintsFunctions map[hint.ID]hint.Function } -func newSolution(nbWires int, hintFunctions []hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { +func newSolution(nbWires int, hintFunctions map[hint.ID]hint.Function, hintsDependencies map[hint.ID]string, coefficients []fr.Element) (solution, error) { s := solution{ values: make([]fr.Element, nbWires), coefficients: coefficients, solved: make([]bool, nbWires), - mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - } - - for _, h := range hintFunctions { - if _, ok := s.mHintsFunctions[h.UUID()]; ok { - return s, fmt.Errorf("duplicate hint function %s", h) - } - s.mHintsFunctions[h.UUID()] = h + mHintsFunctions: hintFunctions, } // hintsDependencies is from compile time; it contains the list of hints the solver **needs** @@ -176,7 +169,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { } - err := f.Call(curve.ID, inputs, outputs) + err := f(curve.ID, inputs, outputs) var v fr.Element for i := range outputs { diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index 4f7c8c2ee5..bc680c9edd 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -206,7 +206,7 @@ func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}) *G1Aff } } -var DecomposeScalar = hint.NewStaticHint(func(curve ecc.ID, inputs []*big.Int, res []*big.Int) error { +var DecomposeScalar = func(curve ecc.ID, inputs []*big.Int, res []*big.Int) error { cc := innerCurve(curve) sp := ecc.SplitScalar(inputs[0], cc.glvBasis) res[0].Set(&(sp[0])) @@ -225,7 +225,7 @@ var DecomposeScalar = hint.NewStaticHint(func(curve ecc.ID, inputs []*big.Int, r res[2].Div(res[2], cc.fr) return nil -}) +} func init() { hint.Register(DecomposeScalar) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index f1094fae26..aa381a0836 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -206,7 +206,7 @@ func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}) *G1Aff } } -var DecomposeScalar = hint.NewStaticHint(func(curve ecc.ID, inputs []*big.Int, res []*big.Int) error { +var DecomposeScalar = func(curve ecc.ID, inputs []*big.Int, res []*big.Int) error { cc := innerCurve(curve) sp := ecc.SplitScalar(inputs[0], cc.glvBasis) res[0].Set(&(sp[0])) @@ -225,7 +225,7 @@ var DecomposeScalar = hint.NewStaticHint(func(curve ecc.ID, inputs []*big.Int, r res[2].Div(res[2], cc.fr) return nil -}) +} func init() { hint.Register(DecomposeScalar) diff --git a/std/hints.go b/std/hints.go index fe0fe29f6c..565fb0afce 100644 --- a/std/hints.go +++ b/std/hints.go @@ -1,20 +1,30 @@ package std import ( + "sync" + "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/std/algebra/sw_bls12377" "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/math/bits" ) -// GetHints return std hints that are always injected in gnark solvers -func GetHints() []hint.Function { - return []hint.Function{ - sw_bls24315.DecomposeScalar, - sw_bls12377.DecomposeScalar, - bits.NTrits, - bits.NNAF, - bits.IthBit, - bits.NBits, - } +var registerOnce sync.Once + +// RegisterHints register all gnark/std hints +// In the case where the Solver/Prover code is loaded alongside the circuit, this is not useful. +// However, if a Solver/Prover services consumes serialized constraint systems, it has no way to +// know which hints were registered; caller code may add them through backend.WithHints(...). +func RegisterHints() { + registerOnce.Do(registerHints) +} + +func registerHints() { + // note that importing these packages may already triggers a call to hint.Register(...) + hint.Register(sw_bls24315.DecomposeScalar) + hint.Register(sw_bls12377.DecomposeScalar) + hint.Register(bits.NTrits) + hint.Register(bits.NNAF) + hint.Register(bits.IthBit) + hint.Register(bits.NBits) } diff --git a/std/hints_test.go b/std/hints_test.go new file mode 100644 index 0000000000..b5ad54b92b --- /dev/null +++ b/std/hints_test.go @@ -0,0 +1,21 @@ +package std + +import ( + "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend" +) + +func ExampleRegisterHints() { + // this constraint system correspond to a circuit using gnark/std components which rely on hints + // like bits.ToNAF(...) + var ccs frontend.CompiledConstraintSystem + + // since package bits is not imported, the hint NNAF is not registered + // --> hint.Register(bits.NNAF) + // rather than to keep track on which hints are needed, a prover/solver service can register all + // gnark/std hints with this call + RegisterHints() + + // then --> + ccs.IsSolved(&witness.Witness{}) +} diff --git a/std/math/bits/conversion_binary.go b/std/math/bits/conversion_binary.go index ff19e6ec61..1ed77c0eb5 100644 --- a/std/math/bits/conversion_binary.go +++ b/std/math/bits/conversion_binary.go @@ -8,18 +8,8 @@ import ( "github.com/consensys/gnark/frontend" ) -var ( - // IthBit returns the i-tb bit the input. The function expects exactly two - // integer inputs i and n, takes the little-endian bit representation of n and - // returns its i-th bit. - IthBit = hint.NewStaticHint(ithBit) - - // NBits returns the first bits of the input. The number of returned bits is - // defined by the length of the results slice. - NBits = hint.NewStaticHint(nBits) -) - func init() { + // register hints hint.Register(IthBit) hint.Register(NBits) } @@ -106,7 +96,10 @@ func toBinary(api frontend.API, v frontend.Variable, opts ...BaseConversionOptio return bits } -func ithBit(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { +// IthBit returns the i-tb bit the input. The function expects exactly two +// integer inputs i and n, takes the little-endian bit representation of n and +// returns its i-th bit. +func IthBit(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { result := results[0] if !inputs[1].IsUint64() { result.SetUint64(0) @@ -117,7 +110,9 @@ func ithBit(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { return nil } -func nBits(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { +// NBits returns the first bits of the input. The number of returned bits is +// defined by the length of the results slice. +func NBits(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { n := inputs[0] for i := 0; i < len(results); i++ { results[i].SetUint64(uint64(n.Bit(i))) diff --git a/std/math/bits/conversion_ternary.go b/std/math/bits/conversion_ternary.go index 0d366b66c0..f11b0eb212 100644 --- a/std/math/bits/conversion_ternary.go +++ b/std/math/bits/conversion_ternary.go @@ -11,7 +11,7 @@ import ( // NTrits returns the first trits of the input. The number of returned trits is // defined by the length of the results slice. -var NTrits = hint.NewStaticHint(nTrits) +var NTrits = nTrits func init() { hint.Register(NTrits) diff --git a/std/math/bits/naf.go b/std/math/bits/naf.go index 89d60de10e..8cab2acd74 100644 --- a/std/math/bits/naf.go +++ b/std/math/bits/naf.go @@ -11,7 +11,7 @@ import ( // NNAF returns the NAF decomposition of the input. The number of digits is // defined by the number of elements in the results slice. -var NNAF = hint.NewStaticHint(nNaf) +var NNAF = nNaf func init() { hint.Register(NNAF) diff --git a/test/engine.go b/test/engine.go index 4819834e50..2055166e11 100644 --- a/test/engine.go +++ b/test/engine.go @@ -344,7 +344,7 @@ func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Vari res[i] = new(big.Int) } - err := f.Call(e.curveID, in, res) + err := f(e.curveID, in, res) if err != nil { panic("NewHint: " + err.Error())