Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add automatic non-native witness element limb constraining #446

Merged
merged 8 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion frontend/internal/expr/linear_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ func (t Term) WireID() int {
return t.VID
}

func (t Term) HashCode() uint64 {
return t.Coeff[0]*29 + uint64(t.VID<<12)
}

// Len return the lenght of the Variable (implements Sort interface)
func (l LinearExpression) Len() int {
return len(l)
Expand Down Expand Up @@ -77,7 +81,7 @@ func (l LinearExpression) Less(i, j int) bool {
func (l LinearExpression) HashCode() uint64 {
h := uint64(17)
for _, val := range l {
h = h*23 + val.Coeff[0]*29 + uint64(val.VID<<12) // TODO @gbotrel revisit
h = h*23 + val.HashCode() // TODO @gbotrel revisit
}
return h
}
6 changes: 5 additions & 1 deletion frontend/internal/expr/linear_expression_scs_torefactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func (t TermToRefactor) WireID() int {
return t.VID
}

func (t TermToRefactor) HashCode() uint64 {
return uint64(t.CID) + uint64(t.VID<<32)
}

// Len return the lenght of the Variable (implements Sort interface)
func (l LinearExpressionToRefactor) Len() int {
return len(l)
Expand Down Expand Up @@ -68,7 +72,7 @@ func (l LinearExpressionToRefactor) Less(i, j int) bool {
func (l LinearExpressionToRefactor) HashCode() uint64 {
h := uint64(17)
for _, val := range l {
h = h*23 + uint64(val.CID) + uint64(val.VID<<32) // TODO @gbotrel revisit
h = h*23 + val.HashCode() // TODO @gbotrel revisit
}
return h
}
Binary file modified internal/stats/latest.stats
Binary file not shown.
110 changes: 60 additions & 50 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,48 +13,54 @@ import (
// a slice of limbs. The type parameter defines the field this element belongs
// to.
type Element[T FieldParams] struct {
Limbs []frontend.Variable // in little-endian (least significant limb first) encoding
// Limbs is the decomposition of the integer value into limbs in the native
// field. To enforce that the limbs are of expected width, use Pack...
// methods on the Field. Uses little-endian (least significant limb first)
// encoding.
Limbs []frontend.Variable

// overflow indicates the number of additions on top of the normal form. To
// ensure that none of the limbs overflow the scalar field of the snark
// curve, we must check that nbBits+overflow < floor(log2(fr modulus))
overflow uint

// internal indicates if the element is returned from [Field] methods. If
// so, then we can assume that the limbs are already constrained to be
// correct width. If the flag is not set, then the Element most probably
// comes from the witness (or constructed by the user). Then we have to
// ensure that the limbs are width-constrained. We do not store the
// enforcement info in the Element to prevent modifying the witness.
internal bool
}

// NewElement builds a new emulated element from input. The inputs can be:
// - of type Element[T] (or a pointer to it). Then, the limbs are cloned and packed into new Element[T],
// - integer-like. Then it is cast to [*big.Int], decomposed into limbs and packed into new Element[T],
// - a variable in circuit. Then, it is packed into new Element[T]
func NewElement[T FieldParams](v interface{}) Element[T] {
r := Element[T]{}
var fp T

if v == nil {
r.Limbs = make([]frontend.Variable, fp.NbLimbs())
for i := 0; i < len(r.Limbs); i++ {
r.Limbs[i] = 0
}

return r
}
switch tv := v.(type) {
case Element[T]:
r.Limbs = make([]frontend.Variable, len(tv.Limbs))
copy(r.Limbs, tv.Limbs)
r.overflow = tv.overflow
return r
case *Element[T]:
r.Limbs = make([]frontend.Variable, len(tv.Limbs))
copy(r.Limbs, tv.Limbs)
r.overflow = tv.overflow
return r
}
if frontend.IsCanonical(v) {
// TODO @gbotrel @ivokub check this -- seems oddd.
r.Limbs = []frontend.Variable{v}
return r
if e, ok := v.(Element[T]); ok {
return *e.copy()
} else if e, ok := v.(*Element[T]); ok {
return *e.copy()
} else if frontend.IsCanonical(v) {
// We can not force that v is correct width. Better use PackFullLimbs
panic("can not enforec limb width. Use PackLimbs instead")
} else if v == nil {
r := newConstElement[T](0)
r.internal = false
return *r
} else {
r := newConstElement[T](v)
r.internal = false
return *r
}
}

// newConstElement is shorthand for initialising new element using NewElement and
// taking pointer to it. We only want to have a public method for initialising
// an element which return a value because the user uses this only for witness
// creation and it mess up schema parsing.
func newConstElement[T FieldParams](v interface{}) *Element[T] {
var fp T
// convert to big.Int
bValue := utils.FromInterface(v)

Expand All @@ -65,41 +71,45 @@ func NewElement[T FieldParams](v interface{}) Element[T] {

// decompose into limbs
// TODO @gbotrel use big.Int pool here
limbs := make([]*big.Int, fp.NbLimbs())
for i := range limbs {
limbs[i] = new(big.Int)
blimbs := make([]*big.Int, fp.NbLimbs())
for i := range blimbs {
blimbs[i] = new(big.Int)
}
if err := decompose(&bValue, fp.BitsPerLimb(), limbs); err != nil {
if err := decompose(&bValue, fp.BitsPerLimb(), blimbs); err != nil {
panic(fmt.Errorf("decompose value: %w", err))
}

// assign limb values
r.Limbs = make([]frontend.Variable, fp.NbLimbs())
limbs := make([]frontend.Variable, len(blimbs))
for i := range limbs {
r.Limbs[i] = frontend.Variable(limbs[i])
limbs[i] = frontend.Variable(blimbs[i])
}
return &Element[T]{
Limbs: limbs,
overflow: 0,
internal: true,
}

return r
}

// newElementPtr is shorthand for initialising new element using NewElement and
// taking pointer to it. We only want to have a public method for initialising
// an element which return a value because the user uses this only for witness
// creation and it mess up schema parsing.
func newElementPtr[T FieldParams](v interface{}) *Element[T] {
el := NewElement[T](v)
return &el
}

// newElementLimbs sets the limbs and overflow. Given as a function for later
// newInternalElement sets the limbs and overflow. Given as a function for later
// possible refactor.
func newElementLimbs[T FieldParams](limbs []frontend.Variable, overflow uint) *Element[T] {
return &Element[T]{Limbs: limbs, overflow: overflow}
func (f *Field[T]) newInternalElement(limbs []frontend.Variable, overflow uint) *Element[T] {
return &Element[T]{Limbs: limbs, overflow: overflow, internal: true}
}

// GnarkInitHook describes how to initialise the element.
func (e *Element[T]) GnarkInitHook() {
if e.Limbs == nil {
*e = NewElement[T](nil)
*e = NewElement[T](0)
}
}

// copy makes a deep copy of the element.
func (e *Element[T]) copy() *Element[T] {
r := Element[T]{}
r.Limbs = make([]frontend.Variable, len(e.Limbs))
copy(r.Limbs, e.Limbs)
r.overflow = e.overflow
r.internal = e.internal
return &r
}
6 changes: 3 additions & 3 deletions std/math/emulated/element_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ func (c *FromBinaryCircuit[T]) Define(api frontend.API) error {
if err != nil {
return err
}
res := f.FromBinary(c.Bits)
res := f.FromBinary(c.Bits...)
f.AssertIsEqual(res, c.Res)
return nil
}
Expand Down Expand Up @@ -724,10 +724,10 @@ func TestOptimisation(t *testing.T) {
}
ccs, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit)
assert.NoError(err)
assert.LessOrEqual(ccs.GetNbConstraints(), 3757)
assert.LessOrEqual(ccs.GetNbConstraints(), 5577)
ccs2, err := frontend.Compile(testCurve.ScalarField(), scs.NewBuilder, &circuit)
assert.NoError(err)
assert.LessOrEqual(ccs2.GetNbConstraints(), 10483)
assert.LessOrEqual(ccs2.GetNbConstraints(), 12977)
}

type FourMulsCircuit[T FieldParams] struct {
Expand Down
73 changes: 60 additions & 13 deletions std/math/emulated/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ type Field[T FieldParams] struct {

// constants for often used elements n, 0 and 1. Allocated only once
nConstOnce sync.Once
nConst Element[T]
nConst *Element[T]
zeroConstOnce sync.Once
zeroConst Element[T]
zeroConst *Element[T]
oneConstOnce sync.Once
oneConst Element[T]
oneConst *Element[T]

log zerolog.Logger

constrainedLimbs map[uint64]struct{}
}

// NewField returns an object to be used in-circuit to perform emulated
Expand All @@ -47,8 +49,9 @@ type Field[T FieldParams] struct {
// is extremly costly. See package doc for more info.
func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
f := &Field[T]{
api: native,
log: logger.Logger(),
api: native,
log: logger.Logger(),
constrainedLimbs: make(map[uint64]struct{}),
}

// ensure prime is correctly set
Expand Down Expand Up @@ -86,43 +89,86 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) {
// Zero returns zero as a constant.
func (f *Field[T]) Zero() *Element[T] {
f.zeroConstOnce.Do(func() {
f.zeroConst = NewElement[T](nil)
f.zeroConst = newConstElement[T](0)
})
return &f.zeroConst
return f.zeroConst
}

// One returns one as a constant.
func (f *Field[T]) One() *Element[T] {
f.oneConstOnce.Do(func() {
f.oneConst = NewElement[T](1)
f.oneConst = newConstElement[T](1)
})
return &f.oneConst
return f.oneConst
}

// Modulus returns the modulus of the emulated ring as a constant.
func (f *Field[T]) Modulus() *Element[T] {
f.nConstOnce.Do(func() {
f.nConst = NewElement[T](f.fParams.Modulus())
f.nConst = newConstElement[T](f.fParams.Modulus())
})
return &f.nConst
return f.nConst
}

// PackElementLimbs returns an element from the given limbs. The method
// constrains the limbs to have same width as the modulus of the field.
func (f *Field[T]) PackElementLimbs(limbs []frontend.Variable) *Element[T] {
e := newElementLimbs[T](limbs, 0)
e := f.newInternalElement(limbs, 0)
f.enforceWidth(e, true)
return e
}

// PackFullLimbs creates an element from the given limbs and enforces every limb
// to have NbBits bits.
func (f *Field[T]) PackFullLimbs(limbs []frontend.Variable) *Element[T] {
e := newElementLimbs[T](limbs, 0)
e := f.newInternalElement(limbs, 0)
f.enforceWidth(e, false)
return e
}

func (f *Field[T]) enforceWidthConditional(a *Element[T]) (didConstrain bool) {
if a == nil {
// for some reason called on nil
return false
}
if a.internal {
// internal elements are already constrained in the method which returned it
return false
}
if _, isConst := f.constantValue(a); isConst {
// constant values are constant
return false
}
for i := range a.Limbs {
if !frontend.IsCanonical(a.Limbs[i]) {
// this is not a variable. This may happen when some limbs are
// constant and some variables. A strange case but lets try to cover
// it anyway.
continue
}
if vv, ok := a.Limbs[i].(interface{ HashCode() uint64 }); ok {
// okay, this is a canonical variable and it has a hashcode. We use
// it to see if the limb is already constrained.
h := vv.HashCode()
if _, ok := f.constrainedLimbs[h]; !ok {
// we found a limb which hasn't yet been constrained. This means
// that we should enforce width for the whole element. But we
// still iterate over all limbs just to mark them in the table.
didConstrain = true
f.constrainedLimbs[h] = struct{}{}
}
} else {
// we have no way of knowing if the limb has been constrained. To be
// on the safe side constrain the whole element again.
didConstrain = true
}
}
if didConstrain {
f.enforceWidth(a, false)
}
return
}

func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) {
var ok bool

Expand All @@ -147,6 +193,7 @@ func (f *Field[T]) constantValue(v *Element[T]) (*big.Int, bool) {
// combination in a single new limb.
// compact returns a and b minimal (in number of limbs) representation that fits in the snark field
func (f *Field[T]) compact(a, b *Element[T]) (ac, bc []frontend.Variable, bitsPerLimb uint) {
// omit width reduction as is done in the calling method already
maxOverflow := max(a.overflow, b.overflow)
// subtract one bit as can not potentially use all bits of Fr and one bit as
// grouping may overflow
Expand Down
4 changes: 4 additions & 0 deletions std/math/emulated/field_assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) fronte
// to overflow). This method does not ensure that the values are equal modulo
// the field order. For strict equality, use AssertIsEqual.
func (f *Field[T]) AssertLimbsEquality(a, b *Element[T]) {
f.enforceWidthConditional(a)
f.enforceWidthConditional(b)
ba, aConst := f.constantValue(a)
bb, bConst := f.constantValue(b)
if aConst && bConst {
Expand Down Expand Up @@ -141,6 +143,7 @@ func (f *Field[T]) enforceWidth(a *Element[T], modWidth bool) {

// AssertIsEqual ensures that a is equal to b modulo the modulus.
func (f *Field[T]) AssertIsEqual(a, b *Element[T]) {
// we omit width assertion as it is done in Sub below
ba, aConst := f.constantValue(a)
bb, bConst := f.constantValue(b)
if aConst && bConst {
Expand Down Expand Up @@ -170,6 +173,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) {

// AssertIsEqualLessThan ensures that e is less or equal than e.
func (f *Field[T]) AssertIsLessEqualThan(e, a *Element[T]) {
// we omit conditional width assertion as is done in ToBits below
if e.overflow+a.overflow > 0 {
panic("inputs must have 0 overflow")
}
Expand Down
3 changes: 2 additions & 1 deletion std/math/emulated/field_binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// representation of Element, reduce Element first and take less significant
// bits corresponding to the bitwidth of the emulated modulus.
func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable {
f.enforceWidthConditional(a)
ba, aConst := f.constantValue(a)
if aConst {
return f.api.ToBinary(ba, int(f.fParams.BitsPerLimb()*f.fParams.NbLimbs()))
Expand All @@ -37,5 +38,5 @@ func (f *Field[T]) FromBits(bs ...frontend.Variable) *Element[T] {
limbs[i] = bits.FromBinary(f.api, bs[i*f.fParams.BitsPerLimb():(i+1)*f.fParams.BitsPerLimb()])
}
limbs[nbLimbs-1] = bits.FromBinary(f.api, bs[(nbLimbs-1)*f.fParams.BitsPerLimb():])
return newElementLimbs[T](limbs, 0)
return f.newInternalElement(limbs, 0)
}
Loading