Skip to content

Commit

Permalink
PackedSha256 + Sha256AndBlake2sInput hints (#603)
Browse files Browse the repository at this point in the history
* packedSha256

* add hintcode''

* fmt

* test template

* add sha256 helpers

* initialize slice

* add unit test

* improve unit test

* integration test

* add check for length of input

* add unit test for Sha256AndBlake2sInput

* fmt

* use AddOfsset method
  • Loading branch information
TAdev0 authored Jul 29, 2024
1 parent 1afc25a commit 3bacdc9
Show file tree
Hide file tree
Showing 12 changed files with 767 additions and 9 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
%builtins range_check bitwise keccak

from starkware.cairo.common.cairo_builtins import KeccakBuiltin, BitwiseBuiltin
from starkware.cairo.common.builtin_keccak.keccak import keccak_uint256s
from starkware.cairo.common.alloc import alloc
Expand Down
5 changes: 5 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ const (
blake2sFinalizeV3Code string = "# Add dummy pairs of input and output.\nfrom starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress\n\n_n_packed_instances = int(ids.N_PACKED_INSTANCES)\nassert 0 <= _n_packed_instances < 20\n_blake2s_input_chunk_size_felts = int(ids.BLAKE2S_INPUT_CHUNK_SIZE_FELTS)\nassert 0 <= _blake2s_input_chunk_size_felts < 100\n\nmessage = [0] * _blake2s_input_chunk_size_felts\nmodified_iv = [IV[0] ^ 0x01010020] + IV[1:]\noutput = blake2s_compress(\n message=message,\n h=modified_iv,\n t0=0,\n t1=0,\n f0=0xffffffff,\n f1=0,\n)\npadding = (message + modified_iv + [0, 0xffffffff] + output) * (_n_packed_instances - 1)\nsegments.write_arg(ids.blake2s_ptr_end, padding)"
blake2sComputeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import compute_blake2s_func\ncompute_blake2s_func(segments=segments, output_ptr=ids.output)"

// ------ Sha256 Hash hints related code ------

packedSha256Code string = "from starkware.cairo.common.cairo_sha256.sha256_utils import (\n IV, compute_message_schedule, sha2_compress_function)\n\n_sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS)\nassert 0 <= _sha256_input_chunk_size_felts < 100\n\nw = compute_message_schedule(memory.get_range(\n ids.sha256_start, _sha256_input_chunk_size_felts))\nnew_state = sha2_compress_function(IV, w)\nsegments.write_arg(ids.output, new_state)"

// ------ Keccak hints related code ------
unsafeKeccakFinalizeCode string = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
Expand Down Expand Up @@ -168,4 +172,5 @@ const (
setAddCode string = "assert ids.elm_size > 0\nassert ids.set_ptr <= ids.set_end_ptr\nelm_list = memory.get_range(ids.elm_ptr, ids.elm_size)\nfor i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):\n if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:\n ids.index = i // ids.elm_size\n ids.is_elm_in_set = 1\n break\nelse:\n ids.is_elm_in_set = 0"
searchSortedLowerCode string = "array_ptr = ids.array_ptr\nelm_size = ids.elm_size\nassert isinstance(elm_size, int) and elm_size > 0, \\\n f'Invalid value for elm_size. Got: {elm_size}.'\n\nn_elms = ids.n_elms\nassert isinstance(n_elms, int) and n_elms >= 0, \\\n f'Invalid value for n_elms. Got: {n_elms}.'\nif '__find_element_max_size' in globals():\n assert n_elms <= __find_element_max_size, \\\n f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \\\n f'Got: n_elms={n_elms}.'\n\nfor i in range(n_elms):\n if memory[array_ptr + elm_size * i] >= ids.key:\n ids.index = i\n break\nelse:\n ids.index = n_elms"
normalizeAddressCode string = "# Verify the assumptions on the relationship between 2**250, ADDR_BOUND and PRIME.\nADDR_BOUND = ids.ADDR_BOUND % PRIME\nassert (2**250 < ADDR_BOUND <= 2**251) and (2 * 2**250 < PRIME) and (\n ADDR_BOUND * 2 > PRIME), \\\n 'normalize_address() cannot be used with the current constants.'\nids.is_small = 1 if ids.addr < ADDR_BOUND else 0"
sha256AndBlake2sInputCode string = "ids.full_word = int(ids.n_bytes >= 4)"
)
5 changes: 5 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint) (hinter.Hinte
return createBlake2sFinalizeV3Hinter(resolver)
case blake2sComputeCode:
return createBlake2sComputeHinter(resolver)
// Sha256 hints
case packedSha256Code:
return createPackedSha256Hinter(resolver)
// Keccak hints
case keccakWriteArgsCode:
return createKeccakWriteArgsHinter(resolver)
Expand Down Expand Up @@ -314,6 +317,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint) (hinter.Hinte
return createNondetElementsOverXHinter(resolver, 10)
case normalizeAddressCode:
return createNormalizeAddressHinter(resolver)
case sha256AndBlake2sInputCode:
return createSha256AndBlake2sInputHinter(resolver)
default:
return nil, fmt.Errorf("not identified hint: \n%s", rawHint.Code)
}
Expand Down
47 changes: 47 additions & 0 deletions pkg/hintrunner/zero/zerohint_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,50 @@ func createNormalizeAddressHinter(resolver hintReferenceResolver) (hinter.Hinter

return newNormalizeAddressHint(isSmall, addr), nil
}

// Sha256AndBlake2sInput hint writes 1 or 0 at `full_word` address, wether `n_bytes“
// is greater than or equal to 4 or not
//
// `newSha256AndBlake2sInputHint` takes 2 arguments
// - `full_word` represents the address where the result of the comparison is stored
// - `n_bytes` represents the value that will be compared to 4
func newSha256AndBlake2sInputHint(fullWord, nbytes hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "Sha256AndBlake2sInput",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> ids.full_word = int(ids.n_bytes >= 4)

n_bytes, err := hinter.ResolveAsFelt(vm, nbytes)
if err != nil {
return err
}

fullWordAddr, err := fullWord.GetAddress(vm)
if err != nil {
return err
}

var resultMv memory.MemoryValue
if n_bytes.Cmp(new(fp.Element).SetUint64(4)) >= 0 {
resultMv = memory.MemoryValueFromFieldElement(&utils.FeltOne)
} else {
resultMv = memory.MemoryValueFromFieldElement(&utils.FeltZero)
}
return vm.Memory.WriteToAddress(&fullWordAddr, &resultMv)
},
}
}

