Skip to content

Commit

Permalink
feat: impl invModPUint512 hint
Browse files Browse the repository at this point in the history
  • Loading branch information
MartianGreed committed Jul 25, 2024
1 parent 4c37c85 commit 17e4938
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ const (
uint256SqrtCode string = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root.low = root\nids.root.high = 0"
uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low\nb = (ids.b.high << 128) + ids.b.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a * b, div)\n\nids.quotient_low.low = quotient & ((1 << 128) - 1)\nids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)\nids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)\nids.quotient_high.high = quotient >> 384\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"
uint256SubCode string = "def split(num: int, num_bits_shift: int = 128, length: int = 2):\n a = []\n for _ in range(length):\n a.append( num & ((1 << num_bits_shift) - 1) )\n num = num >> num_bits_shift\n return tuple(a)\n\ndef pack(z, num_bits_shift: int = 128) -> int:\n limbs = (z.low, z.high)\n return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))\n\na = pack(ids.a)\nb = pack(ids.b)\nres = (a - b)%2**256\nres_split = split(res)\nids.res.low = res_split[0]\nids.res.high = res_split[1]"
// ------ Uint512 hints related code -------
invModPUint512Code string = "def pack_512(u, num_bits_shift: int) -> int:\n limbs = (u.d0, u.d1, u.d2, u.d3)\n return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))\n\nx = pack_512(ids.x, num_bits_shift = 128)\np = ids.p.low + (ids.p.high << 128)\nx_inverse_mod_p = pow(x,-1, p)\n\nx_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)\n\nids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]\nids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]"
// ------ Usort hints related code ------
usortBodyCode string = "from collections import defaultdict\n\ninput_ptr = ids.input\ninput_len = int(ids.input_len)\nif __usort_max_size is not None:\n assert input_len <= __usort_max_size, (\n f\"usort() can only be used with input_len<={__usort_max_size}. \"\n f\"Got: input_len={input_len}.\"\n )\n\npositions_dict = defaultdict(list)\nfor i in range(input_len):\n val = memory[input_ptr + i]\n positions_dict[val].append(i)\n\noutput = sorted(positions_dict.keys())\nids.output_len = len(output)\nids.output = segments.gen_arg(output)\nids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])"
usortEnterScopeCode string = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))"
Expand Down
3 changes: 3 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createUint256MulDivModHinter(resolver)
case uint256SubCode:
return createUint256SubHinter(resolver)
// Uint512 hints
case invModPUint512Code:
return createInvModPUint512Hinter(resolver)
// Signature hints
case verifyECDSASignatureCode:
return createVerifyECDSASignatureHinter(resolver)
Expand Down
102 changes: 102 additions & 0 deletions pkg/hintrunner/zero/zerohint_uint512.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package zero

import (
"math/big"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

const (
P_LOW = "201385395114098847380338600778089168199"
P_HIGH = "64323764613183177041862057485226039389"
)

func createInvModPUint512Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
x, err := resolver.GetResOperander("x")
if err != nil {
return nil, err
}

xInverseModP, err := resolver.GetResOperander("x_inverse_mod_p")
if err != nil {
return nil, err
}

return newInvModPUint512Hint(x, xInverseModP), nil
}

func newInvModPUint512Hint(x, xInverseModP hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "InvModPUint512",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
// def pack_512(u, num_bits_shift: int) -> int:
// limbs = (u.d0, u.d1, u.d2, u.d3)
// return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
//
// x = pack_512(ids.x, num_bits_shift = 128)
// p = ids.p.low + (ids.p.high << 128)
// x_inverse_mod_p = pow(x,-1, p)
//
// x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)
//
// ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
// ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]
pack512 := func(lolow, loHigh, hiLow, hiHigh *fp.Element, numBitsShift int) big.Int {
var loLowBig, loHighBig, hiLowBig, hiHighBig big.Int
lolow.BigInt(&loLowBig)
loHigh.BigInt(&loHighBig)
hiLow.BigInt(&hiLowBig)
hiHigh.BigInt(&hiHighBig)

return *new(big.Int).Add(new(big.Int).Lsh(&hiHighBig, uint(numBitsShift)), &loLowBig).Add(new(big.Int).Lsh(&hiLowBig, uint(numBitsShift)), &loHighBig)
}
pack := func(low, high *fp.Element, numBitsShift int) big.Int {
var lowBig, highBig big.Int
low.BigInt(&lowBig)
high.BigInt(&highBig)

return *new(big.Int).Add(new(big.Int).Lsh(&highBig, uint(numBitsShift)), &lowBig)
}

xLoLow, xLoHigh, xHiLow, xHiHigh, err := GetUint512AsFelts(vm, x)
if err != nil {
return err
}
pLow, err := new(fp.Element).SetString(P_LOW)
if err != nil {
return err
}
pHigh, err := new(fp.Element).SetString(P_HIGH)
if err != nil {
return err
}

x := pack512(xLoLow, xLoHigh, xHiLow, xHiHigh, 128)
p := pack(pLow, pHigh, 128)

xInverseModPBig := new(big.Int).Exp(&x, big.NewInt(-1), &p)

split := func(num big.Int, numBitsShift uint16, length int) []fp.Element {
a := make([]fp.Element, length)
mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(numBitsShift)), big.NewInt(1))

