diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 1300484fb..faf2c0781 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -57,6 +57,8 @@ const ( uint256SqrtCode string = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root.low = root\nids.root.high = 0" uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low\nb = (ids.b.high << 128) + ids.b.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a * b, div)\n\nids.quotient_low.low = quotient & ((1 << 128) - 1)\nids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)\nids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)\nids.quotient_high.high = quotient >> 384\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128" uint256SubCode string = "def split(num: int, num_bits_shift: int = 128, length: int = 2):\n a = []\n for _ in range(length):\n a.append( num & ((1 << num_bits_shift) - 1) )\n num = num >> num_bits_shift\n return tuple(a)\n\ndef pack(z, num_bits_shift: int = 128) -> int:\n limbs = (z.low, z.high)\n return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))\n\na = pack(ids.a)\nb = pack(ids.b)\nres = (a - b)%2**256\nres_split = split(res)\nids.res.low = res_split[0]\nids.res.high = res_split[1]" + // ------ Uint512 hints related code ------- + invModPUint512Code string = "from starkware.cairo.common.uint512 import Uint512\n\nvalue = Uint512(ids.value, ids.modulus)\nvalue_inv = value.inv_mod(ids.modulus)\ni" // ------ Usort hints related code ------ usortBodyCode string = "from collections import defaultdict\n\ninput_ptr = ids.input\ninput_len = int(ids.input_len)\nif __usort_max_size is not None:\n assert input_len <= __usort_max_size, (\n f\"usort() can only be used with input_len<={__usort_max_size}. \"\n f\"Got: input_len={input_len}.\"\n )\n\npositions_dict = defaultdict(list)\nfor i in range(input_len):\n val = memory[input_ptr + i]\n positions_dict[val].append(i)\n\noutput = sorted(positions_dict.keys())\nids.output_len = len(output)\nids.output = segments.gen_arg(output)\nids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])" usortEnterScopeCode string = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))" diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index adb88b982..8c1e41178 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -111,6 +111,9 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createUint256MulDivModHinter(resolver) case uint256SubCode: return createUint256SubHinter(resolver) + // Uint512 hints + case invModPUint512Code: + return createInvModPUint512Hinter(resolver) // Signature hints case verifyECDSASignatureCode: return createVerifyECDSASignatureHinter(resolver) diff --git a/pkg/hintrunner/zero/zerohint_uint512.go b/pkg/hintrunner/zero/zerohint_uint512.go new file mode 100644 index 000000000..6fb797617 --- /dev/null +++ b/pkg/hintrunner/zero/zerohint_uint512.go @@ -0,0 +1,101 @@ +package zero + +import ( + "math/big" + + "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +const ( + P_LOW = "201385395114098847380338600778089168199" + P_HIGH = "64323764613183177041862057485226039389" +) + +func createInvModPUint512Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + x, err := resolver.GetResOperander("x") + if err != nil { + return nil, err + } + + xInverseModP, err := resolver.GetResOperander("x_inverse_mod_p") + if err != nil { + return nil, err + } + + return newInvModPUint512Hint(x, xInverseModP), nil +} + +func newInvModPUint512Hint(x, xInverseModP hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "InvModPUint512", + Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error { + // def pack_512(u, num_bits_shift: int) -> int: + // limbs = (u.d0, u.d1, u.d2, u.d3) + // return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + // + // x = pack_512(ids.x, num_bits_shift = 128) + // p = ids.p.low + (ids.p.high << 128) + // x_inverse_mod_p = pow(x,-1, p) + // + // x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128) + // + // ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0] + // ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1] + pack512 := func(lolow, loHigh, hiLow, hiHigh *fp.Element, numBitsShift int) big.Int { + var loLowBig, loHighBig, hiLowBig, hiHighBig big.Int + lolow.BigInt(&loLowBig) + loHigh.BigInt(&loHighBig) + hiLow.BigInt(&hiLowBig) + hiHigh.BigInt(&hiHighBig) + + return *new(big.Int).Add(new(big.Int).Lsh(&hiHighBig, uint(numBitsShift)), &loLowBig).Add(new(big.Int).Lsh(&hiLowBig, uint(numBitsShift)), &loHighBig) + } + pack := func(low, high *fp.Element, numBitsShift int) big.Int { + var lowBig, highBig big.Int + low.BigInt(&lowBig) + high.BigInt(&highBig) + + return *new(big.Int).Add(new(big.Int).Lsh(&highBig, uint(numBitsShift)), &lowBig) + } + + xLoLow, xLoHigh, xHiLow, xHiHigh, err := GetUint512AsFelts(vm, x) + if err != nil { + return err + } + pLow, err := new(fp.Element).SetString(P_LOW) + if err != nil { + return err + } + pHigh, err := new(fp.Element).SetString(P_HIGH) + if err != nil { + return err + } + + x := pack512(xLoLow, xLoHigh, xHiLow, xHiHigh, 128) + p := pack(pLow, pHigh, 128) + xInverseModPBig := new(big.Int).Exp(&x, big.NewInt(-1), &p) + + split := func(num big.Int, numBitsShift uint16, length int) []fp.Element { + a := make([]fp.Element, length) + mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(numBitsShift)), big.NewInt(1)) + + for i := 0; i < length; i++ { + a[i] = *new(fp.Element).SetBigInt(new(big.Int).And(&num, mask)) + num.Rsh(&num, uint(numBitsShift)) + } + + return a + } + + xInverseModPSplit := split(*xInverseModPBig, 128, 2) + + resAddr, err := xInverseModP.GetAddress(vm) + if err != nil { + return err + } + return vm.Memory.WriteUint256ToAddress(resAddr, &xInverseModPSplit[0], &xInverseModPSplit[1]) + }, + } +} diff --git a/pkg/hintrunner/zero/zerohint_uint512_test.go b/pkg/hintrunner/zero/zerohint_uint512_test.go new file mode 100644 index 000000000..b8d21d9d7 --- /dev/null +++ b/pkg/hintrunner/zero/zerohint_uint512_test.go @@ -0,0 +1,32 @@ +package zero + +import ( + "testing" + + "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +func TestInvModPUint512(t *testing.T) { + runHinterTests(t, map[string][]hintTestCase{ + "InvModPUint512": { + { + operanders: []*hintOperander{ + {Name: "x.d0", Kind: apRelative, Value: feltUint64(101)}, + {Name: "x.d1", Kind: apRelative, Value: feltUint64(2)}, + {Name: "x.d2", Kind: apRelative, Value: feltUint64(15)}, + {Name: "x.d3", Kind: apRelative, Value: feltUint64(61)}, + {Name: "x_inverse_mod_p.low", Kind: uninitialized}, + {Name: "x_inverse_mod_p.high", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newInvModPUint512Hint(ctx.operanders["x.d0"], ctx.operanders["x_inverse_mod_p.low"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "x_inverse_mod_p.low": feltString("80275402838848031859800366538378848249"), + "x_inverse_mod_p.high": feltString("5810892639608724280512701676461676039"), + }), + }, + }, + }) +} diff --git a/pkg/hintrunner/zero/zerohint_utils.go b/pkg/hintrunner/zero/zerohint_utils.go index cf0435de0..bdba9fdbb 100644 --- a/pkg/hintrunner/zero/zerohint_utils.go +++ b/pkg/hintrunner/zero/zerohint_utils.go @@ -48,6 +48,70 @@ func GetUint256AsFelts(vm *VM.VirtualMachine, ref hinter.ResOperander) (*fp.Elem return low, high, nil } +func GetUint512AsFelts(vm *VM.VirtualMachine, ref hinter.ResOperander) (*fp.Element, *fp.Element, *fp.Element, *fp.Element, error) { + lowRefAddr, err := ref.GetAddress(vm) + if err != nil { + return nil, nil, nil, nil, err + } + + lowPart, err := vm.Memory.ReadFromAddress(&lowRefAddr) + if err != nil { + return nil, nil, nil, nil, err + } + + highRefAddr, err := lowRefAddr.AddOffset(1) + if err != nil { + return nil, nil, nil, nil, err + } + + highPart, err := vm.Memory.ReadFromAddress(&highRefAddr) + if err != nil { + return nil, nil, nil, nil, err + } + + highLowRefAddr, err := highRefAddr.AddOffset(1) + if err != nil { + return nil, nil, nil, nil, err + } + + highLowPart, err := vm.Memory.ReadFromAddress(&highLowRefAddr) + if err != nil { + return nil, nil, nil, nil, err + } + + highHighRefAddr, err := highLowRefAddr.AddOffset(1) + if err != nil { + return nil, nil, nil, nil, err + } + + highHighPart, err := vm.Memory.ReadFromAddress(&highHighRefAddr) + if err != nil { + return nil, nil, nil, nil, err + } + + lowLow, err := lowPart.FieldElement() + if err != nil { + return nil, nil, nil, nil, err + } + + lowHigh, err := highPart.FieldElement() + if err != nil { + return nil, nil, nil, nil, err + } + + highLow, err := highLowPart.FieldElement() + if err != nil { + return nil, nil, nil, nil, err + } + + highHigh, err := highHighPart.FieldElement() + if err != nil { + return nil, nil, nil, nil, err + } + + return lowLow, lowHigh, highLow, highHigh, nil +} + // This helper function is used in FastEcAddAssignNewY and // EcDoubleAssignNewYV1 hints to compute the y-coordinate of // a point on an elliptic curve