func createSha256AndBlake2sInputHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
fullWord, err := resolver.GetResOperander("full_word")
if err != nil {
return nil, err
}

nBytes, err := resolver.GetResOperander("n_bytes")
if err != nil {
return nil, err
}

return newSha256AndBlake2sInputHint(fullWord, nBytes), nil
}
41 changes: 41 additions & 0 deletions pkg/hintrunner/zero/zerohint_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,5 +634,46 @@ func TestZeroHintOthers(t *testing.T) {
check: varValueEquals("is_small", feltUint64(1)),
},
},
"Sha256AndBlake2sInput": {
{
operanders: []*hintOperander{
{Name: "full_word", Kind: uninitialized},
{Name: "n_bytes", Kind: apRelative, Value: feltUint64(3)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSha256AndBlake2sInputHint(
ctx.operanders["full_word"],
ctx.operanders["n_bytes"],
)
},
check: varValueEquals("full_word", feltUint64(0)),
},
{
operanders: []*hintOperander{
{Name: "full_word", Kind: uninitialized},
{Name: "n_bytes", Kind: apRelative, Value: feltUint64(4)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSha256AndBlake2sInputHint(
ctx.operanders["full_word"],
ctx.operanders["n_bytes"],
)
},
check: varValueEquals("full_word", feltUint64(1)),
},
{
operanders: []*hintOperander{
{Name: "full_word", Kind: uninitialized},
{Name: "n_bytes", Kind: apRelative, Value: feltUint64(5)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSha256AndBlake2sInputHint(
ctx.operanders["full_word"],
ctx.operanders["n_bytes"],
)
},
check: varValueEquals("full_word", feltUint64(1)),
},
},
})
}
89 changes: 89 additions & 0 deletions pkg/hintrunner/zero/zerohint_sha256.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package zero