for i := 0; i < length; i++ {
a[i] = *new(fp.Element).SetBigInt(new(big.Int).And(&num, mask))
num.Rsh(&num, uint(numBitsShift))
}

return a
}

xInverseModPSplit := split(*xInverseModPBig, 128, 2)

resAddr, err := xInverseModP.GetAddress(vm)
if err != nil {
return err
}
return vm.Memory.WriteUint256ToAddress(resAddr, &xInverseModPSplit[0], &xInverseModPSplit[1])
},
}
}
32 changes: 32 additions & 0 deletions pkg/hintrunner/zero/zerohint_uint512_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package zero

import (
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

func TestInvModPUint512(t *testing.T) {
runHinterTests(t, map[string][]hintTestCase{
"InvModPUint512": {
{
operanders: []*hintOperander{
{Name: "x.d0", Kind: apRelative, Value: feltUint64(101)},
{Name: "x.d1", Kind: apRelative, Value: feltUint64(2)},
{Name: "x.d2", Kind: apRelative, Value: feltUint64(15)},
{Name: "x.d3", Kind: apRelative, Value: feltUint64(61)},
{Name: "x_inverse_mod_p.low", Kind: uninitialized},
{Name: "x_inverse_mod_p.high", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newInvModPUint512Hint(ctx.operanders["x.d0"], ctx.operanders["x_inverse_mod_p.low"])
},
check: allVarValueEquals(map[string]*fp.Element{
"x_inverse_mod_p.low": feltString("80275402838848031859800366538378848249"),
"x_inverse_mod_p.high": feltString("5810892639608724280512701676461676039"),
}),
},
},
})
}
64 changes: 64 additions & 0 deletions pkg/hintrunner/zero/zerohint_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,70 @@ func GetUint256AsFelts(vm *VM.VirtualMachine, ref hinter.ResOperander) (*fp.Elem
return low, high, nil
}

func GetUint512AsFelts(vm *VM.VirtualMachine, ref hinter.ResOperander) (*fp.Element, *fp.Element, *fp.Element, *fp.Element, error) {
lowRefAddr, err := ref.GetAddress(vm)
if err != nil {
return nil, nil, nil, nil, err
}

lowPart, err := vm.Memory.ReadFromAddress(&lowRefAddr)
if err != nil {
return nil, nil, nil, nil, err
}

highRefAddr, err := lowRefAddr.AddOffset(1)
if err != nil {
return nil, nil, nil, nil, err
}

highPart, err := vm.Memory.ReadFromAddress(&highRefAddr)
if err != nil {
return nil, nil, nil, nil, err
}

highLowRefAddr, err := highRefAddr.AddOffset(1)
if err != nil {
return nil, nil, nil, nil, err
}

highLowPart, err := vm.Memory.ReadFromAddress(&highLowRefAddr)
if err != nil {
return nil, nil, nil, nil, err
}

highHighRefAddr, err := highLowRefAddr.AddOffset(1)
if err != nil {
return nil, nil, nil, nil, err
}

highHighPart, err := vm.Memory.ReadFromAddress(&highHighRefAddr)
if err != nil {
return nil, nil, nil, nil, err
}

lowLow, err := lowPart.FieldElement()
if err != nil {
return nil, nil, nil, nil, err
}

lowHigh, err := highPart.FieldElement()
if err != nil {
return nil, nil, nil, nil, err
}

highLow, err := highLowPart.FieldElement()
if err != nil {
return nil, nil, nil, nil, err
}

highHigh, err := highHighPart.FieldElement()
if err != nil {
return nil, nil, nil, nil, err
}

return lowLow, lowHigh, highLow, highHigh, nil
}

// This helper function is used in FastEcAddAssignNewY and
// EcDoubleAssignNewYV1 hints to compute the y-coordinate of
// a point on an elliptic curve
Expand Down

0 comments on commit 17e4938

Please sign in to comment.