diff --git a/integration_tests/cairo_zero_hint_tests/packed_sha256.starknet_with_keccak.cairo b/integration_tests/cairo_zero_hint_tests/packed_sha256.starknet_with_keccak.cairo new file mode 100644 index 000000000..1fb7b76f9 --- /dev/null +++ b/integration_tests/cairo_zero_hint_tests/packed_sha256.starknet_with_keccak.cairo @@ -0,0 +1,417 @@ +// The content of this file has been borrowed from LambdaClass Cairo VM in Rust +// See https://github.com/lambdaclass/cairo-vm/blob/24c2349cc19832fd8c1552304fe0439765ed82c6/cairo_programs/packed_sha256_test.cairo + +%builtins range_check bitwise + +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.math import assert_nn_le, unsigned_div_rem +from starkware.cairo.common.memset import memset +from starkware.cairo.common.pow import pow + +const BLOCK_SIZE = 7; +const ALL_ONES = 2 ** 251 - 1; +// Pack the different instances with offsets of 35 bits. This is the maximal possible offset for +// 7 32-bit words and it allows space for carry bits in integer addition operations (up to +// 8 summands). +const SHIFTS = 1 + 2 ** 35 + 2 ** (35 * 2) + 2 ** (35 * 3) + 2 ** (35 * 4) + 2 ** (35 * 5) + 2 ** ( + 35 * 6 +); + +// Given an array of size 16, extends it to the message schedule array (of size 64) by writing +// 48 more values. +// Each element represents 7 32-bit words from 7 difference instances, starting at bits +// 0, 35, 35 * 2, ..., 35 * 6. +func compute_message_schedule{bitwise_ptr: BitwiseBuiltin*}(message: felt*) { + alloc_locals; + + // Defining the following constants as local variables saves some instructions. + local shift_mask3 = SHIFTS * (2 ** 32 - 2 ** 3); + local shift_mask7 = SHIFTS * (2 ** 32 - 2 ** 7); + local shift_mask10 = SHIFTS * (2 ** 32 - 2 ** 10); + local shift_mask17 = SHIFTS * (2 ** 32 - 2 ** 17); + local shift_mask18 = SHIFTS * (2 ** 32 - 2 ** 18); + local shift_mask19 = SHIFTS * (2 ** 32 - 2 ** 19); + local mask32ones = SHIFTS * (2 ** 32 - 1); + + // Loop variables. + tempvar bitwise_ptr = bitwise_ptr; + tempvar message = message + 16; + tempvar n = 64 - 16; + + loop: + // Compute s0 = right_rot(w[i - 15], 7) ^ right_rot(w[i - 15], 18) ^ (w[i - 15] >> 3). + tempvar w0 = message[-15]; + assert bitwise_ptr[0].x = w0; + assert bitwise_ptr[0].y = shift_mask7; + let w0_rot7 = (2 ** (32 - 7)) * w0 + (1 / 2 ** 7 - 2 ** (32 - 7)) * bitwise_ptr[0].x_and_y; + assert bitwise_ptr[1].x = w0; + assert bitwise_ptr[1].y = shift_mask18; + let w0_rot18 = (2 ** (32 - 18)) * w0 + (1 / 2 ** 18 - 2 ** (32 - 18)) * bitwise_ptr[1].x_and_y; + assert bitwise_ptr[2].x = w0; + assert bitwise_ptr[2].y = shift_mask3; + let w0_shift3 = (1 / 2 ** 3) * bitwise_ptr[2].x_and_y; + assert bitwise_ptr[3].x = w0_rot7; + assert bitwise_ptr[3].y = w0_rot18; + assert bitwise_ptr[4].x = bitwise_ptr[3].x_xor_y; + assert bitwise_ptr[4].y = w0_shift3; + let s0 = bitwise_ptr[4].x_xor_y; + let bitwise_ptr = bitwise_ptr + 5 * BitwiseBuiltin.SIZE; + + // Compute s1 = right_rot(w[i - 2], 17) ^ right_rot(w[i - 2], 19) ^ (w[i - 2] >> 10). + tempvar w1 = message[-2]; + assert bitwise_ptr[0].x = w1; + assert bitwise_ptr[0].y = shift_mask17; + let w1_rot17 = (2 ** (32 - 17)) * w1 + (1 / 2 ** 17 - 2 ** (32 - 17)) * bitwise_ptr[0].x_and_y; + assert bitwise_ptr[1].x = w1; + assert bitwise_ptr[1].y = shift_mask19; + let w1_rot19 = (2 ** (32 - 19)) * w1 + (1 / 2 ** 19 - 2 ** (32 - 19)) * bitwise_ptr[1].x_and_y; + assert bitwise_ptr[2].x = w1; + assert bitwise_ptr[2].y = shift_mask10; + let w1_shift10 = (1 / 2 ** 10) * bitwise_ptr[2].x_and_y; + assert bitwise_ptr[3].x = w1_rot17; + assert bitwise_ptr[3].y = w1_rot19; + assert bitwise_ptr[4].x = bitwise_ptr[3].x_xor_y; + assert bitwise_ptr[4].y = w1_shift10; + let s1 = bitwise_ptr[4].x_xor_y; + let bitwise_ptr = bitwise_ptr + 5 * BitwiseBuiltin.SIZE; + + assert bitwise_ptr[0].x = message[-16] + s0 + message[-7] + s1; + assert bitwise_ptr[0].y = mask32ones; + assert message[0] = bitwise_ptr[0].x_and_y; + let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE; + + tempvar bitwise_ptr = bitwise_ptr; + tempvar message = message + 1; + tempvar n = n - 1; + jmp loop if n != 0; + + return (); +} + +func sha2_compress{bitwise_ptr: BitwiseBuiltin*}( + state: felt*, message: felt*, round_constants: felt* +) -> (new_state: felt*) { + alloc_locals; + + // Defining the following constants as local variables saves some instructions. + local shift_mask2 = SHIFTS * (2 ** 32 - 2 ** 2); + local shift_mask13 = SHIFTS * (2 ** 32 - 2 ** 13); + local shift_mask22 = SHIFTS * (2 ** 32 - 2 ** 22); + local shift_mask6 = SHIFTS * (2 ** 32 - 2 ** 6); + local shift_mask11 = SHIFTS * (2 ** 32 - 2 ** 11); + local shift_mask25 = SHIFTS * (2 ** 32 - 2 ** 25); + local mask32ones = SHIFTS * (2 ** 32 - 1); + + tempvar a = state[0]; + tempvar b = state[1]; + tempvar c = state[2]; + tempvar d = state[3]; + tempvar e = state[4]; + tempvar f = state[5]; + tempvar g = state[6]; + tempvar h = state[7]; + tempvar round_constants = round_constants; + tempvar message = message; + tempvar bitwise_ptr = bitwise_ptr; + tempvar n = 64; + + loop: + // Compute s0 = right_rot(a, 2) ^ right_rot(a, 13) ^ right_rot(a, 22). + assert bitwise_ptr[0].x = a; + assert bitwise_ptr[0].y = shift_mask2; + let a_rot2 = (2 ** (32 - 2)) * a + (1 / 2 ** 2 - 2 ** (32 - 2)) * bitwise_ptr[0].x_and_y; + assert bitwise_ptr[1].x = a; + assert bitwise_ptr[1].y = shift_mask13; + let a_rot13 = (2 ** (32 - 13)) * a + (1 / 2 ** 13 - 2 ** (32 - 13)) * bitwise_ptr[1].x_and_y; + assert bitwise_ptr[2].x = a; + assert bitwise_ptr[2].y = shift_mask22; + let a_rot22 = (2 ** (32 - 22)) * a + (1 / 2 ** 22 - 2 ** (32 - 22)) * bitwise_ptr[2].x_and_y; + assert bitwise_ptr[3].x = a_rot2; + assert bitwise_ptr[3].y = a_rot13; + assert bitwise_ptr[4].x = bitwise_ptr[3].x_xor_y; + assert bitwise_ptr[4].y = a_rot22; + let s0 = bitwise_ptr[4].x_xor_y; + let bitwise_ptr = bitwise_ptr + 5 * BitwiseBuiltin.SIZE; + + // Compute s1 = right_rot(e, 6) ^ right_rot(e, 11) ^ right_rot(e, 25). + assert bitwise_ptr[0].x = e; + assert bitwise_ptr[0].y = shift_mask6; + let e_rot6 = (2 ** (32 - 6)) * e + (1 / 2 ** 6 - 2 ** (32 - 6)) * bitwise_ptr[0].x_and_y; + assert bitwise_ptr[1].x = e; + assert bitwise_ptr[1].y = shift_mask11; + let e_rot11 = (2 ** (32 - 11)) * e + (1 / 2 ** 11 - 2 ** (32 - 11)) * bitwise_ptr[1].x_and_y; + assert bitwise_ptr[2].x = e; + assert bitwise_ptr[2].y = shift_mask25; + let e_rot25 = (2 ** (32 - 25)) * e + (1 / 2 ** 25 - 2 ** (32 - 25)) * bitwise_ptr[2].x_and_y; + assert bitwise_ptr[3].x = e_rot6; + assert bitwise_ptr[3].y = e_rot11; + assert bitwise_ptr[4].x = bitwise_ptr[3].x_xor_y; + assert bitwise_ptr[4].y = e_rot25; + let s1 = bitwise_ptr[4].x_xor_y; + let bitwise_ptr = bitwise_ptr + 5 * BitwiseBuiltin.SIZE; + + // Compute ch = (e & f) ^ ((~e) & g). + assert bitwise_ptr[0].x = e; + assert bitwise_ptr[0].y = f; + assert bitwise_ptr[1].x = ALL_ONES - e; + assert bitwise_ptr[1].y = g; + let ch = bitwise_ptr[0].x_and_y + bitwise_ptr[1].x_and_y; + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE; + + // Compute maj = (a & b) ^ (a & c) ^ (b & c). + assert bitwise_ptr[0].x = a; + assert bitwise_ptr[0].y = b; + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y; + assert bitwise_ptr[1].y = c; + let maj = (a + b + c - bitwise_ptr[1].x_xor_y) / 2; + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE; + + tempvar temp1 = h + s1 + ch + round_constants[0] + message[0]; + tempvar temp2 = s0 + maj; + + assert bitwise_ptr[0].x = temp1 + temp2; + assert bitwise_ptr[0].y = mask32ones; + let new_a = bitwise_ptr[0].x_and_y; + assert bitwise_ptr[1].x = d + temp1; + assert bitwise_ptr[1].y = mask32ones; + let new_e = bitwise_ptr[1].x_and_y; + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE; + + tempvar new_a = new_a; + tempvar new_b = a; + tempvar new_c = b; + tempvar new_d = c; + tempvar new_e = new_e; + tempvar new_f = e; + tempvar new_g = f; + tempvar new_h = g; + tempvar round_constants = round_constants + 1; + tempvar message = message + 1; + tempvar bitwise_ptr = bitwise_ptr; + tempvar n = n - 1; + jmp loop if n != 0; + + // Add the compression result to the original state: + let (res) = alloc(); + assert bitwise_ptr[0].x = state[0] + new_a; + assert bitwise_ptr[0].y = mask32ones; + assert res[0] = bitwise_ptr[0].x_and_y; + assert bitwise_ptr[1].x = state[1] + new_b; + assert bitwise_ptr[1].y = mask32ones; + assert res[1] = bitwise_ptr[1].x_and_y; + assert bitwise_ptr[2].x = state[2] + new_c; + assert bitwise_ptr[2].y = mask32ones; + assert res[2] = bitwise_ptr[2].x_and_y; + assert bitwise_ptr[3].x = state[3] + new_d; + assert bitwise_ptr[3].y = mask32ones; + assert res[3] = bitwise_ptr[3].x_and_y; + assert bitwise_ptr[4].x = state[4] + new_e; + assert bitwise_ptr[4].y = mask32ones; + assert res[4] = bitwise_ptr[4].x_and_y; + assert bitwise_ptr[5].x = state[5] + new_f; + assert bitwise_ptr[5].y = mask32ones; + assert res[5] = bitwise_ptr[5].x_and_y; + assert bitwise_ptr[6].x = state[6] + new_g; + assert bitwise_ptr[6].y = mask32ones; + assert res[6] = bitwise_ptr[6].x_and_y; + assert bitwise_ptr[7].x = state[7] + new_h; + assert bitwise_ptr[7].y = mask32ones; + assert res[7] = bitwise_ptr[7].x_and_y; + let bitwise_ptr = bitwise_ptr + 8 * BitwiseBuiltin.SIZE; + + return (res,); +} + +// Returns the 64 round constants of SHA256. +func get_round_constants() -> (round_constants: felt*) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + local round_constants = 0x428A2F98 * SHIFTS; + local a = 0x71374491 * SHIFTS; + local a = 0xB5C0FBCF * SHIFTS; + local a = 0xE9B5DBA5 * SHIFTS; + local a = 0x3956C25B * SHIFTS; + local a = 0x59F111F1 * SHIFTS; + local a = 0x923F82A4 * SHIFTS; + local a = 0xAB1C5ED5 * SHIFTS; + local a = 0xD807AA98 * SHIFTS; + local a = 0x12835B01 * SHIFTS; + local a = 0x243185BE * SHIFTS; + local a = 0x550C7DC3 * SHIFTS; + local a = 0x72BE5D74 * SHIFTS; + local a = 0x80DEB1FE * SHIFTS; + local a = 0x9BDC06A7 * SHIFTS; + local a = 0xC19BF174 * SHIFTS; + local a = 0xE49B69C1 * SHIFTS; + local a = 0xEFBE4786 * SHIFTS; + local a = 0x0FC19DC6 * SHIFTS; + local a = 0x240CA1CC * SHIFTS; + local a = 0x2DE92C6F * SHIFTS; + local a = 0x4A7484AA * SHIFTS; + local a = 0x5CB0A9DC * SHIFTS; + local a = 0x76F988DA * SHIFTS; + local a = 0x983E5152 * SHIFTS; + local a = 0xA831C66D * SHIFTS; + local a = 0xB00327C8 * SHIFTS; + local a = 0xBF597FC7 * SHIFTS; + local a = 0xC6E00BF3 * SHIFTS; + local a = 0xD5A79147 * SHIFTS; + local a = 0x06CA6351 * SHIFTS; + local a = 0x14292967 * SHIFTS; + local a = 0x27B70A85 * SHIFTS; + local a = 0x2E1B2138 * SHIFTS; + local a = 0x4D2C6DFC * SHIFTS; + local a = 0x53380D13 * SHIFTS; + local a = 0x650A7354 * SHIFTS; + local a = 0x766A0ABB * SHIFTS; + local a = 0x81C2C92E * SHIFTS; + local a = 0x92722C85 * SHIFTS; + local a = 0xA2BFE8A1 * SHIFTS; + local a = 0xA81A664B * SHIFTS; + local a = 0xC24B8B70 * SHIFTS; + local a = 0xC76C51A3 * SHIFTS; + local a = 0xD192E819 * SHIFTS; + local a = 0xD6990624 * SHIFTS; + local a = 0xF40E3585 * SHIFTS; + local a = 0x106AA070 * SHIFTS; + local a = 0x19A4C116 * SHIFTS; + local a = 0x1E376C08 * SHIFTS; + local a = 0x2748774C * SHIFTS; + local a = 0x34B0BCB5 * SHIFTS; + local a = 0x391C0CB3 * SHIFTS; + local a = 0x4ED8AA4A * SHIFTS; + local a = 0x5B9CCA4F * SHIFTS; + local a = 0x682E6FF3 * SHIFTS; + local a = 0x748F82EE * SHIFTS; + local a = 0x78A5636F * SHIFTS; + local a = 0x84C87814 * SHIFTS; + local a = 0x8CC70208 * SHIFTS; + local a = 0x90BEFFFA * SHIFTS; + local a = 0xA4506CEB * SHIFTS; + local a = 0xBEF9A3F7 * SHIFTS; + local a = 0xC67178F2 * SHIFTS; + return (&round_constants,); +} + +const SHA256_INPUT_CHUNK_SIZE_FELTS = 16; +const SHA256_STATE_SIZE_FELTS = 8; +// Each instance consists of 16 words of message, 8 words for the input state and 8 words +// for the output state. +const SHA256_INSTANCE_SIZE = SHA256_INPUT_CHUNK_SIZE_FELTS + 2 * SHA256_STATE_SIZE_FELTS; + +// Computes SHA256 of 'input'. Inputs of up to 55 bytes are supported. +// To use this function, split the input into (up to) 14 words of 32 bits (big endian). +// For example, to compute sha256('Hello world'), use: +// input = [1214606444, 1864398703, 1919706112] +// where: +// 1214606444 == int.from_bytes(b'Hell', 'big') +// 1864398703 == int.from_bytes(b'o wo', 'big') +// 1919706112 == int.from_bytes(b'rld\x00', 'big') # Note the '\x00' padding. +// +// output is an array of 8 32-bit words (big endian). +// +// Assumption: n_bytes <= 55. +// +// Note: You must call finalize_sha2() at the end of the program. Otherwise, this function +// is not sound and a malicious prover may return a wrong result. +// Note: the interface of this function may change in the future. +func sha256{range_check_ptr, sha256_ptr: felt*}(input: felt*, n_bytes: felt) -> (output: felt*) { + assert_nn_le(n_bytes, 55); + let sha256_start = sha256_ptr; + _sha256_input(input=input, n_bytes=n_bytes, n_words=SHA256_INPUT_CHUNK_SIZE_FELTS - 2); + assert sha256_ptr[0] = 0; + assert sha256_ptr[1] = n_bytes * 8; + let sha256_ptr = sha256_ptr + 2; + + // Set the initial state to IV. + assert sha256_ptr[0] = 0x6A09E667; + assert sha256_ptr[1] = 0xBB67AE85; + assert sha256_ptr[2] = 0x3C6EF372; + assert sha256_ptr[3] = 0xA54FF53A; + assert sha256_ptr[4] = 0x510E527F; + assert sha256_ptr[5] = 0x9B05688C; + assert sha256_ptr[6] = 0x1F83D9AB; + assert sha256_ptr[7] = 0x5BE0CD19; + let sha256_ptr = sha256_ptr + SHA256_STATE_SIZE_FELTS; + + let output = sha256_ptr; + %{ + from starkware.cairo.common.cairo_sha256.sha256_utils import ( + IV, compute_message_schedule, sha2_compress_function) + + _sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS) + assert 0 <= _sha256_input_chunk_size_felts < 100 + + w = compute_message_schedule(memory.get_range( + ids.sha256_start, _sha256_input_chunk_size_felts)) + new_state = sha2_compress_function(IV, w) + segments.write_arg(ids.output, new_state) + %} + let sha256_ptr = sha256_ptr + SHA256_STATE_SIZE_FELTS; + return (output,); +} + +func _sha256_input{range_check_ptr, sha256_ptr: felt*}(input: felt*, n_bytes: felt, n_words: felt) { + alloc_locals; + + local full_word; + %{ ids.full_word = int(ids.n_bytes >= 4) %} + + if (full_word != 0) { + assert sha256_ptr[0] = input[0]; + let sha256_ptr = sha256_ptr + 1; + return _sha256_input(input=input + 1, n_bytes=n_bytes - 4, n_words=n_words - 1); + } + + // This is the last input word, so we should add a byte '0x80' at the end and fill the rest with + // zeros. + + if (n_bytes == 0) { + assert sha256_ptr[0] = 0x80000000; + memset(dst=sha256_ptr + 1, value=0, n=n_words - 1); + let sha256_ptr = sha256_ptr + n_words; + return (); + } + + assert_nn_le(n_bytes, 3); + let (padding) = pow(256, 3 - n_bytes); + local range_check_ptr = range_check_ptr; + + assert sha256_ptr[0] = input[0] + padding * 0x80; + + memset(dst=sha256_ptr + 1, value=0, n=n_words - 1); + let sha256_ptr = sha256_ptr + n_words; + return (); +} + +func test_packed_sha256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() { + alloc_locals; + let input_len = 3; + let input: felt* = alloc(); + assert input[0] = 1214606444; + assert input[1] = 1864398703; + assert input[2] = 1919706112; + let n_bytes = 11; + + let (local sha256_ptr_start: felt*) = alloc(); + let sha256_ptr = sha256_ptr_start; + + let (local output: felt*) = sha256{sha256_ptr=sha256_ptr}(input, n_bytes); + assert output[0] = 1693223114; + assert output[1] = 11692261; + assert output[2] = 3122279783; + assert output[3] = 2317046550; + assert output[4] = 3524457715; + assert output[5] = 1722959730; + assert output[6] = 844319370; + assert output[7] = 3970137916; + + return (); +} + +func main{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() { + test_packed_sha256(); + return (); +} \ No newline at end of file diff --git a/integration_tests/cairo_zero_hint_tests_in_progress/keccak_uint256s.starknet_with_keccak.cairo b/integration_tests/cairo_zero_hint_tests_in_progress/keccak_uint256s.starknet_with_keccak.cairo index e6e69066c..a7a4f9540 100644 --- a/integration_tests/cairo_zero_hint_tests_in_progress/keccak_uint256s.starknet_with_keccak.cairo +++ b/integration_tests/cairo_zero_hint_tests_in_progress/keccak_uint256s.starknet_with_keccak.cairo @@ -1,4 +1,5 @@ %builtins range_check bitwise keccak + from starkware.cairo.common.cairo_builtins import KeccakBuiltin, BitwiseBuiltin from starkware.cairo.common.builtin_keccak.keccak import keccak_uint256s from starkware.cairo.common.alloc import alloc diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index a8837dab6..a1464d040 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -117,6 +117,10 @@ const ( blake2sFinalizeV3Code string = "# Add dummy pairs of input and output.\nfrom starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress\n\n_n_packed_instances = int(ids.N_PACKED_INSTANCES)\nassert 0 <= _n_packed_instances < 20\n_blake2s_input_chunk_size_felts = int(ids.BLAKE2S_INPUT_CHUNK_SIZE_FELTS)\nassert 0 <= _blake2s_input_chunk_size_felts < 100\n\nmessage = [0] * _blake2s_input_chunk_size_felts\nmodified_iv = [IV[0] ^ 0x01010020] + IV[1:]\noutput = blake2s_compress(\n message=message,\n h=modified_iv,\n t0=0,\n t1=0,\n f0=0xffffffff,\n f1=0,\n)\npadding = (message + modified_iv + [0, 0xffffffff] + output) * (_n_packed_instances - 1)\nsegments.write_arg(ids.blake2s_ptr_end, padding)" blake2sComputeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import compute_blake2s_func\ncompute_blake2s_func(segments=segments, output_ptr=ids.output)" + // ------ Sha256 Hash hints related code ------ + + packedSha256Code string = "from starkware.cairo.common.cairo_sha256.sha256_utils import (\n IV, compute_message_schedule, sha2_compress_function)\n\n_sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS)\nassert 0 <= _sha256_input_chunk_size_felts < 100\n\nw = compute_message_schedule(memory.get_range(\n ids.sha256_start, _sha256_input_chunk_size_felts))\nnew_state = sha2_compress_function(IV, w)\nsegments.write_arg(ids.output, new_state)" + // ------ Keccak hints related code ------ unsafeKeccakFinalizeCode string = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')" unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')" @@ -168,4 +172,5 @@ const ( setAddCode string = "assert ids.elm_size > 0\nassert ids.set_ptr <= ids.set_end_ptr\nelm_list = memory.get_range(ids.elm_ptr, ids.elm_size)\nfor i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):\n if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:\n ids.index = i // ids.elm_size\n ids.is_elm_in_set = 1\n break\nelse:\n ids.is_elm_in_set = 0" searchSortedLowerCode string = "array_ptr = ids.array_ptr\nelm_size = ids.elm_size\nassert isinstance(elm_size, int) and elm_size > 0, \\\n f'Invalid value for elm_size. Got: {elm_size}.'\n\nn_elms = ids.n_elms\nassert isinstance(n_elms, int) and n_elms >= 0, \\\n f'Invalid value for n_elms. Got: {n_elms}.'\nif '__find_element_max_size' in globals():\n assert n_elms <= __find_element_max_size, \\\n f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \\\n f'Got: n_elms={n_elms}.'\n\nfor i in range(n_elms):\n if memory[array_ptr + elm_size * i] >= ids.key:\n ids.index = i\n break\nelse:\n ids.index = n_elms" normalizeAddressCode string = "# Verify the assumptions on the relationship between 2**250, ADDR_BOUND and PRIME.\nADDR_BOUND = ids.ADDR_BOUND % PRIME\nassert (2**250 < ADDR_BOUND <= 2**251) and (2 * 2**250 < PRIME) and (\n ADDR_BOUND * 2 > PRIME), \\\n 'normalize_address() cannot be used with the current constants.'\nids.is_small = 1 if ids.addr < ADDR_BOUND else 0" + sha256AndBlake2sInputCode string = "ids.full_word = int(ids.n_bytes >= 4)" ) diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index 043715ef0..a08ca998d 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -206,6 +206,9 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint) (hinter.Hinte return createBlake2sFinalizeV3Hinter(resolver) case blake2sComputeCode: return createBlake2sComputeHinter(resolver) + // Sha256 hints + case packedSha256Code: + return createPackedSha256Hinter(resolver) // Keccak hints case keccakWriteArgsCode: return createKeccakWriteArgsHinter(resolver) @@ -314,6 +317,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint) (hinter.Hinte return createNondetElementsOverXHinter(resolver, 10) case normalizeAddressCode: return createNormalizeAddressHinter(resolver) + case sha256AndBlake2sInputCode: + return createSha256AndBlake2sInputHinter(resolver) default: return nil, fmt.Errorf("not identified hint: \n%s", rawHint.Code) } diff --git a/pkg/hintrunner/zero/zerohint_others.go b/pkg/hintrunner/zero/zerohint_others.go index adade9ccb..994846830 100644 --- a/pkg/hintrunner/zero/zerohint_others.go +++ b/pkg/hintrunner/zero/zerohint_others.go @@ -751,3 +751,50 @@ func createNormalizeAddressHinter(resolver hintReferenceResolver) (hinter.Hinter return newNormalizeAddressHint(isSmall, addr), nil } + +// Sha256AndBlake2sInput hint writes 1 or 0 at `full_word` address, wether `n_bytes“ +// is greater than or equal to 4 or not +// +// `newSha256AndBlake2sInputHint` takes 2 arguments +// - `full_word` represents the address where the result of the comparison is stored +// - `n_bytes` represents the value that will be compared to 4 +func newSha256AndBlake2sInputHint(fullWord, nbytes hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "Sha256AndBlake2sInput", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> ids.full_word = int(ids.n_bytes >= 4) + + n_bytes, err := hinter.ResolveAsFelt(vm, nbytes) + if err != nil { + return err + } + + fullWordAddr, err := fullWord.GetAddress(vm) + if err != nil { + return err + } + + var resultMv memory.MemoryValue + if n_bytes.Cmp(new(fp.Element).SetUint64(4)) >= 0 { + resultMv = memory.MemoryValueFromFieldElement(&utils.FeltOne) + } else { + resultMv = memory.MemoryValueFromFieldElement(&utils.FeltZero) + } + return vm.Memory.WriteToAddress(&fullWordAddr, &resultMv) + }, + } +} + +func createSha256AndBlake2sInputHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + fullWord, err := resolver.GetResOperander("full_word") + if err != nil { + return nil, err + } + + nBytes, err := resolver.GetResOperander("n_bytes") + if err != nil { + return nil, err + } + + return newSha256AndBlake2sInputHint(fullWord, nBytes), nil +} diff --git a/pkg/hintrunner/zero/zerohint_others_test.go b/pkg/hintrunner/zero/zerohint_others_test.go index 6966f54f8..61a552568 100644 --- a/pkg/hintrunner/zero/zerohint_others_test.go +++ b/pkg/hintrunner/zero/zerohint_others_test.go @@ -634,5 +634,46 @@ func TestZeroHintOthers(t *testing.T) { check: varValueEquals("is_small", feltUint64(1)), }, }, + "Sha256AndBlake2sInput": { + { + operanders: []*hintOperander{ + {Name: "full_word", Kind: uninitialized}, + {Name: "n_bytes", Kind: apRelative, Value: feltUint64(3)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSha256AndBlake2sInputHint( + ctx.operanders["full_word"], + ctx.operanders["n_bytes"], + ) + }, + check: varValueEquals("full_word", feltUint64(0)), + }, + { + operanders: []*hintOperander{ + {Name: "full_word", Kind: uninitialized}, + {Name: "n_bytes", Kind: apRelative, Value: feltUint64(4)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSha256AndBlake2sInputHint( + ctx.operanders["full_word"], + ctx.operanders["n_bytes"], + ) + }, + check: varValueEquals("full_word", feltUint64(1)), + }, + { + operanders: []*hintOperander{ + {Name: "full_word", Kind: uninitialized}, + {Name: "n_bytes", Kind: apRelative, Value: feltUint64(5)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSha256AndBlake2sInputHint( + ctx.operanders["full_word"], + ctx.operanders["n_bytes"], + ) + }, + check: varValueEquals("full_word", feltUint64(1)), + }, + }, }) } diff --git a/pkg/hintrunner/zero/zerohint_sha256.go b/pkg/hintrunner/zero/zerohint_sha256.go new file mode 100644 index 000000000..4565e6300 --- /dev/null +++ b/pkg/hintrunner/zero/zerohint_sha256.go @@ -0,0 +1,89 @@ +package zero + +import ( + "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + hintrunnerUtils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" + mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" + + VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" +) + +func newPackedSha256Hint(sha256Start, output hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "PackedSha256", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> from starkware.cairo.common.cairo_sha256.sha256_utils import ( + //> IV, compute_message_schedule, sha2_compress_function) + //> + //> _sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS) + //> assert 0 <= _sha256_input_chunk_size_felts < 100 + //> + //> w = compute_message_schedule(memory.get_range( + //> ids.sha256_start, _sha256_input_chunk_size_felts)) + //> new_state = sha2_compress_function(IV, w) + //> segments.write_arg(ids.output, new_state) + + Sha256InputChunkSize := uint64(16) + + sha256Start, err := hinter.ResolveAsAddress(vm, sha256Start) + if err != nil { + return err + } + + w, err := vm.Memory.GetConsecutiveMemoryValues(*sha256Start, Sha256InputChunkSize) + if err != nil { + return err + } + + wUint32 := make([]uint32, len(w)) + for i := 0; i < len(w); i++ { + value, err := hintrunnerUtils.ToSafeUint32(&w[i]) + if err != nil { + return err + } + wUint32[i] = value + } + + messageSchedule, err := utils.ComputeMessageSchedule(wUint32) + if err != nil { + return err + } + newState := utils.Sha256Compress(utils.IV(), messageSchedule) + + output, err := hinter.ResolveAsAddress(vm, output) + if err != nil { + return err + } + + for i := 0; i < len(newState); i++ { + newStateValue := mem.MemoryValueFromInt(newState[i]) + outputOffset, err := output.AddOffset(int16(i)) + if err != nil { + return err + } + + err = vm.Memory.WriteToAddress(&outputOffset, &newStateValue) + if err != nil { + return err + } + } + + return nil + }, + } +} + +func createPackedSha256Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + sha256Start, err := resolver.GetResOperander("sha256_start") + if err != nil { + return nil, err + } + + output, err := resolver.GetResOperander("output") + if err != nil { + return nil, err + } + + return newPackedSha256Hint(sha256Start, output), nil +} diff --git a/pkg/hintrunner/zero/zerohint_sha256_test.go b/pkg/hintrunner/zero/zerohint_sha256_test.go new file mode 100644 index 000000000..2a338a213 --- /dev/null +++ b/pkg/hintrunner/zero/zerohint_sha256_test.go @@ -0,0 +1,52 @@ +package zero + +import ( + "testing" + + "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +func TestZeroHintSha256(t *testing.T) { + runHinterTests(t, map[string][]hintTestCase{ + "PackedSha256": { + { + operanders: []*hintOperander{ + {Name: "sha256_start", Kind: apRelative, Value: addr(6)}, + {Name: "output", Kind: apRelative, Value: addr(22)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + {Name: "buffer", Kind: apRelative, Value: feltUint64(0)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newPackedSha256Hint(ctx.operanders["sha256_start"], ctx.operanders["output"]) + }, + check: consecutiveVarAddrResolvedValueEquals( + "output", + []*fp.Element{ + feltString("3663108286"), + feltString("398046313"), + feltString("1647531929"), + feltString("2006957770"), + feltString("2363872401"), + feltString("3235013187"), + feltString("3137272298"), + feltString("406301144"), + }), + }, + }, + }) +} diff --git a/pkg/utils/blake.go b/pkg/utils/blake.go index 7db61de86..a3309e0ff 100644 --- a/pkg/utils/blake.go +++ b/pkg/utils/blake.go @@ -30,19 +30,15 @@ func SIGMA() [10][16]uint8 { } } -func rightRot(value uint32, n uint32) uint32 { - return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n)) -} - func mix(a uint32, b uint32, c uint32, d uint32, m0 uint32, m1 uint32) (uint32, uint32, uint32, uint32) { a = a + b + m0 - d = rightRot(d^a, 16) + d = RightRot(d^a, 16) c = c + d - b = rightRot(b^c, 12) + b = RightRot(b^c, 12) a = a + b + m1 - d = rightRot(d^a, 8) + d = RightRot(d^a, 8) c = c + d - b = rightRot(b^c, 7) + b = RightRot(b^c, 7) return a, b, c, d } diff --git a/pkg/utils/blake_test.go b/pkg/utils/blake_test.go index a4033f13d..a374e82b1 100644 --- a/pkg/utils/blake_test.go +++ b/pkg/utils/blake_test.go @@ -51,7 +51,7 @@ func TestRightRot(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - result := rightRot(tc.value, tc.n) + result := RightRot(tc.value, tc.n) if result != tc.expected { t.Errorf("Expected %08X, got %08X", tc.expected, result) } diff --git a/pkg/utils/math.go b/pkg/utils/math.go index 27589f314..4f76a616b 100644 --- a/pkg/utils/math.go +++ b/pkg/utils/math.go @@ -125,3 +125,7 @@ func FeltDivRem(a, b *fp.Element) (div fp.Element, rem fp.Element) { return div, rem } + +func RightRot(value uint32, n uint32) uint32 { + return (value >> n) | ((value & ((1 << n) - 1)) << (32 - n)) +} diff --git a/pkg/utils/sha256.go b/pkg/utils/sha256.go new file mode 100644 index 000000000..a25976c4c --- /dev/null +++ b/pkg/utils/sha256.go @@ -0,0 +1,101 @@ +package utils + +import "fmt" + +func ComputeMessageSchedule(input []uint32) ([]uint32, error) { + // def compute_message_schedule(message: List[int]) -> List[int]: + // w = list(message) + // assert len(w) == 16 + + // for i in range(16, 64): + // s0 = right_rot(w[i - 15], 7) ^ right_rot(w[i - 15], 18) ^ (w[i - 15] >> 3) + // s1 = right_rot(w[i - 2], 17) ^ right_rot(w[i - 2], 19) ^ (w[i - 2] >> 10) + // w.append((w[i - 16] + s0 + w[i - 7] + s1) % 2**32) + + // return w + if len(input) != 16 { + return nil, fmt.Errorf("input length must be 16, got %d", len(input)) + } + + w := make([]uint32, 64) + copy(w, input) + + for i := 16; i < 64; i++ { + s0 := RightRot(w[i-15], 7) ^ RightRot(w[i-15], 18) ^ (w[i-15] >> 3) + s1 := RightRot(w[i-2], 17) ^ RightRot(w[i-2], 19) ^ (w[i-2] >> 10) + w[i] = w[i-16] + s0 + w[i-7] + s1 + } + + return w, nil +} + +func Sha256Compress(state [8]uint32, w []uint32) []uint32 { + // def sha2_compress_function(state: List[int], w: List[int]) -> List[int]: + // a, b, c, d, e, f, g, h = state + + // for i in range(64): + // s0 = right_rot(a, 2) ^ right_rot(a, 13) ^ right_rot(a, 22) + // s1 = right_rot(e, 6) ^ right_rot(e, 11) ^ right_rot(e, 25) + // ch = (e & f) ^ ((~e) & g) + // temp1 = (h + s1 + ch + ROUND_CONSTANTS[i] + w[i]) % 2**32 + // maj = (a & b) ^ (a & c) ^ (b & c) + // temp2 = (s0 + maj) % 2**32 + + // h = g + // g = f + // f = e + // e = (d + temp1) % 2**32 + // d = c + // c = b + // b = a + // a = (temp1 + temp2) % 2**32 + + // # Add the compression result to the original state. + // return [ + // + // (state[0] + a) % 2**32, + // (state[1] + b) % 2**32, + // (state[2] + c) % 2**32, + // (state[3] + d) % 2**32, + // (state[4] + e) % 2**32, + // (state[5] + f) % 2**32, + // (state[6] + g) % 2**32, + // (state[7] + h) % 2**32, + // + // ] + k := []uint32{ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, + } + + a, b, c, d, e, f, g, h := state[0], state[1], state[2], state[3], state[4], state[5], state[6], state[7] + + for i := 0; i < 64; i++ { + S1 := RightRot(e, 6) ^ RightRot(e, 11) ^ RightRot(e, 25) + ch := (e & f) ^ ((^e) & g) + temp1 := h + S1 + ch + k[i] + w[i] + S0 := RightRot(a, 2) ^ RightRot(a, 13) ^ RightRot(a, 22) + maj := (a & b) ^ (a & c) ^ (b & c) + temp2 := S0 + maj + + h = g + g = f + f = e + e = d + temp1 + d = c + c = b + b = a + a = temp1 + temp2 + } + + return []uint32{ + state[0] + a, state[1] + b, state[2] + c, state[3] + d, + state[4] + e, state[5] + f, state[6] + g, state[7] + h, + } +}