import (
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
hintrunnerUtils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"

VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
)

func newPackedSha256Hint(sha256Start, output hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "PackedSha256",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from starkware.cairo.common.cairo_sha256.sha256_utils import (
//> IV, compute_message_schedule, sha2_compress_function)
//>
//> _sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS)
//> assert 0 <= _sha256_input_chunk_size_felts < 100
//>
//> w = compute_message_schedule(memory.get_range(
//> ids.sha256_start, _sha256_input_chunk_size_felts))
//> new_state = sha2_compress_function(IV, w)
//> segments.write_arg(ids.output, new_state)

Sha256InputChunkSize := uint64(16)

sha256Start, err := hinter.ResolveAsAddress(vm, sha256Start)
if err != nil {
return err
}

w, err := vm.Memory.GetConsecutiveMemoryValues(*sha256Start, Sha256InputChunkSize)
if err != nil {
return err
}

wUint32 := make([]uint32, len(w))
for i := 0; i < len(w); i++ {
value, err := hintrunnerUtils.ToSafeUint32(&w[i])
if err != nil {
return err
}
wUint32[i] = value
}

messageSchedule, err := utils.ComputeMessageSchedule(wUint32)
if err != nil {
return err
}
newState := utils.Sha256Compress(utils.IV(), messageSchedule)

output, err := hinter.ResolveAsAddress(vm, output)
if err != nil {
return err
}

for i := 0; i < len(newState); i++ {
newStateValue := mem.MemoryValueFromInt(newState[i])
outputOffset, err := output.AddOffset(int16(i))
if err != nil {
return err
}

err = vm.Memory.WriteToAddress(&outputOffset, &newStateValue)
if err != nil {
return err
}
}

return nil
},
}
}

func createPackedSha256Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
sha256Start, err := resolver.GetResOperander("sha256_start")
if err != nil {
return nil, err
}

output, err := resolver.GetResOperander("output")
if err != nil {
return nil, err
}

return newPackedSha256Hint(sha256Start, output), nil
}
52 changes: 52 additions & 0 deletions pkg/hintrunner/zero/zerohint_sha256_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package zero

import (
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

func TestZeroHintSha256(t *testing.T) {
runHinterTests(t, map[string][]hintTestCase{
"PackedSha256": {
{
operanders: []*hintOperander{
{Name: "sha256_start", Kind: apRelative, Value: addr(6)},
{Name: "output", Kind: apRelative, Value: addr(22)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
{Name: "buffer", Kind: apRelative, Value: feltUint64(0)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newPackedSha256Hint(ctx.operanders["sha256_start"], ctx.operanders["output"])
},
check: consecutiveVarAddrResolvedValueEquals(
"output",
[]*fp.Element{
feltString("3663108286"),
feltString("398046313"),
feltString("1647531929"),
feltString("2006957770"),
feltString("2363872401"),
feltString("3235013187"),
feltString("3137272298"),
feltString("406301144"),
}),
},
},
})
}
12 changes: 4 additions & 8 deletions pkg/utils/blake.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,15 @@ func SIGMA() [10][16]uint8 {
}
}

func rightRot(value uint32, n uint32) uint32 {
return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n))
}

func mix(a uint32, b uint32, c uint32, d uint32, m0 uint32, m1 uint32) (uint32, uint32, uint32, uint32) {
a = a + b + m0
d = rightRot(d^a, 16)
d = RightRot(d^a, 16)
c = c + d
b = rightRot(b^c, 12)
b = RightRot(b^c, 12)
a = a + b + m1
d = rightRot(d^a, 8)
d = RightRot(d^a, 8)
c = c + d
b = rightRot(b^c, 7)
b = RightRot(b^c, 7)
return a, b, c, d
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/utils/blake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestRightRot(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := rightRot(tc.value, tc.n)
result := RightRot(tc.value, tc.n)
if result != tc.expected {
t.Errorf("Expected %08X, got %08X", tc.expected, result)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/utils/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ func FeltDivRem(a, b *fp.Element) (div fp.Element, rem fp.Element) {

return div, rem
}

func RightRot(value uint32, n uint32) uint32 {
return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n))
}
Loading

0 comments on commit 3bacdc9

Please sign in to comment.