From cf8266fd5d1ff66962579ff7967ac5cdcf699f77 Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Wed, 22 Sep 2021 11:21:58 +0300 Subject: [PATCH] Cairo v0.4.1. --- README.md | 4 +- src/cmake_utils/gen_pip_cmake.py | 41 +- src/cmake_utils/gen_py_lib.py | 49 +- src/cmake_utils/gen_python_exe.py | 86 +- src/cmake_utils/gen_venv.py | 106 +- src/cmake_utils/unite_lib.py | 2 +- src/demo/amm_demo/demo.py | 125 +- src/demo/amm_demo/prove_batch.py | 38 +- .../feeder_gateway/feeder_gateway_client.py | 6 +- .../everest/api/gateway/gateway_client.py | 12 +- .../everest/api/gateway/transaction.py | 1 + .../business_logic/internal_transaction.py | 35 +- src/services/everest/business_logic/state.py | 28 +- src/services/everest/definitions/fields.py | 26 +- src/services/external_api/base_client.py | 49 +- src/services/external_api/has_uri_prefix.py | 3 +- .../cairo/bootloader/compute_fact.py | 53 +- .../cairo/bootloader/fact_topology.py | 2 +- .../cairo/bootloader/generate_fact.py | 66 +- .../cairo/bootloader/hash_program.py | 24 +- .../select_input_builtins_test.py | 40 +- .../validate_builtins_test.py | 25 +- src/starkware/cairo/common/CMakeLists.txt | 4 +- .../cairo/common/cairo_function_runner.py | 99 +- .../cairo/common/cairo_keccak/keccak_utils.py | 24 +- .../cairo/common/cairo_sha256/sha256_utils.py | 128 ++ src/starkware/cairo/common/dict.py | 11 +- src/starkware/cairo/common/math.cairo | 27 + src/starkware/cairo/common/math_utils.py | 4 +- src/starkware/cairo/common/memset.cairo | 33 + src/starkware/cairo/common/patricia_utils.py | 23 +- .../cairo/common/small_merkle_tree.py | 7 +- src/starkware/cairo/common/structs.py | 28 +- src/starkware/cairo/lang/VERSION | 2 +- .../bitwise/bitwise_builtin_runner.py | 71 +- .../builtins/builtin_runner_test_utils.py | 4 +- .../lang/builtins/ec/ec_op_builtin_runner.py | 50 +- .../lang/builtins/hash/hash_builtin_runner.py | 44 +- .../range_check/range_check_builtin_runner.py | 50 +- .../range_check_builtin_runner_test.py | 14 +- .../signature/signature_builtin_runner.py | 67 +- .../signature_builtin_runner_test.py | 85 +- src/starkware/cairo/lang/cairo_constants.py | 2 +- .../cairo/lang/compiler/CMakeLists.txt | 2 + .../cairo/lang/compiler/assembler.py | 58 +- .../cairo/lang/compiler/assembler_test.py | 66 +- .../lang/compiler/ast/aliased_identifier.py | 5 +- .../compiler/ast/ast_objects_test_utils.py | 12 +- .../cairo/lang/compiler/ast/bool_expr.py | 4 +- .../cairo/lang/compiler/ast/cairo_types.py | 9 +- .../cairo/lang/compiler/ast/code_elements.py | 211 ++- src/starkware/cairo/lang/compiler/ast/expr.py | 103 +- .../lang/compiler/ast/formatting_utils.py | 24 +- .../compiler/ast/formatting_utils_test.py | 88 +- .../cairo/lang/compiler/ast/instructions.py | 17 +- src/starkware/cairo/lang/compiler/ast/node.py | 2 +- .../cairo/lang/compiler/ast/notes.py | 23 +- .../cairo/lang/compiler/ast/rvalue.py | 32 +- .../cairo/lang/compiler/ast/types.py | 4 +- .../cairo/lang/compiler/ast/visitor.py | 41 +- .../cairo/lang/compiler/ast_objects_test.py | 103 +- .../cairo/lang/compiler/cairo_compile.py | 269 ++- .../cairo/lang/compiler/cairo_compile_test.py | 20 +- .../cairo/lang/compiler/cairo_format.py | 26 +- src/starkware/cairo/lang/compiler/conftest.py | 2 +- .../cairo/lang/compiler/const_expr_checker.py | 11 +- .../cairo/lang/compiler/constants.py | 8 +- .../cairo/lang/compiler/debug_info.py | 17 +- .../cairo/lang/compiler/debug_info_test.py | 31 +- src/starkware/cairo/lang/compiler/encode.py | 76 +- .../cairo/lang/compiler/encode_test.py | 36 +- .../cairo/lang/compiler/error_handling.py | 59 +- .../lang/compiler/error_handling_test.py | 26 +- .../lang/compiler/expression_evaluator.py | 18 +- .../compiler/expression_evaluator_test.py | 14 +- .../lang/compiler/expression_simplifier.py | 76 +- .../compiler/expression_simplifier_test.py | 99 +- .../lang/compiler/expression_transformer.py | 86 +- src/starkware/cairo/lang/compiler/fields.py | 12 +- .../lang/compiler/identifier_definition.py | 48 +- .../compiler/identifier_definition_test.py | 19 +- .../cairo/lang/compiler/identifier_manager.py | 73 +- .../lang/compiler/identifier_manager_field.py | 13 +- .../compiler/identifier_manager_field_test.py | 19 +- .../lang/compiler/identifier_manager_test.py | 178 +- .../cairo/lang/compiler/identifier_utils.py | 25 +- .../lang/compiler/identifier_utils_test.py | 33 +- .../cairo/lang/compiler/import_loader.py | 26 +- .../cairo/lang/compiler/import_loader_test.py | 119 +- .../cairo/lang/compiler/instruction.py | 8 +- .../lang/compiler/instruction_builder.py | 122 +- .../lang/compiler/instruction_builder_test.py | 825 +++++---- .../cairo/lang/compiler/instruction_test.py | 11 +- .../cairo/lang/compiler/location_utils.py | 7 +- .../cairo/lang/compiler/module_reader.py | 7 +- .../cairo/lang/compiler/module_reader_test.py | 21 +- .../cairo/lang/compiler/offset_reference.py | 25 +- .../lang/compiler/offset_reference_test.py | 32 +- src/starkware/cairo/lang/compiler/parser.py | 166 +- .../cairo/lang/compiler/parser_errors_test.py | 149 +- .../cairo/lang/compiler/parser_test.py | 748 ++++---- .../cairo/lang/compiler/parser_test_utils.py | 4 +- .../cairo/lang/compiler/parser_transformer.py | 218 ++- .../preprocessor/compound_expressions.py | 78 +- .../preprocessor/compound_expressions_test.py | 303 +-- .../lang/compiler/preprocessor/conftest.py | 2 +- .../preprocessor/default_pass_manager.py | 63 +- .../compiler/preprocessor/dependency_graph.py | 42 +- .../preprocessor/dependency_graph_test.py | 152 +- .../cairo/lang/compiler/preprocessor/flow.py | 53 +- .../lang/compiler/preprocessor/flow_test.py | 99 +- .../preprocessor/identifier_aware_visitor.py | 65 +- .../identifier_aware_visitor_test.py | 19 +- .../preprocessor/identifier_collector.py | 135 +- .../preprocessor/identifier_collector_test.py | 94 +- .../compiler/preprocessor/local_variables.py | 131 +- .../preprocessor/local_variables_test.py | 91 +- .../compiler/preprocessor/pass_manager.py | 2 +- .../compiler/preprocessor/preprocess_codes.py | 6 +- .../compiler/preprocessor/preprocessor.py | 925 ++++++---- .../preprocessor/preprocessor_test.py | 1636 +++++++++++------ .../preprocessor/preprocessor_test_utils.py | 53 +- .../preprocessor/preprocessor_utils.py | 20 +- .../compiler/preprocessor/reg_tracking.py | 13 +- .../preprocessor/reg_tracking_test.py | 25 +- .../compiler/preprocessor/struct_collector.py | 84 +- .../preprocessor/struct_collector_test.py | 129 +- .../compiler/preprocessor/unique_labels.py | 2 +- .../preprocessor/unique_labels_test.py | 12 +- src/starkware/cairo/lang/compiler/program.py | 74 +- .../cairo/lang/compiler/references.py | 59 +- .../cairo/lang/compiler/references_test.py | 14 +- .../lang/compiler/resolve_search_result.py | 26 +- .../compiler/resolve_search_result_test.py | 18 +- .../cairo/lang/compiler/scoped_name.py | 12 +- .../cairo/lang/compiler/scoped_name_test.py | 40 +- .../lang/compiler/substitute_identifiers.py | 66 +- .../cairo/lang/compiler/type_casts.py | 56 +- .../cairo/lang/compiler/type_casts_test.py | 39 +- .../cairo/lang/compiler/type_system.py | 22 +- .../lang/compiler/type_system_visitor.py | 159 +- .../lang/compiler/type_system_visitor_test.py | 353 ++-- .../cairo/lang/compiler/type_utils.py | 41 + .../cairo/lang/compiler/type_utils_test.py | 53 + .../cairo/lang/ide/vscode-cairo/package.json | 2 +- src/starkware/cairo/lang/instances.py | 24 +- src/starkware/cairo/lang/setup.py | 58 +- src/starkware/cairo/lang/tracer/profile.py | 27 +- src/starkware/cairo/lang/tracer/profiler.py | 26 +- src/starkware/cairo/lang/tracer/tracer.py | 80 +- .../cairo/lang/tracer/tracer_data.py | 138 +- .../cairo/lang/tracer/tracer_data_test.py | 84 +- src/starkware/cairo/lang/version.py | 2 +- .../cairo/lang/vm/air_public_input.py | 15 +- src/starkware/cairo/lang/vm/builtin_runner.py | 50 +- src/starkware/cairo/lang/vm/cairo_pie.py | 209 ++- src/starkware/cairo/lang/vm/cairo_pie_test.py | 91 +- src/starkware/cairo/lang/vm/cairo_run.py | 351 ++-- src/starkware/cairo/lang/vm/cairo_runner.py | 332 ++-- .../cairo/lang/vm/cairo_runner_test.py | 97 +- src/starkware/cairo/lang/vm/memory_dict.py | 67 +- .../cairo/lang/vm/memory_dict_test.py | 94 +- .../cairo/lang/vm/memory_segments.py | 68 +- .../cairo/lang/vm/memory_segments_test.py | 37 +- .../cairo/lang/vm/output_builtin_runner.py | 74 +- .../lang/vm/output_builtin_runner_test.py | 46 +- .../cairo/lang/vm/reconstruct_traceback.py | 48 +- .../lang/vm/reconstruct_traceback_test.py | 14 +- src/starkware/cairo/lang/vm/relocatable.py | 40 +- .../cairo/lang/vm/relocatable_fields.py | 8 +- .../cairo/lang/vm/relocatable_fields_test.py | 7 +- .../cairo/lang/vm/relocatable_test.py | 16 +- src/starkware/cairo/lang/vm/security.py | 11 +- src/starkware/cairo/lang/vm/security_test.py | 69 +- src/starkware/cairo/lang/vm/trace_entry.py | 30 +- .../cairo/lang/vm/trace_entry_test.py | 13 +- src/starkware/cairo/lang/vm/utils.py | 7 +- .../cairo/lang/vm/validated_memory_dict.py | 6 +- .../lang/vm/validated_memory_dict_test.py | 11 +- src/starkware/cairo/lang/vm/vm.py | 334 ++-- src/starkware/cairo/lang/vm/vm_consts.py | 193 +- src/starkware/cairo/lang/vm/vm_consts_test.py | 322 ++-- src/starkware/cairo/lang/vm/vm_test.py | 128 +- src/starkware/cairo/sharp/client_lib.py | 23 +- src/starkware/cairo/sharp/client_lib_test.py | 49 +- src/starkware/cairo/sharp/fact_checker.py | 29 +- .../cairo/sharp/fact_checker_test.py | 2 +- src/starkware/cairo/sharp/sharp_client.py | 164 +- .../cairo/sharp/sharp_client_test.py | 90 +- .../crypto/signature/fast_pedersen_hash.py | 27 +- .../starkware/crypto/signature/math_utils.py | 2 +- .../signature/nothing_up_my_sleeve_gen.py | 44 +- .../starkware/crypto/signature/signature.py | 92 +- src/starkware/python/async_subprocess.py | 7 +- src/starkware/python/expression_string.py | 28 +- .../python/expression_string_test.py | 70 +- src/starkware/python/json_rpc/client.py | 6 +- src/starkware/python/json_rpc/client_test.py | 22 +- src/starkware/python/math_utils.py | 5 +- src/starkware/python/math_utils_test.py | 14 +- src/starkware/python/merkle_tree.py | 8 +- src/starkware/python/object_utils.py | 33 +- src/starkware/python/python_dependencies.py | 27 +- src/starkware/python/random_test.py | 47 +- src/starkware/python/test_utils.py | 4 +- src/starkware/python/test_utils_test.py | 11 +- src/starkware/python/utils.py | 51 +- src/starkware/python/utils_test.py | 33 +- .../business_logic/internal_transaction.py | 11 +- .../internal_transaction_interface.py | 1 - .../starknet/business_logic/state.py | 2 +- .../starknet/business_logic/state_objects.py | 6 +- src/starkware/starknet/cli/CMakeLists.txt | 1 + src/starkware/starknet/cli/starknet_cli.py | 36 +- src/starkware/starknet/compiler/compile.py | 11 +- .../starknet/compiler/contract_interface.py | 25 +- .../starknet/compiler/data_encoder.py | 70 +- .../starknet/compiler/data_encoder_test.py | 71 +- .../compiler/starknet_preprocessor.py | 49 +- .../compiler/starknet_preprocessor_test.py | 72 +- .../starknet/compiler/storage_var.py | 49 +- .../starknet/compiler/storage_var_test.py | 50 - src/starkware/starknet/public/CMakeLists.txt | 3 + src/starkware/starknet/public/abi_structs.py | 109 ++ .../starknet/public/abi_structs_test.py | 68 + .../starknet/security/CMakeLists.txt | 3 +- .../starknet/security/hints_whitelist.py | 1 - .../starknet/security/starknet_common.cairo | 3 +- .../security/whitelists/cairo_sha256.json | 123 ++ .../starknet/security/whitelists/latest.json | 138 ++ .../starknet/storage/starknet_storage.py | 6 +- src/starkware/starknet/testing/CMakeLists.txt | 1 + src/starkware/starknet/testing/contract.py | 33 +- .../starknet/testing/contract_test.py | 9 +- src/starkware/starknet/testing/starknet.py | 146 +- .../starknet/testing/starknet_test.py | 19 +- src/starkware/starknet/testing/state.py | 130 ++ src/starkware/starkware_utils/CMakeLists.txt | 16 +- .../commitment_tree/CMakeLists.txt | 1 + .../__init__.py | 0 .../{ => commitment_tree}/binary_fact_tree.py | 25 +- .../binary_fact_tree_node.py | 112 +- .../merkle_tree/traverse_tree.py | 8 +- .../commitment_tree/patricia_tree/__init__.py | 0 .../patricia_tree/nodes.py | 48 +- .../patricia_tree/nodes_test.py | 13 +- .../patricia_tree/patricia_tree.py | 41 +- .../patricia_tree/virtual_patricia_node.py | 128 +- .../virtual_patricia_node_test.py | 77 +- src/starkware/starkware_utils/config_base.py | 20 +- .../starkware_utils/custom_raising_dict.py | 11 +- .../starkware_utils/error_handling.py | 28 +- .../starkware_utils/field_validators.py | 189 +- .../marshmallow_dataclass_fields.py | 40 +- src/starkware/starkware_utils/serializable.py | 27 +- src/starkware/starkware_utils/time/time.py | 3 +- .../starkware_utils/validated_dataclass.py | 76 +- .../starkware_utils/validated_fields.py | 80 +- src/starkware/storage/batch_store.py | 10 +- src/starkware/storage/batch_store_test.py | 7 +- src/starkware/storage/dict_storage.py | 2 +- src/starkware/storage/gated_storage.py | 20 +- src/starkware/storage/gated_storage_test.py | 10 +- src/starkware/storage/imm_storage.py | 21 +- .../storage/internal_proxy_storage.py | 4 +- .../storage/internal_proxy_storage_test.py | 4 +- src/starkware/storage/metrics.py | 12 +- src/starkware/storage/names.py | 12 +- src/starkware/storage/storage.py | 62 +- src/starkware/storage/storage_test.py | 12 +- src/starkware/storage/test_utils.py | 5 +- 271 files changed, 11153 insertions(+), 6784 deletions(-) create mode 100644 src/starkware/cairo/common/cairo_sha256/sha256_utils.py create mode 100644 src/starkware/cairo/common/memset.cairo create mode 100644 src/starkware/cairo/lang/compiler/type_utils.py create mode 100644 src/starkware/cairo/lang/compiler/type_utils_test.py create mode 100644 src/starkware/starknet/public/abi_structs.py create mode 100644 src/starkware/starknet/public/abi_structs_test.py create mode 100644 src/starkware/starknet/security/whitelists/cairo_sha256.json create mode 100644 src/starkware/starknet/testing/state.py create mode 100644 src/starkware/starkware_utils/commitment_tree/CMakeLists.txt rename src/starkware/starkware_utils/{patricia_tree => commitment_tree}/__init__.py (100%) rename src/starkware/starkware_utils/{ => commitment_tree}/binary_fact_tree.py (70%) rename src/starkware/starkware_utils/{ => commitment_tree}/binary_fact_tree_node.py (79%) rename src/starkware/starkware_utils/{ => commitment_tree}/merkle_tree/traverse_tree.py (89%) create mode 100644 src/starkware/starkware_utils/commitment_tree/patricia_tree/__init__.py rename src/starkware/starkware_utils/{ => commitment_tree}/patricia_tree/nodes.py (77%) rename src/starkware/starkware_utils/{ => commitment_tree}/patricia_tree/nodes_test.py (85%) rename src/starkware/starkware_utils/{ => commitment_tree}/patricia_tree/patricia_tree.py (64%) rename src/starkware/starkware_utils/{ => commitment_tree}/patricia_tree/virtual_patricia_node.py (70%) rename src/starkware/starkware_utils/{ => commitment_tree}/patricia_tree/virtual_patricia_node_test.py (81%) diff --git a/README.md b/README.md index 6245ce2c..7a51190f 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.4.0) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.4.1) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.4.0.zip . +> docker cp ${container_id}:/app/cairo-lang-0.4.1.zip . > docker rm -v ${container_id} ``` diff --git a/src/cmake_utils/gen_pip_cmake.py b/src/cmake_utils/gen_pip_cmake.py index 450c1fc0..b63af667 100755 --- a/src/cmake_utils/gen_pip_cmake.py +++ b/src/cmake_utils/gen_pip_cmake.py @@ -16,34 +16,41 @@ def main(): - parser = ArgumentParser( - description='Generates a CMake file declaring all pip targets.') + parser = ArgumentParser(description="Generates a CMake file declaring all pip targets.") parser.add_argument( - '--interpreter_deps', type=str, nargs='*', required=True, - help='Interpreters and dependency output JSON files. ' - 'Example: python3.7:python_deps.json ...') - parser.add_argument('--output', type=str, help='Output cmake file', required=True) + "--interpreter_deps", + type=str, + nargs="*", + required=True, + help="Interpreters and dependency output JSON files. " + "Example: python3.7:python_deps.json ...", + ) + parser.add_argument("--output", type=str, help="Output cmake file", required=True) args = parser.parse_args() - res = '' + res = "" package_libs = defaultdict(list) package_versions = defaultdict(list) # Load dependency files for each interpreter. for interpreter_dep in args.interpreter_deps: - interpreter, dep_file = interpreter_dep.split(':') - with open(dep_file, 'r') as fp: + interpreter, dep_file = interpreter_dep.split(":") + with open(dep_file, "r") as fp: for package in json.load(fp): # Extract name of package. - name = package['package']['key'].replace('-', '_').lower() + name = package["package"]["key"].replace("-", "_").lower() # Build a requirement line for current interpreter. - req = package['package']['package_name'] + \ - '==' + package['package']['installed_version'] + req = ( + package["package"]["package_name"] + + "==" + + package["package"]["installed_version"] + ) package_versions[name].append(f'"{interpreter} {req}"') # Append dependency libraries. dep_names = [ - dep['key'].replace('-', '_').lower() for dep in package['dependencies']] - package_libs[name] += [f'{interpreter}:pip_{name}' for name in dep_names] + dep["key"].replace("-", "_").lower() for dep in package["dependencies"] + ] + package_libs[name] += [f"{interpreter}:pip_{name}" for name in dep_names] # Create a united rule for each pip package. for package_name in sorted(package_versions.keys()): @@ -56,9 +63,9 @@ def main(): # Write the output file, only if it is changed, so that the timestamp will not be updated # otherwise. - if not os.path.exists(args.output) or open(args.output, 'r').read() != res: - open(args.output, 'w').write(res) + if not os.path.exists(args.output) or open(args.output, "r").read() != res: + open(args.output, "w").write(res) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/cmake_utils/gen_py_lib.py b/src/cmake_utils/gen_py_lib.py index 6cb2baae..269726ea 100755 --- a/src/cmake_utils/gen_py_lib.py +++ b/src/cmake_utils/gen_py_lib.py @@ -32,51 +32,56 @@ def extract_licenses(filename: str) -> List[str]: - prefix = 'License: ' + prefix = "License: " if os.path.isfile(filename): - with open(filename, encoding='utf8') as fp: + with open(filename, encoding="utf8") as fp: for line in fp.readlines(): if line.startswith(prefix): - return line.strip()[len(prefix):].split(',') + return line.strip()[len(prefix) :].split(",") return [] def main(): parser = ArgumentParser( - description='Generates a json file that holds all the information for a python library ' - 'target.') - parser.add_argument('--name', type=str, help='Python library target name', required=True) + description="Generates a json file that holds all the information for a python library " + "target." + ) + parser.add_argument("--name", type=str, help="Python library target name", required=True) parser.add_argument( - '--interpreters', type=str, nargs='*', help='Supported interpreters', - default=['python3.7']) - parser.add_argument('--lib_dir', type=str, nargs='*', help='Library directory', required=True) + "--interpreters", type=str, nargs="*", help="Supported interpreters", default=["python3.7"] + ) + parser.add_argument("--lib_dir", type=str, nargs="*", help="Library directory", required=True) parser.add_argument( - '--import_paths', type=str, nargs='*', default=[], help='Path to add to sys.path') + "--import_paths", type=str, nargs="*", default=[], help="Path to add to sys.path" + ) + parser.add_argument("--files", type=str, nargs="*", help="Library file list") parser.add_argument( - '--files', type=str, nargs='*', help='Library file list') + "--lib_deps", type=str, nargs="*", help="Dependency libraries list", required=True + ) + parser.add_argument("--output", type=str, help="Output info file", required=True) parser.add_argument( - '--lib_deps', type=str, nargs='*', help='Dependency libraries list', required=True) - parser.add_argument('--output', type=str, help='Output info file', required=True) + "--py_exe_deps", type=str, nargs="*", required=True, help="List of executable dependencies" + ) parser.add_argument( - '--py_exe_deps', type=str, nargs='*', required=True, help='List of executable dependencies') + "--cmake_dir", type=str, nargs="?", help="Directory of this CMake target", required=False + ) parser.add_argument( - '--cmake_dir', type=str, nargs='?', help='Directory of this CMake target', required=False) - parser.add_argument( - '--prefix', type=str, nargs='?', help='Prefix of this CMake target', required=False) + "--prefix", type=str, nargs="?", help="Prefix of this CMake target", required=False + ) args = parser.parse_args() # Try to extract license if possible. licenses = [] for d in args.lib_dir: # Remove filters if exist (like 'pypy:'). - d = d.split(':')[-1] - metadata_files = glob.glob(os.path.join(d, '*/METADATA')) + d = d.split(":")[-1] + metadata_files = glob.glob(os.path.join(d, "*/METADATA")) for filename in metadata_files: licenses += extract_licenses(filename) licenses = sorted(set(licenses)) os.makedirs(os.path.dirname(args.output), exist_ok=True) - with open(args.output, 'w') as fp: + with open(args.output, "w") as fp: json.dump( dict( name=args.name, @@ -94,8 +99,8 @@ def main(): sort_keys=True, indent=4, ) - fp.write('\n') + fp.write("\n") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/cmake_utils/gen_python_exe.py b/src/cmake_utils/gen_python_exe.py index 86123e96..baaf82ad 100755 --- a/src/cmake_utils/gen_python_exe.py +++ b/src/cmake_utils/gen_python_exe.py @@ -12,43 +12,41 @@ def main(): - parser = ArgumentParser( - description='Generates an executable file for python_exe().') + parser = ArgumentParser(description="Generates an executable file for python_exe().") + parser.add_argument("--name", help="The name of the target", required=True) + parser.add_argument("--exe_path", help="The path to the output script file", required=True) parser.add_argument( - '--name', help='The name of the target', required=True) + "--venv", help="The python virtual environment that will run the module", required=True + ) + parser.add_argument("--module", help="The name of the module to run", required=True) + parser.add_argument("--args", help="Additional arguments to pass to the module") + parser.add_argument("--info_dir", help="Directory for all libraries info files", required=True) parser.add_argument( - '--exe_path', help='The path to the output script file', required=True) + "--cmake_binary_dir", help="The path to the CMake binary root dir", required=True + ) + parser.add_argument("--working_dir", help="Working directory to run the executable from.") parser.add_argument( - '--venv', help='The python virtual environment that will run the module', required=True) - parser.add_argument( - '--module', help='The name of the module to run', required=True) - parser.add_argument( - '--args', help='Additional arguments to pass to the module') - parser.add_argument( - '--info_dir', help='Directory for all libraries info files', required=True) - parser.add_argument( - '--cmake_binary_dir', help='The path to the CMake binary root dir', required=True) - parser.add_argument( - '--working_dir', help='Working directory to run the executable from.') - parser.add_argument( - '--environment_variables', help='Environments variables for the executable.') + "--environment_variables", help="Environments variables for the executable." + ) args = parser.parse_args() - venv_info = json.load(open(os.path.join(args.info_dir, f'{args.venv}.info'))) + venv_info = json.load(open(os.path.join(args.info_dir, f"{args.venv}.info"))) # Fetch the location of the venv dir, relative to the executable script. - build_path_bash = os.path.relpath( - args.cmake_binary_dir, os.path.dirname(args.exe_path)) - assert 'venv_dir' in venv_info, \ - f'venv_dir not found, make sure "{args.venv}" is a valid virtual environment.' - venv_dir_rel = os.path.relpath(venv_info['venv_dir'], args.cmake_binary_dir) - cd_command = f'cd {args.working_dir}' if args.working_dir else '' + build_path_bash = os.path.relpath(args.cmake_binary_dir, os.path.dirname(args.exe_path)) + assert ( + "venv_dir" in venv_info + ), f'venv_dir not found, make sure "{args.venv}" is a valid virtual environment.' + venv_dir_rel = os.path.relpath(venv_info["venv_dir"], args.cmake_binary_dir) + cd_command = f"cd {args.working_dir}" if args.working_dir else "" exe_args = args.args.replace( - '{VENV_SITE_DIR}', - '${BUILD_ROOT}/' + os.path.relpath(venv_info['site_dir'], args.cmake_binary_dir)) + "{VENV_SITE_DIR}", + "${BUILD_ROOT}/" + os.path.relpath(venv_info["site_dir"], args.cmake_binary_dir), + ) - with open(args.exe_path, 'w') as fp: - fp.write(f"""\ + with open(args.exe_path, "w") as fp: + fp.write( + f"""\ #!/bin/bash # Find the directory of the executable using $(dirname $0), convert it to absolute path using # realpath, and use it to find build directory (e.g., .../build/Debug or /app/). @@ -60,22 +58,32 @@ def main(): CMAKE_TARGET_NAME={args.name} \ ${{BUILD_ROOT}}/{venv_dir_rel}/bin/python -u -m {args.module} \ {exe_args} $@ -""") +""" + ) os.chmod( args.exe_path, - stat.S_IXUSR | stat.S_IRUSR | stat.S_IWUSR | - stat.S_IXGRP | stat.S_IRGRP | - stat.S_IXOTH | stat.S_IROTH) + stat.S_IXUSR + | stat.S_IRUSR + | stat.S_IWUSR + | stat.S_IXGRP + | stat.S_IRGRP + | stat.S_IXOTH + | stat.S_IROTH, + ) # Generate info file. - with open(os.path.join(args.info_dir, f'{args.name}.info'), 'w') as fp: - json.dump({ - 'exe_path': args.exe_path, - 'venv': args.venv, - }, fp, indent=4) - fp.write('\n') + with open(os.path.join(args.info_dir, f"{args.name}.info"), "w") as fp: + json.dump( + { + "exe_path": args.exe_path, + "venv": args.venv, + }, + fp, + indent=4, + ) + fp.write("\n") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/cmake_utils/gen_venv.py b/src/cmake_utils/gen_venv.py index 2d236e83..b9b99538 100755 --- a/src/cmake_utils/gen_venv.py +++ b/src/cmake_utils/gen_venv.py @@ -22,7 +22,7 @@ def filter_interpreter(python: str, entries: List[str]): """ res = [] for x in entries: - parts = x.split(':') + parts = x.split(":") if len(parts) == 1: # Common entry. res.append(x) @@ -46,10 +46,10 @@ def find_dependency_libraries(python: str, libs: List[str], info_dir: str) -> Di lib = library_queue.pop() if lib in found_libraries: continue - filename = os.path.join(info_dir, f'{lib}.info') - with open(filename, 'r') as fp: + filename = os.path.join(info_dir, f"{lib}.info") + with open(filename, "r") as fp: found_libraries[lib] = json.load(fp) - library_queue += filter_interpreter(python, found_libraries[lib]['lib_deps']) + library_queue += filter_interpreter(python, found_libraries[lib]["lib_deps"]) return found_libraries @@ -60,37 +60,39 @@ def fill_init_files(site_dir): if dirpath == site_dir: continue - if dirpath in py_dirs or any(filename.endswith('.py') for filename in filenames): + if dirpath in py_dirs or any(filename.endswith(".py") for filename in filenames): py_dirs.add(os.path.dirname(dirpath)) - if '__init__.py' not in filenames: - with open(os.path.join(dirpath, '__init__.py'), 'w') as f: + if "__init__.py" not in filenames: + with open(os.path.join(dirpath, "__init__.py"), "w") as f: # Create namespace packages, to allow the import of starkware pip libraries. - f.write('__path__ = __import__(\'pkgutil\').extend_path(__path__, __name__)') + f.write("__path__ = __import__('pkgutil').extend_path(__path__, __name__)") def get_pth_dir(python: str, venv_dir: str): - if python == 'python3.7': - return os.path.join(venv_dir, 'lib/python3.7/site-packages') - elif python == 'pypy3': - pth_dir = os.path.join(venv_dir, 'site-packages') + if python == "python3.7": + return os.path.join(venv_dir, "lib/python3.7/site-packages") + elif python == "pypy3": + pth_dir = os.path.join(venv_dir, "site-packages") os.makedirs(pth_dir, exist_ok=True) return pth_dir else: - raise NotImplementedError(f'Unsupported python executable {python}') + raise NotImplementedError(f"Unsupported python executable {python}") def main(): - parser = ArgumentParser(description='Generates a virtual environment.') + parser = ArgumentParser(description="Generates a virtual environment.") parser.add_argument( - '--name', type=str, help='The name of the virtual environment', required=True) + "--name", type=str, help="The name of the virtual environment", required=True + ) + parser.add_argument("--libs", type=str, nargs="*", help="Library list", required=True) + parser.add_argument("--python", help="Python executable", type=str, required=True) + parser.add_argument("--site_dir", help="Site output directory", type=str, required=True) parser.add_argument( - '--libs', type=str, nargs='*', help='Library list', required=True) - parser.add_argument('--python', help='Python executable', type=str, required=True) - parser.add_argument('--site_dir', help='Site output directory', type=str, required=True) + "--venv_dir", help="Virtual environment output directory", type=str, required=True + ) parser.add_argument( - '--venv_dir', help='Virtual environment output directory', type=str, required=True) - parser.add_argument( - '--info_dir', help='Directory for all libraries info files', type=str, required=True) + "--info_dir", help="Directory for all libraries info files", type=str, required=True + ) args = parser.parse_args() # Clean directories. @@ -101,17 +103,17 @@ def main(): # Find python. lookup_paths = [ - '/usr/bin', - '/usr/local/bin', + "/usr/bin", + "/usr/local/bin", ] - python_exec = shutil.which(args.python, path=':'.join(lookup_paths)) + python_exec = shutil.which(args.python, path=":".join(lookup_paths)) # Prepare an empty virtual environment in the background. # --symlinks prefers symlinks of copying. # --without-pip installs a completely empty venv, with no pip. # --clear clears the old venv if exists. - venv_proc = subprocess.Popen([ - python_exec, '-m', 'venv', '--symlinks', '--without-pip', '--clear', - args.venv_dir]) + venv_proc = subprocess.Popen( + [python_exec, "-m", "venv", "--symlinks", "--without-pip", "--clear", args.venv_dir] + ) # Find all libraries. found_libraries = find_dependency_libraries(args.python, args.libs, args.info_dir) @@ -121,18 +123,18 @@ def main(): site_files = [] py_exe_deps = set() for lib_name, lib_info in found_libraries.items(): - imports_list += filter_interpreter(args.python, lib_info['import_paths']) - lib_dirs = filter_interpreter(args.python, lib_info['lib_dir']) - assert len(lib_dirs) == 1, f'Library {lib_name} has {len(lib_dirs)} library directories.' - for filename in filter_interpreter(args.python, lib_info['files']): + imports_list += filter_interpreter(args.python, lib_info["import_paths"]) + lib_dirs = filter_interpreter(args.python, lib_info["lib_dir"]) + assert len(lib_dirs) == 1, f"Library {lib_name} has {len(lib_dirs)} library directories." + for filename in filter_interpreter(args.python, lib_info["files"]): src = os.path.join(lib_dirs[0], filename) dst = os.path.join(args.site_dir, filename) os.makedirs(os.path.dirname(dst), exist_ok=True) - assert not os.path.exists(dst), f'Multiple entries for {filename} in site dir.' + assert not os.path.exists(dst), f"Multiple entries for {filename} in site dir." # Create a hardlink (symlinks don't work well with pytest and conftest.py). os.link(src, dst) site_files.append(src) - py_exe_deps.update(lib_info['py_exe_deps']) + py_exe_deps.update(lib_info["py_exe_deps"]) # Since pytest root discovery is base of __init__.py files, we need to fill dummy __init__.py # In site dir. @@ -141,23 +143,27 @@ def main(): # Generate pth. venv_proc.wait() pth_dir = get_pth_dir(args.python, args.venv_dir) - pth_path = os.path.join(pth_dir, 'venv.pth') - with open(pth_path, 'w') as fp: - fp.write(''.join(os.path.relpath(dirname, pth_dir) + '\n' for dirname in imports_list)) + pth_path = os.path.join(pth_dir, "venv.pth") + with open(pth_path, "w") as fp: + fp.write("".join(os.path.relpath(dirname, pth_dir) + "\n" for dirname in imports_list)) # Generate info file. - with open(os.path.join(args.info_dir, f'{args.name}.info'), 'w') as fp: - json.dump({ - 'python': args.python, - 'venv_dir': args.venv_dir, - 'site_dir': args.site_dir, - 'pth': pth_path, - 'site_files': site_files, - 'imports_list': imports_list, - 'py_exe_deps': sorted(py_exe_deps), - }, fp, indent=4) - fp.write('\n') - - -if __name__ == '__main__': + with open(os.path.join(args.info_dir, f"{args.name}.info"), "w") as fp: + json.dump( + { + "python": args.python, + "venv_dir": args.venv_dir, + "site_dir": args.site_dir, + "pth": pth_path, + "site_files": site_files, + "imports_list": imports_list, + "py_exe_deps": sorted(py_exe_deps), + }, + fp, + indent=4, + ) + fp.write("\n") + + +if __name__ == "__main__": main() diff --git a/src/cmake_utils/unite_lib.py b/src/cmake_utils/unite_lib.py index 2a99d1c2..856a91d8 100755 --- a/src/cmake_utils/unite_lib.py +++ b/src/cmake_utils/unite_lib.py @@ -10,4 +10,4 @@ import sys -sys.stdout.write(' '.join(sorted(set(x.split(':')[-1] for x in sys.argv[1:])))) +sys.stdout.write(" ".join(sorted(set(x.split(":")[-1] for x in sys.argv[1:])))) diff --git a/src/demo/amm_demo/demo.py b/src/demo/amm_demo/demo.py index ea41ba82..9eead9b7 100644 --- a/src/demo/amm_demo/demo.py +++ b/src/demo/amm_demo/demo.py @@ -19,11 +19,11 @@ N_ACCOUNTS = 5 N_BATCHES = 3 -MIN_OPERATOR_BALANCE = 0.1 * 10**18 +MIN_OPERATOR_BALANCE = 0.1 * 10 ** 18 BATCH_SIZE = 10 GAS_PRICE = 10000000000 -AMM_SOURCE_PATH = os.path.join(os.path.dirname(__file__), 'amm.cairo') -CONTRACT_SOURCE_PATH = os.path.join(os.path.dirname(__file__), 'amm_contract.sol') +AMM_SOURCE_PATH = os.path.join(os.path.dirname(__file__), "amm.cairo") +CONTRACT_SOURCE_PATH = os.path.join(os.path.dirname(__file__), "amm_contract.sol") def init_prover(bin_dir: str, node_rpc_url: str) -> BatchProver: @@ -32,16 +32,19 @@ def init_prover(bin_dir: str, node_rpc_url: str) -> BatchProver: node_rpc_url: a URL of an Ethereum node RPC. """ - balance = Balance(a=random.randint(10**6, 10**8), b=random.randint(10**6, 10**8)) + balance = Balance(a=random.randint(10 ** 6, 10 ** 8), b=random.randint(10 ** 6, 10 ** 8)) accounts = { i: Account( pub_key=i, - balance=Balance(a=random.randint(10**5, 10**7), b=random.randint(10**5, 10**7))) - for i in range(N_ACCOUNTS)} + balance=Balance(a=random.randint(10 ** 5, 10 ** 7), b=random.randint(10 ** 5, 10 ** 7)), + ) + for i in range(N_ACCOUNTS) + } sharp_client = init_client(bin_dir=bin_dir, node_rpc_url=node_rpc_url) program = sharp_client.compile_cairo(source_code_path=AMM_SOURCE_PATH) prover = BatchProver( - program=program, balance=balance, accounts=accounts, sharp_client=sharp_client) + program=program, balance=balance, accounts=accounts, sharp_client=sharp_client + ) return prover @@ -63,26 +66,35 @@ def deploy_contract(batch_prover: BatchProver, w3: Web3, operator: eth.Account) cairo_verifier = batch_prover.sharp_client.contract_client.contract.address # Compile the smart contract. - print('Compiling the AMM demo smart contract...') - artifacts = subprocess.check_output( - ['solc', '--bin', '--abi', CONTRACT_SOURCE_PATH]).decode('utf-8').split('\n') + print("Compiling the AMM demo smart contract...") + artifacts = ( + subprocess.check_output(["solc", "--bin", "--abi", CONTRACT_SOURCE_PATH]) + .decode("utf-8") + .split("\n") + ) bytecode = artifacts[3] abi = artifacts[5] new_contract = w3.eth.contract(abi=abi, bytecode=bytecode) transaction = new_contract.constructor( - accountTreeRoot=account_tree_root, amountTokenA=amount_token_a, amountTokenB=amount_token_b, - cairoProgramHash=program_hash, cairoVerifier=cairo_verifier) - print('Deploying the AMM demo smart contract...') + accountTreeRoot=account_tree_root, + amountTokenA=amount_token_a, + amountTokenB=amount_token_b, + cairoProgramHash=program_hash, + cairoVerifier=cairo_verifier, + ) + print("Deploying the AMM demo smart contract...") tx_receipt = send_transaction(w3, transaction, operator) - assert tx_receipt['status'] == 1, \ - f'Failed to deploy contract. Transaction hash: {tx_receipt["transactionHash"]}.' + assert ( + tx_receipt["status"] == 1 + ), f'Failed to deploy contract. Transaction hash: {tx_receipt["transactionHash"]}.' - contract_address = tx_receipt['contractAddress'] + contract_address = tx_receipt["contractAddress"] input( - f'AMM demo smart contract successfully deployed to address {contract_address}. ' - 'You can track the contract state through this link ' - f'https://goerli.etherscan.io/address/{contract_address} .' - 'Press enter to continue.') + f"AMM demo smart contract successfully deployed to address {contract_address}. " + "You can track the contract state through this link " + f"https://goerli.etherscan.io/address/{contract_address} ." + "Press enter to continue." + ) return w3.eth.contract(abi=abi, address=contract_address) @@ -92,41 +104,46 @@ def main(): The main demonstration program. """ - parser = argparse.ArgumentParser(description='AMM demo') + parser = argparse.ArgumentParser(description="AMM demo") parser.add_argument( - '--bin_dir', type=str, default='', - help='The path to a directory that contains the cairo-compile and cairo-run scripts. ' - "If not specified, files are assumed to be in the system's PATH.") + "--bin_dir", + type=str, + default="", + help="The path to a directory that contains the cairo-compile and cairo-run scripts. " + "If not specified, files are assumed to be in the system's PATH.", + ) args = parser.parse_args() # Connect to an Ethereum node. node_rpc_url = input( - 'Please provide an RPC URL to communicate with an Ethereum node on Goerli: ') + "Please provide an RPC URL to communicate with an Ethereum node on Goerli: " + ) w3 = Web3(HTTPProvider(node_rpc_url)) if not w3.isConnected(): - print('Error: could not connect to the Ethereum node.') + print("Error: could not connect to the Ethereum node.") exit(1) # Initialize Ethereum account for on-chain transaction sending. operator_private_key_str = input( - 'Please enter an operator private key, ' - 'or press Enter to generate a new private key: ') + "Please enter an operator private key, " "or press Enter to generate a new private key: " + ) try: operator_private_key = int(operator_private_key_str, 16) except ValueError: - print('Generating a random key...') - operator_private_key = random.randint(0, 2**256) - operator_private_key = '0x{:064x}'.format(operator_private_key) + print("Generating a random key...") + operator_private_key = random.randint(0, 2 ** 256) + operator_private_key = "0x{:064x}".format(operator_private_key) operator = eth.Account.from_key(operator_private_key) # Ask for funds to be transferred to the operator account id its balance is too low. if w3.eth.getBalance(operator.address) < MIN_OPERATOR_BALANCE: input( - f'Please send funds (at least {MIN_OPERATOR_BALANCE * 10**-18} Goerli ETH) ' - f'to {operator.address} and press enter.') + f"Please send funds (at least {MIN_OPERATOR_BALANCE * 10**-18} Goerli ETH) " + f"to {operator.address} and press enter." + ) while w3.eth.getBalance(operator.address) < MIN_OPERATOR_BALANCE: - print('Funds not received yet...') + print("Funds not received yet...") sleep(15) # Initialize the system. @@ -137,29 +154,31 @@ def main(): for _ in range(N_BATCHES): batch = [rand_transaction() for _ in range(BATCH_SIZE)] - print('Sending batch to SHARP...') + print("Sending batch to SHARP...") job_id, fact, program_output = prover.prove_batch(batch) print() - print(f'Waiting for the fact {fact} to be registered on-chain...') + print(f"Waiting for the fact {fact} to be registered on-chain...") mins = 0.0 while not prover.sharp_client.fact_registered(fact): status = prover.sharp_client.get_job_status(job_id) print( f"Elapsed: {mins} minutes. Status of job id '{job_id}' " - f"and fact '{fact}' is '{status}'.") + f"and fact '{fact}' is '{status}'." + ) sleep(15) mins += 0.25 print() - print('Updating on-chain state...') + print("Updating on-chain state...") transaction = amm_contract.functions.updateState(programOutput=program_output) tx_receipt = send_transaction(w3, transaction, operator) - assert tx_receipt['status'] == 1, \ - 'Failed to update the on-chain state. ' \ + assert tx_receipt["status"] == 1, ( + "Failed to update the on-chain state. " f'Transaction hash: {tx_receipt["transactionHash"]}.' + ) print() - print('AMM Demo finished successfully :)') + print("AMM Demo finished successfully :)") def tx_kwargs(w3: Web3, sender_account: eth.Account): @@ -170,7 +189,7 @@ def tx_kwargs(w3: Web3, sender_account: eth.Account): sender_account: the account sending the transaction. """ nonce = w3.eth.getTransactionCount(sender_account) - return {'from': sender_account, 'gas': 10**6, 'gasPrice': GAS_PRICE, 'nonce': nonce} + return {"from": sender_account, "gas": 10 ** 6, "gasPrice": GAS_PRICE, "nonce": nonce} def send_transaction(w3, transaction, sender_account: eth.Account): @@ -183,11 +202,11 @@ def send_transaction(w3, transaction, sender_account: eth.Account): """ transaction_dict = transaction.buildTransaction(tx_kwargs(w3, sender_account.address)) signed_transaction = sender_account.signTransaction(transaction_dict) - print('Transaction built and signed.') + print("Transaction built and signed.") tx_hash = w3.eth.sendRawTransaction(signed_transaction.rawTransaction).hex() - print(f'Transaction sent. tx_hash={tx_hash} .') + print(f"Transaction sent. tx_hash={tx_hash} .") receipt = w3.eth.waitForTransactionReceipt(tx_hash) - print('Transaction successfully mined.') + print("Transaction successfully mined.") return receipt @@ -198,9 +217,12 @@ def get_merkle_root(accounts: Dict[int, Balance]) -> int: accounts: the state of the accounts (the merkle tree leaves). """ tree = MerkleTree(tree_height=10, default_leaf=0) - return tree.compute_merkle_root([ - (i, pedersen_hash(pedersen_hash(a.pub_key, a.balance.a), a.balance.b)) - for i, a in accounts.items()]) + return tree.compute_merkle_root( + [ + (i, pedersen_hash(pedersen_hash(a.pub_key, a.balance.a), a.balance.b)) + for i, a in accounts.items() + ] + ) def rand_transaction() -> SwapTransaction: @@ -208,9 +230,10 @@ def rand_transaction() -> SwapTransaction: Draws a random swap transaction. """ return SwapTransaction( - account_id=random.randint(0, N_ACCOUNTS - 1), token_a_amount=random.randint(1, 1000)) + account_id=random.randint(0, N_ACCOUNTS - 1), token_a_amount=random.randint(1, 1000) + ) -if __name__ == '__main__': - with get_crypto_lib_context_manager('Release'): +if __name__ == "__main__": + with get_crypto_lib_context_manager("Release"): main() diff --git a/src/demo/amm_demo/prove_batch.py b/src/demo/amm_demo/prove_batch.py index bd4ecd75..12ecd78a 100644 --- a/src/demo/amm_demo/prove_batch.py +++ b/src/demo/amm_demo/prove_batch.py @@ -12,6 +12,7 @@ class Balance: """ Represents the balance of each of the two tokens. """ + a: int b: int @@ -30,8 +31,12 @@ class SwapTransaction: class BatchProver: def __init__( - self, program: Program, balance: Balance, accounts: Dict[int, Account], - sharp_client: SharpClient): + self, + program: Program, + balance: Balance, + accounts: Dict[int, Account], + sharp_client: SharpClient, + ): """ Initializes the prover client. Parameters: @@ -64,22 +69,23 @@ def get_program_input(self, transactions: List[SwapTransaction]): Constructs the Cairo program input from the provided transactions and the system state. """ program_input: Dict[str, Any] = { - 'token_a_balance': self.balance.a, - 'token_b_balance': self.balance.b, - 'accounts': {}, - 'transactions': [] + "token_a_balance": self.balance.a, + "token_b_balance": self.balance.b, + "accounts": {}, + "transactions": [], } for index, account in self.accounts.items(): - program_input['accounts'][str(index)] = { - 'public_key': hex(account.pub_key), - 'token_a_balance': account.balance.a, - 'token_b_balance': account.balance.b, + program_input["accounts"][str(index)] = { + "public_key": hex(account.pub_key), + "token_a_balance": account.balance.a, + "token_b_balance": account.balance.b, } for tx in transactions: - program_input['transactions'].append( - {'account_id': tx.account_id, 'token_a_amount': tx.token_a_amount}) + program_input["transactions"].append( + {"account_id": tx.account_id, "token_a_amount": tx.token_a_amount} + ) return program_input @@ -88,12 +94,12 @@ def submit_job(self, program_input) -> Tuple[str, str, List[int]]: Submits a SHARP job to prove the state transition implied by the provided transactions. Returns the job id in the SHARP service, the fact to be registered, and the program output. """ - with tempfile.NamedTemporaryFile(mode='w') as program_input_file: - json.dump( - program_input, program_input_file, indent=4, sort_keys=True) + with tempfile.NamedTemporaryFile(mode="w") as program_input_file: + json.dump(program_input, program_input_file, indent=4, sort_keys=True) program_input_file.flush() cairo_pie = self.sharp_client.run_program( - program=self.program, program_input_path=program_input_file.name) + program=self.program, program_input_path=program_input_file.name + ) job_key = self.sharp_client.submit_cairo_pie(cairo_pie=cairo_pie) fact = self.sharp_client.get_fact(cairo_pie) diff --git a/src/services/everest/api/feeder_gateway/feeder_gateway_client.py b/src/services/everest/api/feeder_gateway/feeder_gateway_client.py index d552bc1d..098d4956 100644 --- a/src/services/everest/api/feeder_gateway/feeder_gateway_client.py +++ b/src/services/everest/api/feeder_gateway/feeder_gateway_client.py @@ -9,11 +9,11 @@ class EverestFeederGatewayClient(BaseClient): Base class to FeederGatewayClient classes. """ - prefix: ClassVar[str] = '/feeder_gateway' + prefix: ClassVar[str] = "/feeder_gateway" async def is_alive(self) -> str: - return await self._send_request(send_method='GET', uri='/is_alive') + return await self._send_request(send_method="GET", uri="/is_alive") async def get_last_batch_id(self) -> int: - raw_response = await self._send_request(send_method='GET', uri='/get_last_batch_id') + raw_response = await self._send_request(send_method="GET", uri="/get_last_batch_id") return json.loads(raw_response) diff --git a/src/services/everest/api/gateway/gateway_client.py b/src/services/everest/api/gateway/gateway_client.py index 7a357492..a000f8aa 100644 --- a/src/services/everest/api/gateway/gateway_client.py +++ b/src/services/everest/api/gateway/gateway_client.py @@ -10,17 +10,19 @@ class EverestGatewayClient(BaseClient): Base class to GatewayClient classes. """ - prefix: ClassVar[str] = '/gateway' + prefix: ClassVar[str] = "/gateway" async def is_alive(self) -> str: - return await self._send_request(send_method='GET', uri='/is_alive') + return await self._send_request(send_method="GET", uri="/is_alive") async def add_transaction_request( - self, add_tx_request: EverestAddTransactionRequest) -> Dict[str, str]: + self, add_tx_request: EverestAddTransactionRequest + ) -> Dict[str, str]: raw_response = await self._send_request( - send_method='POST', uri='/add_transaction', data=add_tx_request.dumps()) + send_method="POST", uri="/add_transaction", data=add_tx_request.dumps() + ) return json.loads(raw_response) async def get_first_unused_tx_id(self) -> int: - response = await self._send_request(send_method='GET', uri='/get_first_unused_tx_id') + response = await self._send_request(send_method="GET", uri="/get_first_unused_tx_id") return json.loads(response) diff --git a/src/services/everest/api/gateway/transaction.py b/src/services/everest/api/gateway/transaction.py index 78c5c20e..2fccfc04 100644 --- a/src/services/everest/api/gateway/transaction.py +++ b/src/services/everest/api/gateway/transaction.py @@ -14,6 +14,7 @@ class EverestTransaction(ValidatedMarshmallowDataclass): Schema: ClassVar[Type[marshmallow_oneofschema.OneOfSchema]] + class EverestAddTransactionRequest(ValidatedMarshmallowDataclass): tx: EverestTransaction tx_id: int diff --git a/src/services/everest/business_logic/internal_transaction.py b/src/services/everest/business_logic/internal_transaction.py index 7cb8cc57..a4d4d196 100644 --- a/src/services/everest/business_logic/internal_transaction.py +++ b/src/services/everest/business_logic/internal_transaction.py @@ -28,8 +28,9 @@ class TransactionSchema(OneOfSchema): def get_obj_type(self, obj): name = type(obj).__name__ - assert name in classes.keys() and classes[name] == type(obj), \ - f'Trying to serialized the object {obj} that was not registered first.' + assert name in classes.keys() and classes[name] == type( + obj + ), f"Trying to serialized the object {obj} that was not registered first." # We register the Schema object here, since it might not exists when the object # itself is registered. if name not in self.type_schemas.keys(): @@ -41,8 +42,9 @@ def get_obj_type(self, obj): def add_class(self, cls: type): cls_name = cls.__name__ if cls_name in self.classes: - assert self.classes[cls_name] == cls, \ - f'Trying to register two classes with the same name {cls_name}' + assert ( + self.classes[cls_name] == cls + ), f"Trying to register two classes with the same name {cls_name}" else: self.classes[cls_name] = cls @@ -106,8 +108,8 @@ def external_name(self) -> str: @classmethod @abstractmethod def from_external( - cls, external_tx: EverestTransaction, - general_config: Config) -> 'EverestInternalTransaction': + cls, external_tx: EverestTransaction, general_config: Config + ) -> "EverestInternalTransaction": """ Returns an internal transaction genearated based on an external one. """ @@ -129,8 +131,8 @@ def get_state_selector(self, general_config: Config) -> StateSelectorBase: @abstractmethod async def apply_state_updates( - self, state: CarriedStateBase, - general_config: Config) -> Optional[EverestTransactionExecutionInfo]: + self, state: CarriedStateBase, general_config: Config + ) -> Optional[EverestTransactionExecutionInfo]: """ Applies the transaction on the Merkle state in an atomic manner. Returns an object containing information about the execution of the transaction, or None - @@ -148,8 +150,8 @@ def verify_signatures(self): @staticmethod @abstractmethod def get_state_selector_of_many( - txs: Iterable['EverestInternalTransaction'], - general_config: Config) -> StateSelectorBase: + txs: Iterable["EverestInternalTransaction"], general_config: Config + ) -> StateSelectorBase: """ Returns the state selector of a collection of transactions (i.e., union of selectors). The implementation of this method must be to downcast the return type. @@ -157,9 +159,12 @@ def get_state_selector_of_many( @staticmethod def _get_state_selector_of_many( - txs: Iterable['EverestInternalTransaction'], - general_config: Config, - state_selector_cls: Type[StateSelectorBase]) -> StateSelectorBase: + txs: Iterable["EverestInternalTransaction"], + general_config: Config, + state_selector_cls: Type[StateSelectorBase], + ) -> StateSelectorBase: return functools.reduce( - operator.__or__, (tx.get_state_selector(general_config=general_config) for tx in txs), - state_selector_cls.empty()) + operator.__or__, + (tx.get_state_selector(general_config=general_config) for tx in txs), + state_selector_cls.empty(), + ) diff --git a/src/services/everest/business_logic/state.py b/src/services/everest/business_logic/state.py index 814edf20..c3ef5bb3 100644 --- a/src/services/everest/business_logic/state.py +++ b/src/services/everest/business_logic/state.py @@ -6,10 +6,10 @@ from starkware.starkware_utils.config_base import Config from starkware.storage.storage import FactFetchingContext -TStateSelector = TypeVar('TStateSelector', bound='StateSelectorBase') -TCarriedState = TypeVar('TCarriedState', bound='CarriedStateBase') -TSharedState = TypeVar('TSharedState', bound='SharedStateBase') -TGeneralConfig = TypeVar('TGeneralConfig', bound=Config) +TStateSelector = TypeVar("TStateSelector", bound="StateSelectorBase") +TCarriedState = TypeVar("TCarriedState", bound="CarriedStateBase") +TSharedState = TypeVar("TSharedState", bound="SharedStateBase") +TGeneralConfig = TypeVar("TGeneralConfig", bound=Config) class StateSelectorBase(ABC): @@ -96,8 +96,9 @@ def fill_missing(self, other: TCarriedState): Fills missing entries from another CarriedState instance. """ state_selector = self.state_selector - assert state_selector & other.state_selector == type(state_selector).empty(), \ - 'Selectors must be disjoint.' + assert ( + state_selector & other.state_selector == type(state_selector).empty() + ), "Selectors must be disjoint." self._fill_missing(other=other) @abstractmethod @@ -150,8 +151,8 @@ def __repr__(self) -> str: @classmethod @abstractmethod async def empty( - cls: Type[TSharedState], ffc: FactFetchingContext, - general_config: Config) -> TSharedState: + cls: Type[TSharedState], ffc: FactFetchingContext, general_config: Config + ) -> TSharedState: """ Returns an empty state. This is called before creating very first batch. """ @@ -164,12 +165,15 @@ def to_carried_state(self: TSharedState, ffc: FactFetchingContext) -> CarriedSta @abstractmethod async def get_filled_carried_state( - self: TSharedState, ffc: FactFetchingContext, - state_selector: StateSelectorBase) -> CarriedStateBase: + self: TSharedState, ffc: FactFetchingContext, state_selector: StateSelectorBase + ) -> CarriedStateBase: pass @abstractmethod async def apply_state_updates( - self: TSharedState, ffc: FactFetchingContext, previous_carried_state: CarriedStateBase, - current_carried_state: CarriedStateBase) -> TSharedState: + self: TSharedState, + ffc: FactFetchingContext, + previous_carried_state: CarriedStateBase, + current_carried_state: CarriedStateBase, + ) -> TSharedState: pass diff --git a/src/services/everest/definitions/fields.py b/src/services/everest/definitions/fields.py index 7f7d4196..9cb976a7 100644 --- a/src/services/everest/definitions/fields.py +++ b/src/services/everest/definitions/fields.py @@ -14,7 +14,8 @@ # Fields data: validation data, dataclass metadata. tx_id_marshmallow_field = mfields.Integer( - strict=True, required=True, validate=validate_non_negative('tx_id')) + strict=True, required=True, validate=validate_non_negative("tx_id") +) tx_id_field_metadata = dict(marshmallow_field=tx_id_marshmallow_field) @@ -37,8 +38,8 @@ def name(self) -> str: # Randomization. def get_random_value(self, random_object: Optional[random.Random] = None) -> str: r = initialize_random(random_object=random_object) - raw_address = ''.join(r.choices(population=string.hexdigits, k=40)) - return Web3.toChecksumAddress(value=f'0x{raw_address}') + raw_address = "".join(r.choices(population=string.hexdigits, k=40)) + return Web3.toChecksumAddress(value=f"0x{raw_address}") # Validation. def is_valid(self, value: str) -> bool: @@ -46,9 +47,9 @@ def is_valid(self, value: str) -> bool: def get_invalid_values(self) -> List[str]: return [ - '0x0Fa81Ec60fe5422d49174F1abdfdC06a9F1c52F2', # Not checksummed. + "0x0Fa81Ec60fe5422d49174F1abdfdC06a9F1c52F2", # Not checksummed. self.get_random_value()[:-1], # Too short address. - self.get_random_value() + '0' # type: ignore # Too long address. + self.get_random_value() + "0", # type: ignore # Too long address. ] @property @@ -57,7 +58,7 @@ def error_code(self) -> ErrorCode: def format_invalid_value_error_message(self, value: str, name: Optional[str] = None) -> str: name = self.name if name is None else name - return f'{name} {value} is out of range / not checksummed.' + return f"{name} {value} is out of range / not checksummed." # Serialization. def get_marshmallow_field(self) -> mfields.Field: @@ -74,15 +75,16 @@ def format(self, value: str) -> str: FactRegistryField = EthAddressTypeField( - name='Address of fact registry', - error_code=StarkErrorCode.INVALID_CONTRACT_ADDRESS) + name="Address of fact registry", error_code=StarkErrorCode.INVALID_CONTRACT_ADDRESS +) EthAddressField = EthAddressTypeField( - name='Ethereum address', - error_code=StarkErrorCode.INVALID_ETH_ADDRESS) + name="Ethereum address", error_code=StarkErrorCode.INVALID_ETH_ADDRESS +) EthAddressIntField = RangeValidatedField( lower_bound=constants.ETH_ADDRESS_LOWER_BOUND, upper_bound=constants.ETH_ADDRESS_UPPER_BOUND, - name_in_error_message='Ethereum address', - out_of_range_error_code=StarkErrorCode.OUT_OF_RANGE_ETH_ADDRESS) + name_in_error_message="Ethereum address", + out_of_range_error_code=StarkErrorCode.OUT_OF_RANGE_ETH_ADDRESS, +) diff --git a/src/services/external_api/base_client.py b/src/services/external_api/base_client.py index 40e5e8bd..a23e31d8 100644 --- a/src/services/external_api/base_client.py +++ b/src/services/external_api/base_client.py @@ -24,7 +24,7 @@ def __init__(self, status_code: int, text: str): self.text = text def __repr__(self) -> str: - return f'HTTP error ocurred. Status: {self.status_code}. Text: {self.text}' + return f"HTTP error ocurred. Status: {self.status_code}. Text: {self.text}" def __str__(self) -> str: """ @@ -42,7 +42,10 @@ class RetryConfig: # Set n_retries == -1 for unlimited retries (for any error type). n_retries: int = 30 retry_codes: Sequence[HTTPStatus] = ( - HTTPStatus.BAD_GATEWAY, HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT) + HTTPStatus.BAD_GATEWAY, + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT, + ) class BaseClient(HasUriPrefix): @@ -51,14 +54,18 @@ class BaseClient(HasUriPrefix): """ def __init__( - self, url: str, certificates_path: Optional[str] = None, - retry_config: Optional[RetryConfig] = None): + self, + url: str, + certificates_path: Optional[str] = None, + retry_config: Optional[RetryConfig] = None, + ): self.url = url self.ssl_context: Optional[ssl.SSLContext] = None self.retry_config = RetryConfig() if retry_config is None else retry_config - assert self.retry_config.n_retries > 0 or self.retry_config.n_retries == -1, \ - 'RetryConfig n_retries parameter value must be either a positive int or equals to -1.' + assert ( + self.retry_config.n_retries > 0 or self.retry_config.n_retries == -1 + ), "RetryConfig n_retries parameter value must be either a positive int or equals to -1." if certificates_path is not None: self.ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) @@ -66,14 +73,15 @@ def __init__( self.ssl_context.check_hostname = True self.ssl_context.load_cert_chain( - certfile=os.path.join(certificates_path, 'user.crt'), - keyfile=os.path.join(certificates_path, 'user.key')) + certfile=os.path.join(certificates_path, "user.crt"), + keyfile=os.path.join(certificates_path, "user.key"), + ) - self.ssl_context.load_verify_locations(os.path.join(certificates_path, 'server.crt')) + self.ssl_context.load_verify_locations(os.path.join(certificates_path, "server.crt")) async def _send_request( - self, send_method: str, uri: str, - data: Optional[Union[str, Dict[str, Any]]] = None) -> str: + self, send_method: str, uri: str, data: Optional[Union[str, Dict[str, Any]]] = None + ) -> str: """ Sends an HTTP request to the target URI. Retries upon failure according to the retry configuration: @@ -95,7 +103,8 @@ async def _send_request( async with aiohttp.TCPConnector(ssl=self.ssl_context) as connector: async with aiohttp.ClientSession(connector=connector) as session: async with session.request( - method=send_method, url=url, data=data) as response: + method=send_method, url=url, data=data + ) as response: text = await response.text() if response.status != HTTPStatus.OK: raise BadRequest(status_code=response.status, text=text) @@ -105,19 +114,21 @@ async def _send_request( if limited_retries and n_retries_left == 0: raise - logger.error('ClientConnectorError, retrying...', exc_info=True) + logger.error("ClientConnectorError, retrying...", exc_info=True) except BadRequest as exception: if limited_retries and ( - n_retries_left == 0 or - exception.status_code not in self.retry_config.retry_codes): + n_retries_left == 0 + or exception.status_code not in self.retry_config.retry_codes + ): raise logger.error( - f'Got BadRequest while trying to access {url}. ' - f'status_code: {exception.status_code}. text: {exception.text}, ' - 'retrying...') + f"Got BadRequest while trying to access {url}. " + f"status_code: {exception.status_code}. text: {exception.text}, " + "retrying..." + ) await asyncio.sleep(1) async def is_alive(self) -> str: - return await self._send_request(send_method='GET', uri='/is_alive') + return await self._send_request(send_method="GET", uri="/is_alive") diff --git a/src/services/external_api/has_uri_prefix.py b/src/services/external_api/has_uri_prefix.py index 2ded7f0a..9e293e53 100644 --- a/src/services/external_api/has_uri_prefix.py +++ b/src/services/external_api/has_uri_prefix.py @@ -6,6 +6,7 @@ class HasUriPrefix(ABC): """ A base class of HTTP Gateway services. """ + @property @classmethod @abstractmethod @@ -21,4 +22,4 @@ def format_uri(cls, name: str) -> str: Concatenates cls.prefix with given URI. """ prefix = cast(str, cls.prefix) # Mypy sees the property as a callable. - return name if len(prefix) == 0 else f'{cls.prefix}{name}' + return name if len(prefix) == 0 else f"{cls.prefix}{name}" diff --git a/src/starkware/cairo/bootloader/compute_fact.py b/src/starkware/cairo/bootloader/compute_fact.py index a64a14e1..a1fa9378 100644 --- a/src/starkware/cairo/bootloader/compute_fact.py +++ b/src/starkware/cairo/bootloader/compute_fact.py @@ -13,20 +13,26 @@ def keccak_ints(values: List[int]) -> str: This function is compatible with Web3.solidityKeccak(['uint256[]'], [values]).hex() """ - return '0x' + binascii.hexlify( - keccak(b''.join(value.to_bytes(32, 'big') for value in values))).decode('ascii') + return "0x" + binascii.hexlify( + keccak(b"".join(value.to_bytes(32, "big") for value in values)) + ).decode("ascii") def generate_program_fact( - program_hash: int, program_output: List[int], fact_topology: FactTopology) -> str: + program_hash: int, program_output: List[int], fact_topology: FactTopology +) -> str: """ Generates the program fact of the Cairo program with program_hash and program_output. See GpsOutputParser.sol for more information on the way the fact is computed. """ - return keccak_ints([ - program_hash, - generate_output_root(program_output=program_output, fact_topology=fact_topology).node_hash - ]) + return keccak_ints( + [ + program_hash, + generate_output_root( + program_output=program_output, fact_topology=fact_topology + ).node_hash, + ] + ) @dataclasses.dataclass @@ -34,11 +40,10 @@ class FactNode: node_hash: int end_offset: int size: int - children: List['FactNode'] + children: List["FactNode"] -def generate_output_root( - program_output: List[int], fact_topology: FactTopology) -> FactNode: +def generate_output_root(program_output: List[int], fact_topology: FactTopology) -> FactNode: """ Generates the root of the output Merkle tree for the program fact computation. See GpsOutputParser.sol for more information on the way the fact is computed. @@ -50,32 +55,36 @@ def generate_output_root( node_stack: List[FactNode] = [] for n_pages, n_nodes in zip(tree_structure[::2], tree_structure[1::2]): # Push n_pages to the stack. - assert 0 <= n_pages <= len(page_sizes), 'Invalid tree structure: n_pages is out of range.' + assert 0 <= n_pages <= len(page_sizes), "Invalid tree structure: n_pages is out of range." for _ in range(n_pages): page_size = page_sizes.pop(0) - page_hash = int(keccak_ints(program_output[offset:offset + page_size]), 16) + page_hash = int(keccak_ints(program_output[offset : offset + page_size]), 16) offset += page_size - node_stack.append(FactNode( - node_hash=page_hash, end_offset=offset, size=page_size, children=[])) + node_stack.append( + FactNode(node_hash=page_hash, end_offset=offset, size=page_size, children=[]) + ) - assert 0 <= n_nodes <= len(node_stack), 'Invalid tree structure: n_nodes is out of range.' + assert 0 <= n_nodes <= len(node_stack), "Invalid tree structure: n_nodes is out of range." if n_nodes > 0: # Create a parent node to the last n_nodes in the head of the stack. node_stack, child_nodes = node_stack[:-n_nodes], node_stack[-n_nodes:] # Create an alternating list of hashes and end offsets. node_data = [val for node in child_nodes for val in [node.node_hash, node.end_offset]] - node_stack.append(FactNode( - node_hash=1 + int(keccak_ints(node_data), 16), - end_offset=child_nodes[-1].end_offset, - size=sum(node.size for node in child_nodes), - children=child_nodes)) + node_stack.append( + FactNode( + node_hash=1 + int(keccak_ints(node_data), 16), + end_offset=child_nodes[-1].end_offset, + size=sum(node.size for node in child_nodes), + children=child_nodes, + ) + ) # Make sure there is one node in the stack (hash and end). - assert len(node_stack) == 1, 'Invalid tree structure: stack contains more than one node.' + assert len(node_stack) == 1, "Invalid tree structure: stack contains more than one node." # Make sure all pages were processed. - assert len(page_sizes) == 0, 'Invalid tree structure: not all pages were processed.' + assert len(page_sizes) == 0, "Invalid tree structure: not all pages were processed." assert offset == node_stack[0].end_offset == len(program_output) return node_stack[0] diff --git a/src/starkware/cairo/bootloader/fact_topology.py b/src/starkware/cairo/bootloader/fact_topology.py index 8910b264..65290003 100644 --- a/src/starkware/cairo/bootloader/fact_topology.py +++ b/src/starkware/cairo/bootloader/fact_topology.py @@ -5,7 +5,7 @@ import marshmallow import marshmallow_dataclass -GPS_FACT_TOPOLOGY = 'gps_fact_topology' +GPS_FACT_TOPOLOGY = "gps_fact_topology" @dataclasses.dataclass(frozen=True) diff --git a/src/starkware/cairo/bootloader/generate_fact.py b/src/starkware/cairo/bootloader/generate_fact.py index 08cebba0..c7121186 100644 --- a/src/starkware/cairo/bootloader/generate_fact.py +++ b/src/starkware/cairo/bootloader/generate_fact.py @@ -11,17 +11,19 @@ def get_program_output(cairo_pie: CairoPie) -> List[int]: """ Returns the program output. """ - assert 'output' in cairo_pie.metadata.builtin_segments, 'The output builtin must be used.' - output = cairo_pie.metadata.builtin_segments['output'] + assert "output" in cairo_pie.metadata.builtin_segments, "The output builtin must be used." + output = cairo_pie.metadata.builtin_segments["output"] def verify_int(x: MaybeRelocatable) -> int: - assert isinstance(x, int), \ - f'Expected program output to contain absolute values, found: {x}.' + assert isinstance( + x, int + ), f"Expected program output to contain absolute values, found: {x}." return x return [ verify_int(cairo_pie.memory[RelocatableValue(segment_index=output.index, offset=i)]) - for i in range(output.size)] + for i in range(output.size) + ] def get_cairo_pie_fact_info(cairo_pie: CairoPie, program_hash: Optional[int] = None) -> FactInfo: @@ -31,7 +33,8 @@ def get_cairo_pie_fact_info(cairo_pie: CairoPie, program_hash: Optional[int] = N program_output = get_program_output(cairo_pie=cairo_pie) fact_topology = get_fact_topology_from_additional_data( output_size=len(program_output), - output_builtin_additional_data=cairo_pie.additional_data['output_builtin']) + output_builtin_additional_data=cairo_pie.additional_data["output_builtin"], + ) if program_hash is None: program_hash = get_program_hash(cairo_pie) fact = generate_program_fact(program_hash, program_output, fact_topology=fact_topology) @@ -59,52 +62,59 @@ def get_page_sizes_from_page_dict(output_size: int, pages: dict) -> List[int]: pages_list = [ (int(page_id_str), page_start, page_size) - for page_id_str, (page_start, page_size) in pages.items()] + for page_id_str, (page_start, page_size) in pages.items() + ] for page_id, page_start, page_size in sorted(pages_list): - assert page_id == expected_page_id, f'Expected page id {expected_page_id}, found {page_id}.' + assert page_id == expected_page_id, f"Expected page id {expected_page_id}, found {page_id}." if page_id == 1: - assert isinstance(page_start, int) and 0 < page_start <= output_size, \ - f'Invalid page start {page_start}.' + assert ( + isinstance(page_start, int) and 0 < page_start <= output_size + ), f"Invalid page start {page_start}." page0_size = page_start else: - assert page_start == expected_page_start, \ - f'Expected page start {expected_page_start}, found {page_start}.' + assert ( + page_start == expected_page_start + ), f"Expected page start {expected_page_start}, found {page_start}." - assert isinstance(page_size, int) and 0 < page_size <= output_size, \ - f'Invalid page size {page_size}.' + assert ( + isinstance(page_size, int) and 0 < page_size <= output_size + ), f"Invalid page size {page_size}." expected_page_start = page_start + page_size expected_page_id += 1 if len(pages) > 0: - assert expected_page_start == output_size, 'Pages must cover the entire program output.' + assert expected_page_start == output_size, "Pages must cover the entire program output." return [page0_size] + [page_size for _, (_, page_size) in sorted(pages.items())] def get_fact_topology_from_additional_data( - output_size: int, output_builtin_additional_data: Dict[str, Any]) -> FactTopology: + output_size: int, output_builtin_additional_data: Dict[str, Any] +) -> FactTopology: """ Returns the fact topology from the additional data of the output builtin. """ - pages = output_builtin_additional_data['pages'] - attributes = output_builtin_additional_data['attributes'] + pages = output_builtin_additional_data["pages"] + attributes = output_builtin_additional_data["attributes"] # If the GPS_FACT_TOPOLOGY attribute is present, use it. Otherwise, the task is expected to # use exactly one page (page 0). if GPS_FACT_TOPOLOGY in attributes: tree_structure = attributes[GPS_FACT_TOPOLOGY] - assert isinstance(tree_structure, list) and \ - len(tree_structure) % 2 == 0 and \ - 0 < len(tree_structure) <= 10 and \ - all(isinstance(x, int) and 0 <= x < 2**30 for x in tree_structure), \ - f"Invalid tree structure specified in the '{GPS_FACT_TOPOLOGY}' attribute." + assert ( + isinstance(tree_structure, list) + and len(tree_structure) % 2 == 0 + and 0 < len(tree_structure) <= 10 + and all(isinstance(x, int) and 0 <= x < 2 ** 30 for x in tree_structure) + ), f"Invalid tree structure specified in the '{GPS_FACT_TOPOLOGY}' attribute." else: - assert len(pages) == 0, \ - f"Additional pages cannot be used since the '{GPS_FACT_TOPOLOGY}' attribute is not " \ - 'specified.' + assert len(pages) == 0, ( + f"Additional pages cannot be used since the '{GPS_FACT_TOPOLOGY}' attribute is not " + "specified." + ) tree_structure = [1, 0] return FactTopology( - tree_structure=tree_structure, - page_sizes=get_page_sizes_from_page_dict(output_size, pages)) + tree_structure=tree_structure, page_sizes=get_page_sizes_from_page_dict(output_size, pages) + ) diff --git a/src/starkware/cairo/bootloader/hash_program.py b/src/starkware/cairo/bootloader/hash_program.py index 24e1bfcb..7262373f 100644 --- a/src/starkware/cairo/bootloader/hash_program.py +++ b/src/starkware/cairo/bootloader/hash_program.py @@ -11,7 +11,7 @@ def compute_program_hash_chain(program: ProgramBase, bootloader_version=0): """ Computes a hash chain over a program, including the length of the data chain. """ - builtin_list = [int.from_bytes(builtin.encode('ascii'), 'big') for builtin in program.builtins] + builtin_list = [int.from_bytes(builtin.encode("ascii"), "big") for builtin in program.builtins] # The program header below is missing the data length, which is later added to the data_chain. program_header = [bootloader_version, program.main, len(program.builtins)] + builtin_list data_chain = program_header + program.data @@ -20,15 +20,21 @@ def compute_program_hash_chain(program: ProgramBase, bootloader_version=0): def main(): - parser = argparse.ArgumentParser( - description='A tool to compute the hash of a cairo program') - parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') + parser = argparse.ArgumentParser(description="A tool to compute the hash of a cairo program") + parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") parser.add_argument( - '--program', type=argparse.FileType('r'), required=True, - help='The name of the program json file.') + "--program", + type=argparse.FileType("r"), + required=True, + help="The name of the program json file.", + ) parser.add_argument( - '--flavor', type=str, default='Release', choices=['Debug', 'Release', 'RelWithDebInfo'], - help='Build flavor') + "--flavor", + type=str, + default="Release", + choices=["Debug", "Release", "RelWithDebInfo"], + help="Build flavor", + ) args = parser.parse_args() with get_crypto_lib_context_manager(args.flavor): @@ -36,5 +42,5 @@ def main(): print(hex(compute_program_hash_chain(program))) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/starkware/cairo/builtin_selection/select_input_builtins_test.py b/src/starkware/cairo/builtin_selection/select_input_builtins_test.py index b2ee9263..105734d7 100644 --- a/src/starkware/cairo/builtin_selection/select_input_builtins_test.py +++ b/src/starkware/cairo/builtin_selection/select_input_builtins_test.py @@ -8,17 +8,21 @@ @pytest.mark.parametrize( - 'builtin_selection_indicators', [ - [True, True, True, True, True], [False, False, False, False, False], - [True, False, False, True, False]], - ids=['select_all_builtins', 'do_not_select_any_builtin', 'select_output_and_ecdsa_builtins']) + "builtin_selection_indicators", + [ + [True, True, True, True, True], + [False, False, False, False, False], + [True, False, False, True, False], + ], + ids=["select_all_builtins", "do_not_select_any_builtin", "select_output_and_ecdsa_builtins"], +) def test_select_input_builtins(builtin_selection_indicators): """ Tests the select_input_builtins Cairo function: calls the function with different builtins selection and checks that the function returns the expected builtin pointers. """ # Setup runner. - cairo_file = os.path.join(os.path.dirname(__file__), 'select_input_builtins.cairo') + cairo_file = os.path.join(os.path.dirname(__file__), "select_input_builtins.cairo") runner = CairoRunner.from_file(cairo_file, DEFAULT_PRIME) runner.initialize_segments() @@ -30,18 +34,24 @@ def test_select_input_builtins(builtin_selection_indicators): # Setup function. builtins_encoding = { - builtin: int.from_bytes(builtin.encode('ascii'), 'big') - for builtin in ['output', 'pedersen', 'range_check', 'ecdsa', 'bitwise']} + builtin: int.from_bytes(builtin.encode("ascii"), "big") + for builtin in ["output", "pedersen", "range_check", "ecdsa", "bitwise"] + } all_builtins = [output_base, hash_base, range_check_base, signature_base, bitwise_base] selected_builtin_encodings = [ - builtin_encoding for builtin_encoding, is_builtin_selected in zip( - builtins_encoding.values(), builtin_selection_indicators) - if is_builtin_selected] + builtin_encoding + for builtin_encoding, is_builtin_selected in zip( + builtins_encoding.values(), builtin_selection_indicators + ) + if is_builtin_selected + ] selected_builtins = [ - builtin for builtin, is_builtin_selected in zip(all_builtins, builtin_selection_indicators) - if is_builtin_selected] + builtin + for builtin, is_builtin_selected in zip(all_builtins, builtin_selection_indicators) + if is_builtin_selected + ] all_encodings = create_memory_struct(runner, builtins_encoding.values()) selected_encodings = create_memory_struct(runner, selected_builtin_encodings) @@ -49,7 +59,7 @@ def test_select_input_builtins(builtin_selection_indicators): n_builtins = len(selected_builtin_encodings) args = [all_encodings, all_ptrs, selected_encodings, n_builtins] - end = runner.initialize_function_entrypoint('select_input_builtins', args) + end = runner.initialize_function_entrypoint("select_input_builtins", args) # Setup context. runner.initialize_vm(hint_locals={}) @@ -62,5 +72,5 @@ def test_select_input_builtins(builtin_selection_indicators): # 'select_input_builtins' should return the pointers to the selected builtins. return_values_addr = context.ap - n_builtins assert [ - context.memory[return_values_addr + i] for i in range(len(selected_builtins))] == \ - selected_builtins + context.memory[return_values_addr + i] for i in range(len(selected_builtins)) + ] == selected_builtins diff --git a/src/starkware/cairo/builtin_selection/validate_builtins_test.py b/src/starkware/cairo/builtin_selection/validate_builtins_test.py index 01fc65ce..2521b444 100644 --- a/src/starkware/cairo/builtin_selection/validate_builtins_test.py +++ b/src/starkware/cairo/builtin_selection/validate_builtins_test.py @@ -4,17 +4,18 @@ from starkware.cairo.common.test_utils import create_memory_struct from starkware.cairo.lang.builtins.range_check.range_check_builtin_runner import ( - RangeCheckBuiltinRunner) + RangeCheckBuiltinRunner, +) from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.instances import small_instance from starkware.cairo.lang.vm.cairo_runner import CairoRunner from starkware.cairo.lang.vm.vm import VmException -CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'validate_builtins.cairo') +CAIRO_FILE = os.path.join(os.path.dirname(__file__), "validate_builtins.cairo") @pytest.mark.parametrize( - 'old_builtins, new_builtins, builtin_sizes, expect_throw', + "old_builtins, new_builtins, builtin_sizes, expect_throw", [ ([10, 10], [10, 10], [1, 1], False), # Second builtin usage is negative. @@ -23,7 +24,8 @@ # Second builtin usage is not a multiple of the builtin_size. ([0, 0], [1, 2], [1, 3], True), ([5, 5], [9, 26], [2, 7], False), - ]) + ], +) def test_validate_builtins(old_builtins, new_builtins, builtin_sizes, expect_throw): """ Tests the inner_validate_builtins_usage Cairo function: calls the function with different @@ -31,12 +33,15 @@ def test_validate_builtins(old_builtins, new_builtins, builtin_sizes, expect_thr """ # Setup runner. runner = CairoRunner.from_file(CAIRO_FILE, DEFAULT_PRIME) - assert len(runner.program.hints) == 0, 'Expecting validator to have no hints.' + assert len(runner.program.hints) == 0, "Expecting validator to have no hints." range_check_builtin = RangeCheckBuiltinRunner( - included=True, ratio=None, inner_rc_bound=2 ** 16, - n_parts=small_instance.builtins['range_check'].n_parts) - runner.builtin_runners['range_check_builtin'] = range_check_builtin + included=True, + ratio=None, + inner_rc_bound=2 ** 16, + n_parts=small_instance.builtins["range_check"].n_parts, + ) + runner.builtin_runners["range_check_builtin"] = range_check_builtin runner.initialize_segments() # Setup function. @@ -50,12 +55,12 @@ def test_validate_builtins(old_builtins, new_builtins, builtin_sizes, expect_thr builtins_sizes, len(builtin_sizes), ] - end = runner.initialize_function_entrypoint('validate_builtins', args) + end = runner.initialize_function_entrypoint("validate_builtins", args) # Setup context. runner.initialize_vm(hint_locals={}) if expect_throw: - with pytest.raises(VmException, match='is out of range'): + with pytest.raises(VmException, match="is out of range"): runner.run_until_pc(end) else: runner.run_until_pc(end) diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 57f2d731..d4426f76 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -3,8 +3,9 @@ python_lib(cairo_common_lib FILES alloc.cairo bitwise.cairo - cairo_keccak/keccak_utils.py cairo_builtins.cairo + cairo_keccak/keccak_utils.py + cairo_sha256/sha256_utils.py default_dict.cairo dict_access.cairo dict.cairo @@ -21,6 +22,7 @@ python_lib(cairo_common_lib math_utils.py math.cairo memcpy.cairo + memset.cairo merkle_multi_update.cairo merkle_update.cairo patricia_utils.py diff --git a/src/starkware/cairo/common/cairo_function_runner.py b/src/starkware/cairo/common/cairo_function_runner.py index 70d70587..714c4253 100644 --- a/src/starkware/cairo/common/cairo_function_runner.py +++ b/src/starkware/cairo/common/cairo_function_runner.py @@ -8,7 +8,8 @@ from starkware.cairo.lang.builtins.ec.instance_def import EcOpInstanceDef from starkware.cairo.lang.builtins.hash.hash_builtin_runner import HashBuiltinRunner from starkware.cairo.lang.builtins.range_check.range_check_builtin_runner import ( - RangeCheckBuiltinRunner) + RangeCheckBuiltinRunner, +) from starkware.cairo.lang.builtins.signature.signature_builtin_runner import SignatureBuiltinRunner from starkware.cairo.lang.compiler.identifier_definition import LabelDefinition from starkware.cairo.lang.compiler.program import Program @@ -28,20 +29,27 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) pedersen_builtin = HashBuiltinRunner( - name='pedersen', included=True, ratio=32, hash_func=pedersen_hash) - self.builtin_runners['pedersen_builtin'] = pedersen_builtin + name="pedersen", included=True, ratio=32, hash_func=pedersen_hash + ) + self.builtin_runners["pedersen_builtin"] = pedersen_builtin range_check_builtin = RangeCheckBuiltinRunner( - included=True, ratio=None, inner_rc_bound=2 ** 16, n_parts=8) - self.builtin_runners['range_check_builtin'] = range_check_builtin + included=True, ratio=None, inner_rc_bound=2 ** 16, n_parts=8 + ) + self.builtin_runners["range_check_builtin"] = range_check_builtin output_builtin = OutputBuiltinRunner(included=True) - self.builtin_runners['output_builtin'] = output_builtin + self.builtin_runners["output_builtin"] = output_builtin signature_builtin = SignatureBuiltinRunner( - name='ecdsa', included=True, ratio=None, process_signature=process_ecdsa, - verify_signature=verify_ecdsa_sig) - self.builtin_runners['ecdsa_builtin'] = signature_builtin - bitwise_builtin = BitwiseBuiltinRunner(included=True, bitwise_builtin=BitwiseInstanceDef( - ratio=None, total_n_bits=251)) - self.builtin_runners['bitwise_builtin'] = bitwise_builtin + name="ecdsa", + included=True, + ratio=None, + process_signature=process_ecdsa, + verify_signature=verify_ecdsa_sig, + ) + self.builtin_runners["ecdsa_builtin"] = signature_builtin + bitwise_builtin = BitwiseBuiltinRunner( + included=True, bitwise_builtin=BitwiseInstanceDef(ratio=None, total_n_bits=251) + ) + self.builtin_runners["bitwise_builtin"] = bitwise_builtin ec_op_builtin = EcOpBuiltinRunner( included=True, ec_op_builtin=EcOpInstanceDef( @@ -51,33 +59,33 @@ def __init__(self, *args, **kwargs): scalar_limit=None, ), ) - self.builtin_runners['ec_op_builtin'] = ec_op_builtin + self.builtin_runners["ec_op_builtin"] = ec_op_builtin self.initialize_segments() @property def pedersen_builtin(self) -> HashBuiltinRunner: - return cast(HashBuiltinRunner, self.builtin_runners['pedersen_builtin']) + return cast(HashBuiltinRunner, self.builtin_runners["pedersen_builtin"]) @property def range_check_builtin(self) -> RangeCheckBuiltinRunner: - return cast(RangeCheckBuiltinRunner, self.builtin_runners['range_check_builtin']) + return cast(RangeCheckBuiltinRunner, self.builtin_runners["range_check_builtin"]) @property def output_builtin(self) -> OutputBuiltinRunner: - return cast(OutputBuiltinRunner, self.builtin_runners['output_builtin']) + return cast(OutputBuiltinRunner, self.builtin_runners["output_builtin"]) @property def ecdsa_builtin(self) -> SignatureBuiltinRunner: - return cast(SignatureBuiltinRunner, self.builtin_runners['ecdsa_builtin']) + return cast(SignatureBuiltinRunner, self.builtin_runners["ecdsa_builtin"]) @property def bitwise_builtin(self) -> BitwiseBuiltinRunner: - return cast(BitwiseBuiltinRunner, self.builtin_runners['bitwise_builtin']) + return cast(BitwiseBuiltinRunner, self.builtin_runners["bitwise_builtin"]) @property def ec_op_builtin(self) -> EcOpBuiltinRunner: - return cast(EcOpBuiltinRunner, self.builtin_runners['ec_op_builtin']) + return cast(EcOpBuiltinRunner, self.builtin_runners["ec_op_builtin"]) def assert_eq(self, arg: MaybeRelocatable, expected_value, apply_modulo: bool = True): """ @@ -86,7 +94,7 @@ def assert_eq(self, arg: MaybeRelocatable, expected_value, apply_modulo: bool = and assert_eq is called recursively on all the items in expected_value. If apply_modulo=True, all the integers are taken modulo the program's prime. """ - assert isinstance(arg, (int, RelocatableValue)), f'Expecting MaybeRelocatable got {arg}' + assert isinstance(arg, (int, RelocatableValue)), f"Expecting MaybeRelocatable got {arg}" if isinstance(expected_value, Iterable): for idx, value in enumerate(expected_value): @@ -96,13 +104,20 @@ def assert_eq(self, arg: MaybeRelocatable, expected_value, apply_modulo: bool = if apply_modulo and isinstance(arg, int): expected_value = expected_value % self.program.prime - assert arg == expected_value, f'{arg} does not equal expected value {expected_value}.' + assert arg == expected_value, f"{arg} does not equal expected value {expected_value}." def run( - self, func_name: str, *args, hint_locals: Optional[Dict[str, Any]] = None, - static_locals: Optional[Dict[str, Any]] = None, - verify_secure: Optional[bool] = None, trace_on_failure: bool = False, - apply_modulo_to_args: Optional[bool] = None, use_full_name: bool = False, **kwargs): + self, + func_name: str, + *args, + hint_locals: Optional[Dict[str, Any]] = None, + static_locals: Optional[Dict[str, Any]] = None, + verify_secure: Optional[bool] = None, + trace_on_failure: bool = False, + apply_modulo_to_args: Optional[bool] = None, + use_full_name: bool = False, + **kwargs, + ): """ Runs func_name(*args). args are converted to Cairo-friendly ones using gen_arg. @@ -117,13 +132,15 @@ def run( assert isinstance(self.program, Program) structs_factory = CairoStructFactory.from_program(program=self.program) full_args_struct = structs_factory.build_func_args( - func=ScopedName.from_string(scope=func_name)) + func=ScopedName.from_string(scope=func_name) + ) all_args = full_args_struct(*args, **kwargs) entrypoint: Union[str, int] if use_full_name: identifier = self.program.identifiers.get_by_full_name( - name=ScopedName.from_string(scope=func_name)) + name=ScopedName.from_string(scope=func_name) + ) assert isinstance(identifier, LabelDefinition) entrypoint = identifier.pc else: @@ -131,22 +148,34 @@ def run( try: self.run_from_entrypoint( - entrypoint, *all_args, hint_locals=hint_locals, static_locals=static_locals, - verify_secure=verify_secure, apply_modulo_to_args=apply_modulo_to_args) + entrypoint, + *all_args, + hint_locals=hint_locals, + static_locals=static_locals, + verify_secure=verify_secure, + apply_modulo_to_args=apply_modulo_to_args, + ) except (VmException, SecurityError, AssertionError) as ex: if trace_on_failure: - print(f"""\ + print( + f"""\ Got {type(ex).__name__} exception during the execution of {func_name}: {str(ex)} -""") +""" + ) trace_runner(runner=self) raise def run_from_entrypoint( - self, entrypoint: Union[str, int], *args, hint_locals: Optional[Dict[str, Any]] = None, - static_locals: Optional[Dict[str, Any]] = None, - run_resources: Optional[RunResources] = None, verify_secure: Optional[bool] = None, - apply_modulo_to_args: Optional[bool] = None): + self, + entrypoint: Union[str, int], + *args, + hint_locals: Optional[Dict[str, Any]] = None, + static_locals: Optional[Dict[str, Any]] = None, + run_resources: Optional[RunResources] = None, + verify_secure: Optional[bool] = None, + apply_modulo_to_args: Optional[bool] = None, + ): """ Runs the program from the given entrypoint. diff --git a/src/starkware/cairo/common/cairo_keccak/keccak_utils.py b/src/starkware/cairo/common/cairo_keccak/keccak_utils.py index 9804478b..a6bde3cb 100644 --- a/src/starkware/cairo/common/cairo_keccak/keccak_utils.py +++ b/src/starkware/cairo/common/cairo_keccak/keccak_utils.py @@ -1,12 +1,16 @@ from typing import List -OFFSETS = list(zip(*[ - [0, 1, 62, 28, 27], - [36, 44, 6, 55, 20], - [3, 10, 43, 25, 39], - [41, 45, 15, 21, 8], - [18, 2, 61, 56, 14] -])) +OFFSETS = list( + zip( + *[ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], + ] + ) +) ROUND_CONSTANTS = [ 0x0000000000000001, @@ -40,7 +44,7 @@ def rot_left(x, n): """ Rotates a 64-bit number n bits to the left. """ - return ((x << n) & (2**64 - 1)) | (x >> (64 - n)) + return ((x << n) & (2 ** 64 - 1)) | (x >> (64 - n)) def keccak_round(a: List[List[int]], rc: int) -> List[List[int]]: @@ -56,9 +60,7 @@ def keccak_round(a: List[List[int]], rc: int) -> List[List[int]]: for y in range(5): b[y][(2 * x + 3 * y) % 5] = rot_left(a[x][y], OFFSETS[x][y]) - a = [ - [b[x][y] ^ ((~b[(x + 1) % 5][y]) & b[(x + 2) % 5][y]) for y in range(5)] - for x in range(5)] + a = [[b[x][y] ^ ((~b[(x + 1) % 5][y]) & b[(x + 2) % 5][y]) for y in range(5)] for x in range(5)] a[0][0] ^= rc return a diff --git a/src/starkware/cairo/common/cairo_sha256/sha256_utils.py b/src/starkware/cairo/common/cairo_sha256/sha256_utils.py new file mode 100644 index 00000000..a6a1c3be --- /dev/null +++ b/src/starkware/cairo/common/cairo_sha256/sha256_utils.py @@ -0,0 +1,128 @@ +from typing import List + +IV = [ + 0x6A09E667, + 0xBB67AE85, + 0x3C6EF372, + 0xA54FF53A, + 0x510E527F, + 0x9B05688C, + 0x1F83D9AB, + 0x5BE0CD19, +] + +ROUND_CONSTANTS = [ + 0x428A2F98, + 0x71374491, + 0xB5C0FBCF, + 0xE9B5DBA5, + 0x3956C25B, + 0x59F111F1, + 0x923F82A4, + 0xAB1C5ED5, + 0xD807AA98, + 0x12835B01, + 0x243185BE, + 0x550C7DC3, + 0x72BE5D74, + 0x80DEB1FE, + 0x9BDC06A7, + 0xC19BF174, + 0xE49B69C1, + 0xEFBE4786, + 0x0FC19DC6, + 0x240CA1CC, + 0x2DE92C6F, + 0x4A7484AA, + 0x5CB0A9DC, + 0x76F988DA, + 0x983E5152, + 0xA831C66D, + 0xB00327C8, + 0xBF597FC7, + 0xC6E00BF3, + 0xD5A79147, + 0x06CA6351, + 0x14292967, + 0x27B70A85, + 0x2E1B2138, + 0x4D2C6DFC, + 0x53380D13, + 0x650A7354, + 0x766A0ABB, + 0x81C2C92E, + 0x92722C85, + 0xA2BFE8A1, + 0xA81A664B, + 0xC24B8B70, + 0xC76C51A3, + 0xD192E819, + 0xD6990624, + 0xF40E3585, + 0x106AA070, + 0x19A4C116, + 0x1E376C08, + 0x2748774C, + 0x34B0BCB5, + 0x391C0CB3, + 0x4ED8AA4A, + 0x5B9CCA4F, + 0x682E6FF3, + 0x748F82EE, + 0x78A5636F, + 0x84C87814, + 0x8CC70208, + 0x90BEFFFA, + 0xA4506CEB, + 0xBEF9A3F7, + 0xC67178F2, +] + + +def right_rot(value, n): + return (value >> n) | ((value & (2 ** n - 1)) << (32 - n)) + + +def compute_message_schedule(message: List[int]) -> List[int]: + w = list(message) + assert len(w) == 16 + + for i in range(16, 64): + s0 = right_rot(w[i - 15], 7) ^ right_rot(w[i - 15], 18) ^ (w[i - 15] >> 3) + s1 = right_rot(w[i - 2], 17) ^ right_rot(w[i - 2], 19) ^ (w[i - 2] >> 10) + w.append((w[i - 16] + s0 + w[i - 7] + s1) % 2 ** 32) + + return w + + +def sha2_compress_function(state: List[int], w: List[int]) -> List[int]: + a, b, c, d, e, f, g, h = state + + for i in range(64): + s0 = right_rot(a, 2) ^ right_rot(a, 13) ^ right_rot(a, 22) + s1 = right_rot(e, 6) ^ right_rot(e, 11) ^ right_rot(e, 25) + ch = (e & f) ^ ((~e) & g) + temp1 = (h + s1 + ch + ROUND_CONSTANTS[i] + w[i]) % 2 ** 32 + maj = (a & b) ^ (a & c) ^ (b & c) + temp2 = (s0 + maj) % 2 ** 32 + + h = g + g = f + f = e + e = (d + temp1) % 2 ** 32 + d = c + c = b + b = a + a = (temp1 + temp2) % 2 ** 32 + + # Add the compression result to the original state. + return [ + (state[0] + a) % 2 ** 32, + (state[1] + b) % 2 ** 32, + (state[2] + c) % 2 ** 32, + (state[3] + d) % 2 ** 32, + (state[4] + e) % 2 ** 32, + (state[5] + f) % 2 ** 32, + (state[6] + g) % 2 ** 32, + (state[7] + h) % 2 ** 32, + ] diff --git a/src/starkware/cairo/common/dict.py b/src/starkware/cairo/common/dict.py index c21a8eee..3fb352fe 100644 --- a/src/starkware/cairo/common/dict.py +++ b/src/starkware/cairo/common/dict.py @@ -11,6 +11,7 @@ class DictTracker: """ Tracks the python dict associated with a Cairo dict. """ + # Python dict. data: dict # Pointer to the first unused position in the dict segment. @@ -35,8 +36,7 @@ def new_dict(self, segments, initial_dict): base = segments.add() assert base.segment_index not in self.trackers self.trackers[base.segment_index] = DictTracker( - data={ - key: segments.gen_arg(value) for key, value in initial_dict.items()}, + data={key: segments.gen_arg(value) for key, value in initial_dict.items()}, current_ptr=base, ) return base @@ -61,9 +61,10 @@ def get_tracker(self, dict_ptr): dict_ptr = dict_ptr.address_ dict_tracker = self.trackers.get(dict_ptr.segment_index) if dict_tracker is None: - raise ValueError(f'Dictionary pointer {dict_ptr} was not created using dict_new().') - assert dict_tracker.current_ptr == dict_ptr, 'Wrong dict pointer supplied. ' \ - f'Got {dict_ptr}, expected {dict_tracker.current_ptr}.' + raise ValueError(f"Dictionary pointer {dict_ptr} was not created using dict_new().") + assert dict_tracker.current_ptr == dict_ptr, ( + "Wrong dict pointer supplied. " f"Got {dict_ptr}, expected {dict_tracker.current_ptr}." + ) return dict_tracker def get_dict(self, dict_ptr) -> dict: diff --git a/src/starkware/cairo/common/math.cairo b/src/starkware/cairo/common/math.cairo index 14f14a6c..1f50cb23 100644 --- a/src/starkware/cairo/common/math.cairo +++ b/src/starkware/cairo/common/math.cairo @@ -292,3 +292,30 @@ func signed_div_rem{range_check_ptr}(value, div, bound) -> (q, r): assert_le(biased_q, 2 * bound - 1) return (q, r) end + +# Splits the given (unsigned) value into n "limbs", where each limb is in the range [0, bound), +# as follows: +# value = x[0] + x[1] * base + x[2] * base**2 + ... + x[n - 1] * base**(n - 1). +# bound must be less than the range check bound (2**128). +# Note that bound may be smaller than base, in which case the function will fail if there is a +# limb which is >= bound. +# Assumptions: +# 1 < bound <= base +# base**n < field characteristic. +func split_int{range_check_ptr}(value, n, base, bound, output : felt*): + if n == 0: + %{ assert ids.value == 0, 'split_int(): value is out of range.' %} + assert value = 0 + return () + end + + %{ + memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base + assert res < ids.bound, f'split_int(): Limb {res} is out of range.' + %} + tempvar low_part = [output] + assert_nn_le(low_part, bound - 1) + + return split_int( + value=(value - low_part) / base, n=n - 1, base=base, bound=bound, output=output + 1) +end diff --git a/src/starkware/cairo/common/math_utils.py b/src/starkware/cairo/common/math_utils.py index d9dee17f..97b76bd1 100644 --- a/src/starkware/cairo/common/math_utils.py +++ b/src/starkware/cairo/common/math_utils.py @@ -2,7 +2,7 @@ def assert_integer(val): """ Asserts that the input is an integer (and not relocatable value). """ - assert isinstance(val, int), f'Expected integer, found: {val}.' + assert isinstance(val, int), f"Expected integer, found: {val}." def as_int(val, prime): @@ -21,5 +21,5 @@ def is_positive(value, prime, rc_bound): Raises an exception if the element is not within that range. """ val = as_int(value, prime) - assert abs(val) < rc_bound, f'value={val} is out of the valid range.' + assert abs(val) < rc_bound, f"value={val} is out of the valid range." return val > 0 diff --git a/src/starkware/cairo/common/memset.cairo b/src/starkware/cairo/common/memset.cairo new file mode 100644 index 00000000..de97c0b2 --- /dev/null +++ b/src/starkware/cairo/common/memset.cairo @@ -0,0 +1,33 @@ +# Writes value into [dst + 0], ..., [dst + n - 1]. +func memset(dst : felt*, value : felt, n): + struct LoopFrame: + member dst : felt* + end + + if n == 0: + return () + end + + %{ vm_enter_scope({'n': ids.n}) %} + tempvar frame = LoopFrame(dst=dst) + + loop: + let frame = [cast(ap - LoopFrame.SIZE, LoopFrame*)] + assert [frame.dst] = value + + let continue_loop = [ap] + # Reserve space for continue_loop. + let next_frame = cast(ap + 1, LoopFrame*) + next_frame.dst = frame.dst + 1; ap++ + %{ + n -= 1 + ids.continue_loop = 1 if n > 0 else 0 + %} + static_assert next_frame + LoopFrame.SIZE == ap + 1 + jmp loop if continue_loop != 0; ap++ + # Assert that the loop executed n times. + n = cast(next_frame.dst, felt) - cast(dst, felt) + + %{ vm_exit_scope() %} + return () +end diff --git a/src/starkware/cairo/common/patricia_utils.py b/src/starkware/cairo/common/patricia_utils.py index 9284cfcc..38ee92ef 100644 --- a/src/starkware/cairo/common/patricia_utils.py +++ b/src/starkware/cairo/common/patricia_utils.py @@ -21,7 +21,7 @@ from starkware.cairo.lang.vm.crypto import pedersen_hash from starkware.python.math_utils import is_power_of_2 -from starkware.starkware_utils.binary_fact_tree_node import UpdateTree +from starkware.starkware_utils.commitment_tree.binary_fact_tree_node import UpdateTree Triplet = Tuple[int, int, int] @@ -64,7 +64,7 @@ def hash_node(e): if left == EMPTY and right == EMPTY: next_node = EMPTY elif left == EMPTY: - next_node = (r_len + 1, r_path + 2**r_len, r_bottom) + next_node = (r_len + 1, r_path + 2 ** r_len, r_bottom) elif right == EMPTY: next_node = (l_len + 1, l_path, l_bottom) else: @@ -73,7 +73,7 @@ def hash_node(e): layer = next_layer height += 1 - root, = layer + (root,) = layer node_at_path[height, 0] = root return hash_node(root), preimage, node_at_path @@ -177,7 +177,7 @@ def get_descents(height: int, path: int, nodes: List[NodeType]): res = {} # length <= 1 is not a descent. if length > 1: - res[orig_height, orig_path] = length, path % 2**length + res[orig_height, orig_path] = length, path % 2 ** length if height > 0: res.update(get_descents(height - 1, path * 2, lefts)) @@ -196,14 +196,17 @@ def compute_siblings_from_tree(height, node: UpdateTree, node_at_path, descent_m return [] left, right = node if left is None: - res = [hash_node(node_at_path[height - 1, path * 2])] + \ - compute_siblings_from_tree(height - 1, right, node_at_path, descent_map, path * 2 + 1) + res = [hash_node(node_at_path[height - 1, path * 2])] + compute_siblings_from_tree( + height - 1, right, node_at_path, descent_map, path * 2 + 1 + ) elif right is None: - res = [hash_node(node_at_path[height - 1, path * 2 + 1])] + \ - compute_siblings_from_tree(height - 1, left, node_at_path, descent_map, path * 2) + res = [hash_node(node_at_path[height - 1, path * 2 + 1])] + compute_siblings_from_tree( + height - 1, left, node_at_path, descent_map, path * 2 + ) else: - res = compute_siblings_from_tree(height - 1, left, node_at_path, descent_map, path * 2) + \ - compute_siblings_from_tree(height - 1, right, node_at_path, descent_map, path * 2 + 1) + res = compute_siblings_from_tree( + height - 1, left, node_at_path, descent_map, path * 2 + ) + compute_siblings_from_tree(height - 1, right, node_at_path, descent_map, path * 2 + 1) descend = descent_map.get((height, path)) if descend is None: diff --git a/src/starkware/cairo/common/small_merkle_tree.py b/src/starkware/cairo/common/small_merkle_tree.py index 99db5101..54958ee2 100644 --- a/src/starkware/cairo/common/small_merkle_tree.py +++ b/src/starkware/cairo/common/small_merkle_tree.py @@ -50,8 +50,11 @@ def compute_merkle_root(self, modifications: Collection[Tuple[int, int]]): def get_preimage_dictionary( - initial_leaves: Collection[Tuple[int, int]], modifications: Collection[Tuple[int, int]], - tree_height: int, default_leaf: int) -> Tuple[int, int, Dict[int, Tuple[int, int]]]: + initial_leaves: Collection[Tuple[int, int]], + modifications: Collection[Tuple[int, int]], + tree_height: int, + default_leaf: int, +) -> Tuple[int, int, Dict[int, Tuple[int, int]]]: """ Given a set of initial leaves and a set of modifications (both are maps from leaf index to value, where all the leaves in `modifications` appear diff --git a/src/starkware/cairo/common/structs.py b/src/starkware/cairo/common/structs.py index b8522c14..aac2c7c7 100644 --- a/src/starkware/cairo/common/structs.py +++ b/src/starkware/cairo/common/structs.py @@ -12,7 +12,8 @@ class CairoStructFactory: def __init__( - self, identifiers: IdentifierManager, additional_imports: Optional[List[str]] = None): + self, identifiers: IdentifierManager, additional_imports: Optional[List[str]] = None + ): """ Creates a CairoStructFactory that converts Cairo structs to python namedtuples. @@ -27,9 +28,7 @@ def __init__( for identifier_path in additional_imports: scope_name = ScopedName.from_string(identifier_path) # Call get_struct_definition to make sure scope_name is a struct. - get_struct_definition( - struct_name=scope_name, - identifier_manager=identifiers) + get_struct_definition(struct_name=scope_name, identifier_manager=identifiers) self.resolved_identifiers[scope_name[-1:]] = scope_name @classmethod @@ -42,8 +41,8 @@ def _get_full_name(self, name: ScopedName): return full_name return self.identifiers.search( - accessible_scopes=[ScopedName.from_string('__main__'), ScopedName()], - name=name).get_canonical_name() + accessible_scopes=[ScopedName.from_string("__main__"), ScopedName()], name=name + ).get_canonical_name() def get_struct_definition(self, name: ScopedName) -> StructDefinition: """ @@ -66,11 +65,12 @@ def build_func_args(self, func: ScopedName): full_name = self._get_full_name(func) implict_args = get_struct_definition( - full_name + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, - self.identifiers).members + full_name + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, self.identifiers + ).members args = get_struct_definition( - full_name + CodeElementFunction.ARGUMENT_SCOPE, self.identifiers).members - return namedtuple(f'{func[-1:]}_full_args', list({**implict_args, **args})) + full_name + CodeElementFunction.ARGUMENT_SCOPE, self.identifiers + ).members + return namedtuple(f"{func[-1:]}_full_args", list({**implict_args, **args})) @property def structs(self): @@ -90,7 +90,7 @@ def __init__(self, factory: CairoStructFactory, path: ScopedName): self.factory = factory self.path = path - def __getattr__(self, name: str) -> 'CairoStructProxy': + def __getattr__(self, name: str) -> "CairoStructProxy": return CairoStructProxy(self.factory, self.path + name) def build(self): @@ -114,6 +114,6 @@ def from_ptr(self, memory, addr): """ named_tuple = self.build() - return named_tuple(**{ - name: memory[addr + index] - for index, name in enumerate(named_tuple._fields)}) + return named_tuple( + **{name: memory[addr + index] for index, name in enumerate(named_tuple._fields)} + ) diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 1d0ba9ea..267577d4 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.4.0 +0.4.1 diff --git a/src/starkware/cairo/lang/builtins/bitwise/bitwise_builtin_runner.py b/src/starkware/cairo/lang/builtins/bitwise/bitwise_builtin_runner.py index 67bd46bd..760b5697 100644 --- a/src/starkware/cairo/lang/builtins/bitwise/bitwise_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/bitwise/bitwise_builtin_runner.py @@ -1,7 +1,10 @@ from typing import Any, Dict, Optional from starkware.cairo.lang.builtins.bitwise.instance_def import ( - CELLS_PER_BITWISE, INPUT_CELLS_PER_BITWISE, BitwiseInstanceDef) + CELLS_PER_BITWISE, + INPUT_CELLS_PER_BITWISE, + BitwiseInstanceDef, +) from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner from starkware.cairo.lang.vm.relocatable import RelocatableValue @@ -9,11 +12,12 @@ class BitwiseBuiltinRunner(SimpleBuiltinRunner): def __init__(self, included: bool, bitwise_builtin: BitwiseInstanceDef): super().__init__( - name='bitwise', + name="bitwise", included=included, ratio=None if bitwise_builtin is None else bitwise_builtin.ratio, cells_per_instance=CELLS_PER_BITWISE, - n_input_cells=INPUT_CELLS_PER_BITWISE) + n_input_cells=INPUT_CELLS_PER_BITWISE, + ) self.stop_ptr: Optional[RelocatableValue] = None self.bitwise_builtin: BitwiseInstanceDef = bitwise_builtin @@ -27,18 +31,22 @@ def rule(vm, addr): y_addr = x_addr + 1 if x_addr not in memory or y_addr not in memory: return - assert vm.is_integer_value(memory[x_addr]), \ - f'{self.name} builtin: Expected integer at address {x_addr}. ' + \ - f'Got: {memory[x_addr]}.' - assert memory[x_addr] < 2**self.bitwise_builtin.total_n_bits, \ - f'{self.name} builtin: Expected integer at address {x_addr} to be smaller than ' + \ - f'2^{self.bitwise_builtin.total_n_bits}. Got: {memory[x_addr]}.' - assert vm.is_integer_value(memory[y_addr]), \ - f'{self.name} builtin: Expected integer at address {y_addr}. ' + \ - f'Got: {memory[y_addr]}.' - assert memory[y_addr] < 2**self.bitwise_builtin.total_n_bits, \ - f'{self.name} builtin: Expected integer at address {y_addr} to be smaller than ' + \ - f'2^{self.bitwise_builtin.total_n_bits}. Got: {memory[y_addr]}.' + assert vm.is_integer_value(memory[x_addr]), ( + f"{self.name} builtin: Expected integer at address {x_addr}. " + + f"Got: {memory[x_addr]}." + ) + assert memory[x_addr] < 2 ** self.bitwise_builtin.total_n_bits, ( + f"{self.name} builtin: Expected integer at address {x_addr} to be smaller than " + + f"2^{self.bitwise_builtin.total_n_bits}. Got: {memory[x_addr]}." + ) + assert vm.is_integer_value(memory[y_addr]), ( + f"{self.name} builtin: Expected integer at address {y_addr}. " + + f"Got: {memory[y_addr]}." + ) + assert memory[y_addr] < 2 ** self.bitwise_builtin.total_n_bits, ( + f"{self.name} builtin: Expected integer at address {y_addr} to be smaller than " + + f"2^{self.bitwise_builtin.total_n_bits}. Got: {memory[y_addr]}." + ) if index == 2: res = memory[x_addr] & memory[y_addr] elif index == 3: @@ -50,32 +58,41 @@ def rule(vm, addr): runner.vm.add_auto_deduction_rule(self.base.segment_index, rule) def air_private_input(self, runner) -> Dict[str, Any]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." res: Dict[int, Any] = {} for addr, val in runner.vm_memory.items(): - if not isinstance(addr, RelocatableValue) or \ - addr.segment_index != self.base.segment_index: + if ( + not isinstance(addr, RelocatableValue) + or addr.segment_index != self.base.segment_index + ): continue idx, typ = divmod(addr.offset, CELLS_PER_BITWISE) if typ >= 2: continue assert isinstance(val, int) - res.setdefault(idx, {'index': idx})['x' if typ == 0 else 'y'] = hex(val) + res.setdefault(idx, {"index": idx})["x" if typ == 0 else "y"] = hex(val) for index, item in res.items(): - assert 'x' in item, f'Missing first input of bitwise instance {index}.' - assert 'y' in item, f'Missing second input of bitwise instance {index}.' + assert "x" in item, f"Missing first input of bitwise instance {index}." + assert "y" in item, f"Missing second input of bitwise instance {index}." - return {'bitwise': sorted(res.values(), key=lambda item: item['index'])} + return {"bitwise": sorted(res.values(), key=lambda item: item["index"])} def get_used_diluted_check_units(self, diluted_spacing: int, diluted_n_bits: int) -> int: total_n_bits = self.bitwise_builtin.total_n_bits partition = [ - i + j for i in range(0, total_n_bits, diluted_spacing * diluted_n_bits) - for j in range(diluted_spacing) if i + j < total_n_bits] - num_trimmed = len([ - 1 for shift in partition - if shift + diluted_spacing * (diluted_n_bits - 1) + 1 > total_n_bits]) + i + j + for i in range(0, total_n_bits, diluted_spacing * diluted_n_bits) + for j in range(diluted_spacing) + if i + j < total_n_bits + ] + num_trimmed = len( + [ + 1 + for shift in partition + if shift + diluted_spacing * (diluted_n_bits - 1) + 1 > total_n_bits + ] + ) return 4 * len(partition) + num_trimmed diff --git a/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py b/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py index e67c239e..477e1ecf 100644 --- a/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py +++ b/src/starkware/cairo/lang/builtins/builtin_runner_test_utils.py @@ -1,7 +1,7 @@ from starkware.cairo.lang.compiler.cairo_compile import compile_cairo from starkware.cairo.lang.vm.cairo_runner import CairoRunner -PRIME = 2**251 + 17 * 2**192 + 1 +PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 def compile_and_run(code: str): @@ -9,7 +9,7 @@ def compile_and_run(code: str): Compiles the given code and runs it in the VM. """ program = compile_cairo(code, PRIME) - runner = CairoRunner(program, layout='small', proof_mode=False) + runner = CairoRunner(program, layout="small", proof_mode=False) runner.initialize_segments() end = runner.initialize_main_entrypoint() runner.initialize_vm({}) diff --git a/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py b/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py index 4e011d3e..4a7bca26 100644 --- a/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/ec/ec_op_builtin_runner.py @@ -1,7 +1,10 @@ from typing import Any, Dict, Optional, Tuple, Union from starkware.cairo.lang.builtins.ec.instance_def import ( - CELLS_PER_EC_OP, INPUT_CELLS_PER_EC_OP, EcOpInstanceDef) + CELLS_PER_EC_OP, + INPUT_CELLS_PER_EC_OP, + EcOpInstanceDef, +) from starkware.cairo.lang.vm.builtin_runner import SimpleBuiltinRunner from starkware.cairo.lang.vm.relocatable import RelocatableValue from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME @@ -11,7 +14,7 @@ EC_POINT_INDICES = [(0, 1), (2, 3), (5, 6)] M_INDEX = 4 OUTPUT_INDICES = EC_POINT_INDICES[2] -INPUT_NAMES = ['p_x', 'p_y', 'q_x', 'q_y', 'm'] +INPUT_NAMES = ["p_x", "p_y", "q_x", "q_y", "m"] assert INPUT_CELLS_PER_EC_OP == len(INPUT_NAMES) assert INPUT_CELLS_PER_EC_OP + len(OUTPUT_INDICES) == CELLS_PER_EC_OP @@ -26,8 +29,8 @@ def point_on_curve(x: int, y: int, alpha: int, beta: int, p: int) -> bool: def ec_op_impl( - p_x: int, p_y: int, q_x: int, q_y: int, m: int, alpha: int, - p: int) -> Union[Tuple[int, int], str]: + p_x: int, p_y: int, q_x: int, q_y: int, m: int, alpha: int, p: int +) -> Union[Tuple[int, int], str]: """ Returns the result of the EC operation P + m * Q. where P = (p_x, p_y), Q = (q_x, q_y) are points on the elliptic curve defined as @@ -39,11 +42,12 @@ def ec_op_impl( class EcOpBuiltinRunner(SimpleBuiltinRunner): def __init__(self, included: bool, ec_op_builtin: EcOpInstanceDef): super().__init__( - name='ec_op', + name="ec_op", included=included, ratio=None if ec_op_builtin is None else ec_op_builtin.ratio, cells_per_instance=CELLS_PER_EC_OP, - n_input_cells=INPUT_CELLS_PER_EC_OP) + n_input_cells=INPUT_CELLS_PER_EC_OP, + ) self.stop_ptr: Optional[RelocatableValue] = None self.ec_op_builtin: EcOpInstanceDef = ec_op_builtin @@ -61,46 +65,52 @@ def rule(vm, addr): # Assert that m <= scalar_limit. if self.ec_op_builtin.scalar_limit is not None: - assert memory[instance + M_INDEX] <= self.ec_op_builtin.scalar_limit,\ - f'{self.name} builtin: m must be at most {self.ec_op_builtin.scalar_limit}.' + assert ( + memory[instance + M_INDEX] <= self.ec_op_builtin.scalar_limit + ), f"{self.name} builtin: m must be at most {self.ec_op_builtin.scalar_limit}." for i in range(INPUT_CELLS_PER_EC_OP): - assert vm.is_integer_value(memory[instance + i]), \ - f'{self.name} builtin: Expected integer at address {instance + i}.' \ - f'Got: {memory[instance + i]}.' + assert vm.is_integer_value(memory[instance + i]), ( + f"{self.name} builtin: Expected integer at address {instance + i}." + f"Got: {memory[instance + i]}." + ) # Assert that if the current address is part of a point which is all set in the # memory, the point is on the curve. for pair in EC_POINT_INDICES[:2]: ec_point = [memory[instance + i] for i in pair] - assert point_on_curve(*ec_point, ALPHA, BETA, FIELD_PRIME), \ - f'{self.name} builtin: point {pair} is not on the curve.' + assert point_on_curve( + *ec_point, ALPHA, BETA, FIELD_PRIME + ), f"{self.name} builtin: point {pair} is not on the curve." res = ec_op_impl( - *[memory[instance + i] for i in range(INPUT_CELLS_PER_EC_OP)], ALPHA, FIELD_PRIME) + *[memory[instance + i] for i in range(INPUT_CELLS_PER_EC_OP)], ALPHA, FIELD_PRIME + ) # The result cannot be the point at infinity. - assert res != EC_INFINITY, 'The result cannot be the point at infinity.' + assert res != EC_INFINITY, "The result cannot be the point at infinity." return res[index - INPUT_CELLS_PER_EC_OP] runner.vm.add_auto_deduction_rule(self.base.segment_index, rule) def air_private_input(self, runner) -> Dict[str, Any]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." res: Dict[int, Any] = {} for addr, val in runner.vm_memory.items(): - if not isinstance(addr, RelocatableValue) or \ - addr.segment_index != self.base.segment_index: + if ( + not isinstance(addr, RelocatableValue) + or addr.segment_index != self.base.segment_index + ): continue idx, typ = divmod(addr.offset, CELLS_PER_EC_OP) if typ >= INPUT_CELLS_PER_EC_OP: continue assert isinstance(val, int) - res.setdefault(idx, {'index': idx})[INPUT_NAMES[typ]] = hex(val) + res.setdefault(idx, {"index": idx})[INPUT_NAMES[typ]] = hex(val) for index, item in res.items(): for name in INPUT_NAMES: assert name in item, f"Missing input '{name}' of {self.name} instance {index}." - return {self.name: sorted(res.values(), key=lambda item: item['index'])} + return {self.name: sorted(res.values(), key=lambda item: item["index"])} diff --git a/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py b/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py index cd9db059..ef12ae87 100644 --- a/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/hash/hash_builtin_runner.py @@ -13,7 +13,8 @@ def __init__(self, name: str, included: bool, ratio: int, hash_func): included=included, ratio=ratio, cells_per_instance=CELLS_PER_HASH, - n_input_cells=INPUT_CELLS_PER_HASH) + n_input_cells=INPUT_CELLS_PER_HASH, + ) self.hash_func = hash_func self.stop_ptr: Optional[RelocatableValue] = None self.verified_addresses: Set[MaybeRelocatable] = set() @@ -27,12 +28,14 @@ def rule(vm, addr, verified_addresses): return if addr - 1 not in memory or addr - 2 not in memory: return - assert vm.is_integer_value(memory[addr - 2]), \ - f'{self.name} builtin: Expected integer at address {addr - 2}. ' + \ - f'Got: {memory[addr - 2]}.' - assert vm.is_integer_value(memory[addr - 1]), \ - f'{self.name} builtin: Expected integer at address {addr - 1}. ' + \ - f'Got: {memory[addr - 1]}.' + assert vm.is_integer_value(memory[addr - 2]), ( + f"{self.name} builtin: Expected integer at address {addr - 2}. " + + f"Got: {memory[addr - 2]}." + ) + assert vm.is_integer_value(memory[addr - 1]), ( + f"{self.name} builtin: Expected integer at address {addr - 1}. " + + f"Got: {memory[addr - 1]}." + ) res = self.hash_func(memory[addr - 2], memory[addr - 1]) verified_addresses.add(addr) return res @@ -40,24 +43,26 @@ def rule(vm, addr, verified_addresses): runner.vm.add_auto_deduction_rule(self.base.segment_index, rule, self.verified_addresses) def air_private_input(self, runner) -> Dict[str, Any]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." res: Dict[int, Any] = {} for addr, val in runner.vm_memory.items(): - if not isinstance(addr, RelocatableValue) or \ - addr.segment_index != self.base.segment_index: + if ( + not isinstance(addr, RelocatableValue) + or addr.segment_index != self.base.segment_index + ): continue idx, typ = divmod(addr.offset, CELLS_PER_HASH) if typ == 2: continue assert isinstance(val, int) - res.setdefault(idx, {'index': idx})['x' if typ == 0 else 'y'] = hex(val) + res.setdefault(idx, {"index": idx})["x" if typ == 0 else "y"] = hex(val) for index, item in res.items(): - assert 'x' in item, f'Missing first input of {self.name} instance {index}.' - assert 'y' in item, f'Missing second input of {self.name} instance {index}.' + assert "x" in item, f"Missing first input of {self.name} instance {index}." + assert "y" in item, f"Missing second input of {self.name} instance {index}." - return {self.name: sorted(res.values(), key=lambda item: item['index'])} + return {self.name: sorted(res.values(), key=lambda item: item["index"])} def get_additional_data(self): return [list(RelocatableValue.to_tuple(x)) for x in sorted(self.verified_addresses)] @@ -79,8 +84,13 @@ def expected_stack(self, public_input): if not self.included: return [], [] - addresses = public_input.memory_segments['pedersen'] + addresses = public_input.memory_segments["pedersen"] max_size = CELLS_PER_HASH * safe_div(public_input.n_steps, self.ratio) - assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ - addresses.begin_addr + max_size < 2**64 + assert ( + 0 + <= addresses.begin_addr + <= addresses.stop_ptr + <= addresses.begin_addr + max_size + < 2 ** 64 + ) return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py index ee2f890b..b53208d3 100644 --- a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner.py @@ -8,11 +8,12 @@ class RangeCheckBuiltinRunner(SimpleBuiltinRunner): def __init__(self, included: bool, ratio, inner_rc_bound, n_parts): super().__init__( - name='range_check', + name="range_check", included=included, ratio=ratio, cells_per_instance=1, - n_input_cells=1) + n_input_cells=1, + ) self.inner_rc_bound = inner_rc_bound self.bound = inner_rc_bound ** n_parts self.n_parts = n_parts @@ -20,40 +21,46 @@ def __init__(self, included: bool, ratio, inner_rc_bound, n_parts): def add_validation_rules(self, runner): def rule(memory, addr): value = memory[addr] - assert isinstance(value, int), \ - f'Range-check builtin: Expected value at address {addr} to be an integer. ' \ - f'Got: {value}.' + assert isinstance(value, int), ( + f"Range-check builtin: Expected value at address {addr} to be an integer. " + f"Got: {value}." + ) # The range check builtin asserts that 0 <= value < BOUND. # For example, if the layout uses 8 16-bit range-checks per instance, # bound will be 2**(16 * 8) = 2**128. - assert 0 <= value < self.bound, \ - f'Value {value}, in range check builtin {addr - self.base}, is out of range ' \ - f'[0, {self.bound}).' + assert 0 <= value < self.bound, ( + f"Value {value}, in range check builtin {addr - self.base}, is out of range " + f"[0, {self.bound})." + ) return {addr} runner.vm.add_validation_rule(self.base.segment_index, rule) def air_private_input(self, runner) -> Dict[str, Any]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." res: Dict[int, Any] = {} for addr, val in runner.vm_memory.items(): - if not isinstance(addr, RelocatableValue) or \ - addr.segment_index != self.base.segment_index: + if ( + not isinstance(addr, RelocatableValue) + or addr.segment_index != self.base.segment_index + ): continue idx = addr.offset assert isinstance(val, int) - res[idx] = {'index': idx, 'value': hex(val)} + res[idx] = {"index": idx, "value": hex(val)} - return {'range_check': sorted(res.values(), key=lambda item: item['index'])} + return {"range_check": sorted(res.values(), key=lambda item: item["index"])} def get_range_check_usage(self, runner) -> Optional[Tuple[int, int]]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." rc_min = None rc_max = None for addr, val in runner.vm_memory.items(): - if not isinstance(addr, RelocatableValue) or \ - addr.segment_index != self.base.segment_index: + if ( + not isinstance(addr, RelocatableValue) + or addr.segment_index != self.base.segment_index + ): continue # Split val into n_parts parts. @@ -85,8 +92,13 @@ def expected_stack(self, public_input): if not self.included: return [], [] - addresses = public_input.memory_segments['range_check'] + addresses = public_input.memory_segments["range_check"] max_size = safe_div(public_input.n_steps, self.ratio) - assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ - addresses.begin_addr + max_size < 2**64 + assert ( + 0 + <= addresses.begin_addr + <= addresses.stop_ptr + <= addresses.begin_addr + max_size + < 2 ** 64 + ) return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py index dc170187..91ddd680 100644 --- a/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py +++ b/src/starkware/cairo/lang/builtins/range_check/range_check_builtin_runner_test.py @@ -19,12 +19,14 @@ def test_validation_rules(): compile_and_run(CODE_FORMAT.format(value=1)) with pytest.raises( - VmException, - match=f'Value {PRIME - 1}, in range check builtin 0, is out of range ' - r'\[0, {bound}\)'.format(bound=2**128)): + VmException, + match=f"Value {PRIME - 1}, in range check builtin 0, is out of range " + r"\[0, {bound}\)".format(bound=2 ** 128), + ): compile_and_run(CODE_FORMAT.format(value=-1)) with pytest.raises( - VmException, - match=f'Range-check builtin: Expected value at address 2:0 to be an integer. Got: 2:0'): - compile_and_run(CODE_FORMAT.format(value='range_check_ptr')) + VmException, + match=f"Range-check builtin: Expected value at address 2:0 to be an integer. Got: 2:0", + ): + compile_and_run(CODE_FORMAT.format(value="range_check_ptr")) diff --git a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py index b8d6eafb..1ca08062 100644 --- a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py +++ b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner.py @@ -1,7 +1,9 @@ from typing import Any, Dict from starkware.cairo.lang.builtins.signature.instance_def import ( - CELLS_PER_SIGNATURE, INPUT_CELLS_PER_SIGNATURE) + CELLS_PER_SIGNATURE, + INPUT_CELLS_PER_SIGNATURE, +) from starkware.cairo.lang.vm.builtin_runner import BuiltinVerifier, SimpleBuiltinRunner from starkware.cairo.lang.vm.relocatable import RelocatableValue from starkware.python.math_utils import safe_div @@ -20,7 +22,8 @@ def __init__(self, name: str, included: bool, ratio, process_signature, verify_s included=included, ratio=ratio, cells_per_instance=CELLS_PER_SIGNATURE, - n_input_cells=INPUT_CELLS_PER_SIGNATURE) + n_input_cells=INPUT_CELLS_PER_SIGNATURE, + ) self.process_signature = process_signature self.verify_signature = verify_signature @@ -41,20 +44,24 @@ def rule(memory, addr): pubkey = memory[pubkey_addr] msg = memory[msg_addr] - assert isinstance(pubkey, int), \ - f'ECDSA builtin: Expected public key at address {pubkey_addr} to be an integer. ' \ - f'Got: {pubkey}.' - assert isinstance(msg, int), \ - f'ECDSA builtin: Expected message hash at address {msg_addr} to be an integer. ' \ - f'Got: {msg}.' - assert pubkey_addr in self.signatures, \ - f'Signature hint is missing for ECDSA builtin at address {pubkey_addr}. ' \ + assert isinstance(pubkey, int), ( + f"ECDSA builtin: Expected public key at address {pubkey_addr} to be an integer. " + f"Got: {pubkey}." + ) + assert isinstance(msg, int), ( + f"ECDSA builtin: Expected message hash at address {msg_addr} to be an integer. " + f"Got: {msg}." + ) + assert pubkey_addr in self.signatures, ( + f"Signature hint is missing for ECDSA builtin at address {pubkey_addr}. " "Add it using 'ecdsa_builtin.add_signature'." + ) signature = self.signatures[pubkey_addr] - assert self.verify_signature(pubkey, msg, signature), \ - f'Signature {signature}, is invalid, with respect to the public key {pubkey}, ' \ - f'and the message hash {msg}.' + assert self.verify_signature(pubkey, msg, signature), ( + f"Signature {signature}, is invalid, with respect to the public key {pubkey}, " + f"and the message hash {msg}." + ) return {pubkey_addr, msg_addr} runner.vm.add_validation_rule(self.base.segment_index, rule) @@ -67,28 +74,31 @@ def air_private_input(self, runner) -> Dict[str, Any]: pubkey = runner.vm_memory[addr] msg = runner.vm_memory[addr + 1] res[idx] = { - 'index': idx, - 'pubkey': hex(pubkey), - 'msg': hex(msg), - 'signature_input': self.process_signature(pubkey, msg, signature), + "index": idx, + "pubkey": hex(pubkey), + "msg": hex(msg), + "signature_input": self.process_signature(pubkey, msg, signature), } - return {self.name: sorted(res.values(), key=lambda item: item['index'])} + return {self.name: sorted(res.values(), key=lambda item: item["index"])} def add_signature(self, addr, signature): """ This function should be used in Cairo hints. """ - assert isinstance(addr, RelocatableValue), \ - f'Expected memory address to be relocatable value. Found: {addr}.' - assert addr.offset % CELLS_PER_SIGNATURE == 0, \ - f'Signature hint must point to the public key cell, not {addr}.' + assert isinstance( + addr, RelocatableValue + ), f"Expected memory address to be relocatable value. Found: {addr}." + assert ( + addr.offset % CELLS_PER_SIGNATURE == 0 + ), f"Signature hint must point to the public key cell, not {addr}." self.signatures[addr] = signature def get_additional_data(self): return [ [list(RelocatableValue.to_tuple(addr)), signature] - for addr, signature in sorted(self.signatures.items())] + for addr, signature in sorted(self.signatures.items()) + ] def extend_additional_data(self, data, relocate_callback, data_is_trusted=True): for addr, signature in data: @@ -104,8 +114,13 @@ def expected_stack(self, public_input): if not self.included: return [], [] - addresses = public_input.memory_segments['signature'] + addresses = public_input.memory_segments["signature"] max_size = safe_div(public_input.n_steps, self.ratio) * CELLS_PER_SIGNATURE - assert 0 <= addresses.begin_addr <= addresses.stop_ptr <= \ - addresses.begin_addr + max_size < 2**64 + assert ( + 0 + <= addresses.begin_addr + <= addresses.stop_ptr + <= addresses.begin_addr + max_size + < 2 ** 64 + ) return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py index a7c656e9..434615d0 100644 --- a/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py +++ b/src/starkware/cairo/lang/builtins/signature/signature_builtin_runner_test.py @@ -15,6 +15,7 @@ class SignatureCodeSections: Code sections relevant for using the signature builtin. See code snippet structure below. """ + hint: str write_pubkey: str write_msg: str @@ -29,21 +30,22 @@ class SignatureExample: # Constants used for creating a code snippet using the signature builtin. # See signature_builtin_runner_test.py. -SIG_PTR = 'ecdsa_ptr' +SIG_PTR = "ecdsa_ptr" formats = SimpleNamespace( - hint_code_format='%{{ ecdsa_builtin.add_signature({addr}, {signature}) %}}', - pubkey_code_format=f'assert [{SIG_PTR} + SignatureBuiltin.pub_key] = {{pubkey}}', - msg_code_format=f'assert [{SIG_PTR} + SignatureBuiltin.message] = {{msg}}', + hint_code_format="%{{ ecdsa_builtin.add_signature({addr}, {signature}) %}}", + pubkey_code_format=f"assert [{SIG_PTR} + SignatureBuiltin.pub_key] = {{pubkey}}", + msg_code_format=f"assert [{SIG_PTR} + SignatureBuiltin.message] = {{msg}}", ) # The address is used inside a hint. -VALID_ADDR = f'ids.{SIG_PTR}' +VALID_ADDR = f"ids.{SIG_PTR}" VALID_SIG = ( 3086480810278599376317923499561306189851900463386393948998357832163236918254, - 598673427589502599949712887611119751108407514580626464031881322743364689811) + 598673427589502599949712887611119751108407514580626464031881322743364689811, +) constants = SimpleNamespace( valid_addr=VALID_ADDR, - invalid_addr=VALID_ADDR + ' + 1', + invalid_addr=VALID_ADDR + " + 1", valid_sig=VALID_SIG, invalid_sig=(VALID_SIG[0] + 1, VALID_SIG[1]), valid_pubkey=1735102664668487605176656616876767369909409133946409161569774794110049207117, @@ -60,16 +62,18 @@ class SignatureTest: """ def __init__(self): - self.test_cases = {'valid': SignatureExample( - error_msg=None, - code_sections=SignatureCodeSections( - hint=formats.hint_code_format.format( - addr=constants.valid_addr, signature=constants.valid_sig), - write_pubkey=formats.pubkey_code_format.format( - pubkey=constants.valid_pubkey), - write_msg=formats.msg_code_format.format(msg=constants.valid_msg), + self.test_cases = { + "valid": SignatureExample( + error_msg=None, + code_sections=SignatureCodeSections( + hint=formats.hint_code_format.format( + addr=constants.valid_addr, signature=constants.valid_sig + ), + write_pubkey=formats.pubkey_code_format.format(pubkey=constants.valid_pubkey), + write_msg=formats.msg_code_format.format(msg=constants.valid_msg), + ), ) - )} + } def add_test_case(self, name: str, error_msg: Optional[str], **code_section_changes): """ @@ -78,7 +82,8 @@ def add_test_case(self, name: str, error_msg: Optional[str], **code_section_chan """ self.test_cases[name] = SignatureExample( code_sections=dataclasses.replace( - self.test_cases['valid'].code_sections, **code_section_changes), + self.test_cases["valid"].code_sections, **code_section_changes + ), error_msg=error_msg, ) @@ -101,53 +106,57 @@ def get_test_cases(self): test = SignatureTest() test.add_test_case( - name='invalid_signature_address', - error_msg='Signature hint must point to the public key cell, not 2:1.', + name="invalid_signature_address", + error_msg="Signature hint must point to the public key cell, not 2:1.", hint=formats.hint_code_format.format( - addr=constants.invalid_addr, signature=constants.valid_sig), + addr=constants.invalid_addr, signature=constants.valid_sig + ), ) test.add_test_case( - name='invalid_signature', + name="invalid_signature", error_msg=( - r'Signature .* is invalid, with respect to the public key ' - '1735102664668487605176656616876767369909409133946409161569774794110049207117, ' - 'and the message hash 2718.'), + r"Signature .* is invalid, with respect to the public key " + "1735102664668487605176656616876767369909409133946409161569774794110049207117, " + "and the message hash 2718." + ), hint=formats.hint_code_format.format( - addr=constants.valid_addr, signature=constants.invalid_sig), + addr=constants.valid_addr, signature=constants.invalid_sig + ), ) test.add_test_case( - name='invalid_public_key', - error_msg='ECDSA builtin: Expected public key at address 2:0 to be an integer. Got: 2:0.', + name="invalid_public_key", + error_msg="ECDSA builtin: Expected public key at address 2:0 to be an integer. Got: 2:0.", write_pubkey=formats.pubkey_code_format.format(pubkey=constants.invalid_pubkey_or_msg), ) test.add_test_case( - name='invalid_message', - error_msg='ECDSA builtin: Expected message hash at address 2:1 to be an integer. Got: 2:0.', + name="invalid_message", + error_msg="ECDSA builtin: Expected message hash at address 2:1 to be an integer. Got: 2:0.", write_msg=formats.msg_code_format.format(msg=constants.invalid_pubkey_or_msg), ) test.add_test_case( - name='missing_hint', + name="missing_hint", error_msg=( - 'Signature hint is missing for ECDSA builtin at address 2:0. ' - "Add it using 'ecdsa_builtin.add_signature'."), - hint='', + "Signature hint is missing for ECDSA builtin at address 2:0. " + "Add it using 'ecdsa_builtin.add_signature'." + ), + hint="", ) # Missing public key or message would not cause a runtime error, but would fail the prover. -test.add_test_case(name='missing_public_key', error_msg=None, write_pubkey='') -test.add_test_case(name='missing_message', error_msg=None, write_msg='') +test.add_test_case(name="missing_public_key", error_msg=None, write_pubkey="") +test.add_test_case(name="missing_message", error_msg=None, write_msg="") test_cases = test.get_test_cases() -@pytest.mark.parametrize('case', test_cases.values(), ids=test_cases.keys()) +@pytest.mark.parametrize("case", test_cases.values(), ids=test_cases.keys()) def test_validation_rules(case): code = CODE.format(**dataclasses.asdict(case.code_sections)) with maybe_raises( - expected_exception=VmException, error_message=case.error_msg, - escape_error_message=False): + expected_exception=VmException, error_message=case.error_msg, escape_error_message=False + ): compile_and_run(code) diff --git a/src/starkware/cairo/lang/cairo_constants.py b/src/starkware/cairo/lang/cairo_constants.py index 84f21458..e9550db7 100644 --- a/src/starkware/cairo/lang/cairo_constants.py +++ b/src/starkware/cairo/lang/cairo_constants.py @@ -1 +1 @@ -DEFAULT_PRIME = 2**251 + 17 * 2**192 + 1 +DEFAULT_PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 diff --git a/src/starkware/cairo/lang/compiler/CMakeLists.txt b/src/starkware/cairo/lang/compiler/CMakeLists.txt index 2ba0c24a..95e1bd30 100644 --- a/src/starkware/cairo/lang/compiler/CMakeLists.txt +++ b/src/starkware/cairo/lang/compiler/CMakeLists.txt @@ -66,6 +66,7 @@ python_lib(cairo_compile_lib type_casts.py type_system_visitor.py type_system.py + type_utils.py LIBS cairo_constants_lib @@ -152,6 +153,7 @@ full_python_test(cairo_compile_test scoped_name_test.py type_casts_test.py type_system_visitor_test.py + type_utils_test.py LIBS cairo_compile_lib diff --git a/src/starkware/cairo/lang/compiler/assembler.py b/src/starkware/cairo/lang/compiler/assembler.py index 808346ef..928f5249 100644 --- a/src/starkware/cairo/lang/compiler/assembler.py +++ b/src/starkware/cairo/lang/compiler/assembler.py @@ -9,37 +9,52 @@ def assemble( - preprocessed_program: PreprocessedProgram, main_scope: ScopedName = ScopedName(), - add_debug_info: bool = False, file_contents_for_debug_info: Dict[str, str] = {}) -> Program: + preprocessed_program: PreprocessedProgram, + main_scope: ScopedName = ScopedName(), + add_debug_info: bool = False, + file_contents_for_debug_info: Dict[str, str] = {}, +) -> Program: data: List[int] = [] hints: Dict[int, List[CairoHint]] = {} - debug_info = DebugInfo(instruction_locations={}, file_contents=file_contents_for_debug_info) \ - if add_debug_info else None + debug_info = ( + DebugInfo(instruction_locations={}, file_contents=file_contents_for_debug_info) + if add_debug_info + else None + ) for inst in preprocessed_program.instructions: for hint, hint_flow_tracking_data in inst.hints: - hints.setdefault(len(data), []).append(CairoHint( - code=hint.hint_code, - accessible_scopes=inst.accessible_scopes, - flow_tracking_data=hint_flow_tracking_data)) + hints.setdefault(len(data), []).append( + CairoHint( + code=hint.hint_code, + accessible_scopes=inst.accessible_scopes, + flow_tracking_data=hint_flow_tracking_data, + ) + ) if debug_info is not None and inst.instruction.location is not None: hint_locations: List[Optional[HintLocation]] = [] for hint, _ in inst.hints: if hint.location is None: hint_locations.append(None) else: - hint_locations.append(HintLocation( - location=hint.location, - n_prefix_newlines=hint.hint.n_prefix_newlines, - )) - debug_info.instruction_locations[len(data)] = \ - InstructionLocation( - inst=inst.instruction.location, - hints=hint_locations, - accessible_scopes=inst.accessible_scopes, - flow_tracking_data=inst.flow_tracking_data) - data += [word for word in encode_instruction( - build_instruction(inst.instruction), prime=preprocessed_program.prime)] + hint_locations.append( + HintLocation( + location=hint.location, + n_prefix_newlines=hint.hint.n_prefix_newlines, + ) + ) + debug_info.instruction_locations[len(data)] = InstructionLocation( + inst=inst.instruction.location, + hints=hint_locations, + accessible_scopes=inst.accessible_scopes, + flow_tracking_data=inst.flow_tracking_data, + ) + data += [ + word + for word in encode_instruction( + build_instruction(inst.instruction), prime=preprocessed_program.prime + ) + ] if debug_info is not None: debug_info.add_autogen_file_contents() @@ -52,4 +67,5 @@ def assemble( identifiers=preprocessed_program.identifiers, builtins=preprocessed_program.builtins, reference_manager=preprocessed_program.reference_manager, - debug_info=debug_info) + debug_info=debug_info, + ) diff --git a/src/starkware/cairo/lang/compiler/assembler_test.py b/src/starkware/cairo/lang/compiler/assembler_test.py index d0dc430e..1b26b08a 100644 --- a/src/starkware/cairo/lang/compiler/assembler_test.py +++ b/src/starkware/cairo/lang/compiler/assembler_test.py @@ -2,56 +2,80 @@ from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition, LabelDefinition from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierManager, MissingIdentifierError) + IdentifierManager, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.preprocessor.flow import ReferenceManager from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.scoped_name import ScopedName def test_main_scope(): - identifiers = IdentifierManager.from_dict({ - ScopedName.from_string('a.b'): ConstDefinition(value=1), - ScopedName.from_string('x.y.z'): ConstDefinition(value=2), - }) + identifiers = IdentifierManager.from_dict( + { + ScopedName.from_string("a.b"): ConstDefinition(value=1), + ScopedName.from_string("x.y.z"): ConstDefinition(value=2), + } + ) reference_manager = ReferenceManager() program = Program( - prime=0, data=[], hints={}, builtins=[], main_scope=ScopedName.from_string('a'), - identifiers=identifiers, reference_manager=reference_manager) + prime=0, + data=[], + hints={}, + builtins=[], + main_scope=ScopedName.from_string("a"), + identifiers=identifiers, + reference_manager=reference_manager, + ) # Check accessible identifiers. - assert program.get_identifier('b', ConstDefinition).value == 1 + assert program.get_identifier("b", ConstDefinition).value == 1 # Ensure inaccessible identifiers. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'a'."): - program.get_identifier('a.b', ConstDefinition) + program.get_identifier("a.b", ConstDefinition) with pytest.raises(MissingIdentifierError, match="Unknown identifier 'x'."): - program.get_identifier('x.y', ConstDefinition) + program.get_identifier("x.y", ConstDefinition) with pytest.raises(MissingIdentifierError, match="Unknown identifier 'y'."): - program.get_identifier('y', ConstDefinition) + program.get_identifier("y", ConstDefinition) # Full name lookup. - assert program.get_identifier('a.b', ConstDefinition, full_name_lookup=True).value == 1 - assert program.get_identifier('x.y.z', ConstDefinition, full_name_lookup=True).value == 2 + assert program.get_identifier("a.b", ConstDefinition, full_name_lookup=True).value == 1 + assert program.get_identifier("x.y.z", ConstDefinition, full_name_lookup=True).value == 2 def test_program_start_property(): - identifiers = IdentifierManager.from_dict({ - ScopedName.from_string('some.main.__start__'): LabelDefinition(3), - }) + identifiers = IdentifierManager.from_dict( + { + ScopedName.from_string("some.main.__start__"): LabelDefinition(3), + } + ) reference_manager = ReferenceManager() - main_scope = ScopedName.from_string('some.main') + main_scope = ScopedName.from_string("some.main") # The label __start__ is in identifiers. program = Program( - prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, identifiers=identifiers, - reference_manager=reference_manager) + prime=0, + data=[], + hints={}, + builtins=[], + main_scope=main_scope, + identifiers=identifiers, + reference_manager=reference_manager, + ) assert program.start == 3 # The label __start__ is not in identifiers. program = Program( - prime=0, data=[], hints={}, builtins=[], main_scope=main_scope, - identifiers=IdentifierManager(), reference_manager=reference_manager) + prime=0, + data=[], + hints={}, + builtins=[], + main_scope=main_scope, + identifiers=IdentifierManager(), + reference_manager=reference_manager, + ) assert program.start == 0 diff --git a/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py b/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py index 4666b752..e26e3f37 100644 --- a/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py +++ b/src/starkware/cairo/lang/compiler/ast/aliased_identifier.py @@ -14,8 +14,9 @@ class AliasedIdentifier(AstNode): location: Optional[Location] = LocationField def format(self): - return f'{self.orig_identifier.format()}' + \ - (f' as {self.local_name.format()}' if self.local_name else '') + return f"{self.orig_identifier.format()}" + ( + f" as {self.local_name.format()}" if self.local_name else "" + ) @property def identifier(self): diff --git a/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py b/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py index a11d9815..ed4133ee 100644 --- a/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py +++ b/src/starkware/cairo/lang/compiler/ast/ast_objects_test_utils.py @@ -1,5 +1,12 @@ from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprDeref, ExprDot, ExprNeg, ExprOperator, ExprParentheses, ExprSubscript) + ExprAddressOf, + ExprDeref, + ExprDot, + ExprNeg, + ExprOperator, + ExprParentheses, + ExprSubscript, +) def remove_parentheses(expr): @@ -20,5 +27,6 @@ def remove_parentheses(expr): return ExprDot(expr=remove_parentheses(expr.expr), member=expr.member) if isinstance(expr, ExprSubscript): return ExprSubscript( - expr=remove_parentheses(expr.expr), offset=remove_parentheses(expr.offset)) + expr=remove_parentheses(expr.expr), offset=remove_parentheses(expr.offset) + ) return expr diff --git a/src/starkware/cairo/lang/compiler/ast/bool_expr.py b/src/starkware/cairo/lang/compiler/ast/bool_expr.py index ca6769be..5fd3a5c1 100644 --- a/src/starkware/cairo/lang/compiler/ast/bool_expr.py +++ b/src/starkware/cairo/lang/compiler/ast/bool_expr.py @@ -15,8 +15,8 @@ class BoolExpr(AstNode): location: Optional[Location] = LocationField def get_particles(self): - relation = '==' if self.eq else '!=' - return [f'{self.a.format()} {relation} ', self.b.format()] + relation = "==" if self.eq else "!=" + return [f"{self.a.format()} {relation} ", self.b.format()] def get_children(self) -> Sequence[Optional[AstNode]]: return [self.a, self.b] diff --git a/src/starkware/cairo/lang/compiler/ast/cairo_types.py b/src/starkware/cairo/lang/compiler/ast/cairo_types.py index 713a1034..a97ec4ec 100644 --- a/src/starkware/cairo/lang/compiler/ast/cairo_types.py +++ b/src/starkware/cairo/lang/compiler/ast/cairo_types.py @@ -18,7 +18,7 @@ def format(self) -> str: Returns a representation of the type as a string. """ - def get_pointer_type(self) -> 'CairoType': + def get_pointer_type(self) -> "CairoType": """ Returns a type of a pointer to the current type. """ @@ -30,7 +30,7 @@ class TypeFelt(CairoType): location: Optional[Location] = LocationField def format(self): - return 'felt' + return "felt" def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -42,7 +42,7 @@ class TypePointer(CairoType): location: Optional[Location] = LocationField def format(self): - return f'{self.pointee.format()}*' + return f"{self.pointee.format()}*" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.pointee] @@ -63,7 +63,7 @@ def resolved_scope(self): """ Verifies that is_fully_resolved=True and returns scope. """ - assert self.is_fully_resolved, 'Type is expected to be fully resolved at this point.' + assert self.is_fully_resolved, "Type is expected to be fully resolved at this point." return self.scope def get_children(self) -> Sequence[Optional[AstNode]]: @@ -75,6 +75,7 @@ class TypeTuple(CairoType): """ Type for a tuple. """ + members: List[CairoType] location: Optional[Location] = LocationField diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py index 1b5afca6..48bed16a 100644 --- a/src/starkware/cairo/lang/compiler/ast/code_elements.py +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -6,10 +6,18 @@ from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr from starkware.cairo.lang.compiler.ast.expr import ( - ExprAssignment, Expression, ExprHint, ExprIdentifier) + ExprAssignment, + Expression, + ExprHint, + ExprIdentifier, +) from starkware.cairo.lang.compiler.ast.formatting_utils import ( - INDENTATION, LocationField, ParticleFormattingConfig, create_particle_sublist, - particles_in_lines) + INDENTATION, + LocationField, + ParticleFormattingConfig, + create_particle_sublist, + particles_in_lines, +) from starkware.cairo.lang.compiler.ast.instructions import InstructionAst from starkware.cairo.lang.compiler.ast.node import AstNode from starkware.cairo.lang.compiler.ast.notes import NoteListField, Notes @@ -48,7 +56,7 @@ class CodeElementConst(CodeElement): expr: Expression def format(self, allowed_line_length): - return f'const {self.identifier.format()} = {self.expr.format()}' + return f"const {self.identifier.format()} = {self.expr.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.identifier, self.expr] @@ -59,7 +67,7 @@ class CodeElementMember(CodeElement): typed_identifier: TypedIdentifier def format(self, allowed_line_length): - return f'member {self.typed_identifier.format()}' + return f"member {self.typed_identifier.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier] @@ -71,7 +79,7 @@ class CodeElementReference(CodeElement): expr: Expression def format(self, allowed_line_length): - return f'let {self.typed_identifier.format()} = {self.expr.format()}' + return f"let {self.typed_identifier.format()} = {self.expr.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.expr] @@ -85,13 +93,14 @@ class CodeElementLocalVariable(CodeElement): Both the expr_type and the initialization expr are optional. """ + typed_identifier: TypedIdentifier expr: Optional[Expression] location: Optional[Location] = LocationField def format(self, allowed_line_length): - assignment = '' if self.expr is None else f' = {self.expr.format()}' - return f'local {self.typed_identifier.format()}{assignment}' + assignment = "" if self.expr is None else f" = {self.expr.format()}" + return f"local {self.typed_identifier.format()}{assignment}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.expr] @@ -103,13 +112,14 @@ class CodeElementTemporaryVariable(CodeElement): Represents a statement of the form: tempvar x = expr. """ + typed_identifier: TypedIdentifier expr: Optional[Expression] location: Optional[Location] = LocationField def format(self, allowed_line_length): - assignment = '' if self.expr is None else f' = {self.expr.format()}' - return f'tempvar {self.typed_identifier.format()}{assignment}' + assignment = "" if self.expr is None else f" = {self.expr.format()}" + return f"tempvar {self.typed_identifier.format()}{assignment}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.expr] @@ -122,12 +132,13 @@ class CodeElementCompoundAssertEq(CodeElement): Unlike AssertEqInstruction, a CodeElementCompoundAssertEq may translate to a few instructions to deal with expressions which contain more than one operation. """ + a: Expression b: Expression location: Optional[Location] = LocationField def format(self, allowed_line_length): - return f'assert {self.a.format()} = {self.b.format()}' + return f"assert {self.a.format()} = {self.b.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.a, self.b] @@ -140,7 +151,7 @@ class CodeElementStaticAssert(CodeElement): location: Optional[Location] = LocationField def format(self, allowed_line_length): - return f'static_assert {self.a.format()} == {self.b.format()}' + return f"static_assert {self.a.format()} == {self.b.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.a, self.b] @@ -152,19 +163,20 @@ class CodeElementReturn(CodeElement): Represents a statement of the form: return ([ident=]expr, ...). """ + exprs: List[ExprAssignment] location: Optional[Location] = LocationField def format(self, allowed_line_length): expr_codes = [x.format() for x in self.exprs] - particles = ['return (', create_particle_sublist(expr_codes, ')')] + particles = ["return (", create_particle_sublist(expr_codes, ")")] return particles_in_lines( particles=particles, config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION, - one_per_line=True)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=True + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return self.exprs @@ -176,20 +188,21 @@ class CodeElementTailCall(CodeElement): Represents a statement of the form: return func_ident([ident=]expr, ...). """ + func_call: RvalueFuncCall location: Optional[Location] = LocationField def get_particles(self): particales = self.func_call.get_particles() - return ['return ' + particales[0]] + particales[1:] + return ["return " + particales[0]] + particales[1:] def format(self, allowed_line_length): return particles_in_lines( particles=self.get_particles(), config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION, - one_per_line=True)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=True + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.func_call] @@ -201,6 +214,7 @@ class CodeElementFuncCall(CodeElement): Represents a statement of the form: func_ident([ident=]expr, ...). """ + func_call: RvalueFuncCall def get_particles(self): @@ -224,19 +238,20 @@ class CodeElementReturnValueReference(CodeElement): 'x [: type]' is the 'typed_identifier' 'func(...)' is the 'func_call'. """ + typed_identifier: TypedIdentifier func_call: RvalueCall def format(self, allowed_line_length): call_particles = self.func_call.get_particles() - first_particle = f'let {self.typed_identifier.format()} = ' + call_particles[0] + first_particle = f"let {self.typed_identifier.format()} = " + call_particles[0] return particles_in_lines( particles=[first_particle] + call_particles[1:], config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION, - one_per_line=True)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=True + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.func_call] @@ -251,23 +266,26 @@ class CodeElementUnpackBinding(CodeElement): '(a, b, c)' is the 'unpacking_list' 'func(...)' is the 'rvalue'. """ + unpacking_list: IdentifierList rvalue: Rvalue def format(self, allowed_line_length): particles = self.rvalue.get_particles() - end_particle = ') = ' + particles[0] - particles = ['let ('] + \ - create_particle_sublist(self.unpacking_list.get_particles(), end_particle) + \ - particles[1:] + end_particle = ") = " + particles[0] + particles = ( + ["let ("] + + create_particle_sublist(self.unpacking_list.get_particles(), end_particle) + + particles[1:] + ) return particles_in_lines( particles=particles, config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION, - one_per_line=True)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=True + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.unpacking_list, self.rvalue] @@ -278,7 +296,7 @@ class CodeElementLabel(CodeElement): identifier: ExprIdentifier def format(self, allowed_line_length): - return f'{self.identifier.format()}:' + return f"{self.identifier.format()}:" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.identifier] @@ -303,7 +321,7 @@ def get_children(self) -> Sequence[Optional[AstNode]]: @dataclasses.dataclass class CodeElementEmptyLine(CodeElement): def format(self, allowed_line_length): - return '' + return "" def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -317,11 +335,11 @@ class CommentedCodeElement(AstNode): def format(self, allowed_line_length): elm_str = self.code_elm.format(allowed_line_length=allowed_line_length) - comment_str = f'#{self.comment}' if self.comment is not None else '' - separator = ' ' if elm_str != '' and comment_str != '' else '' + comment_str = f"#{self.comment}" if self.comment is not None else "" + separator = " " if elm_str != "" and comment_str != "" else "" return elm_str + separator + comment_str.rstrip() - def fix_comment_spaces(self, allow_additional_comment_spaces: bool) -> 'CommentedCodeElement': + def fix_comment_spaces(self, allow_additional_comment_spaces: bool) -> "CommentedCodeElement": """ Comments should start with exactly one space after '#' except for some cases (in which allow_additional_comment_spaces=True). @@ -332,14 +350,14 @@ def fix_comment_spaces(self, allow_additional_comment_spaces: bool) -> 'Commente if comment is None: return self - if set(comment) == {'#'}: + if set(comment) == {"#"}: # Allow a line of '#'. return self if not allow_additional_comment_spaces: comment = comment.strip() - if not comment.startswith(' '): - comment = ' ' + comment + if not comment.startswith(" "): + comment = " " + comment return CommentedCodeElement(code_elm=self.code_elm, comment=comment, location=self.location) @@ -356,7 +374,7 @@ def format(self, allowed_line_length): code_elements = add_empty_lines_before_labels(code_elements) code_elements = fix_comment_spaces(code_elements) - return ''.join(f'{code_elm.format(allowed_line_length)}\n' for code_elm in code_elements) + return "".join(f"{code_elm.format(allowed_line_length)}\n" for code_elm in code_elements) def get_children(self) -> Sequence[Optional[AstNode]]: return self.code_elements @@ -368,11 +386,12 @@ class CodeElementScoped(CodeElement): Represents a list of code elements that should be handled inside a scope. This class does not appear naturally in the parsed AST. """ + scope: ScopedName code_elements: List[CodeElement] def format(self, allowed_line_length): - raise NotImplementedError(f'Formatting {type(self).__name__} is not supported.') + raise NotImplementedError(f"Formatting {type(self).__name__} is not supported.") def get_children(self) -> Sequence[Optional[AstNode]]: return self.code_elements @@ -387,6 +406,7 @@ class CodeElementFunction(CodeElement): return (z=x, w=y) end """ + # The type of the code element. Either 'func', 'namespace' or 'struct'. element_type: str identifier: ExprIdentifier @@ -397,9 +417,9 @@ class CodeElementFunction(CodeElement): decorators: List[ExprIdentifier] additional_attributes: Dict[str, Any] = dataclasses.field(default_factory=dict) - ARGUMENT_SCOPE = ScopedName.from_string('Args') - IMPLICIT_ARGUMENT_SCOPE = ScopedName.from_string('ImplicitArgs') - RETURN_SCOPE = ScopedName.from_string('Return') + ARGUMENT_SCOPE = ScopedName.from_string("Args") + IMPLICIT_ARGUMENT_SCOPE = ScopedName.from_string("ImplicitArgs") + RETURN_SCOPE = ScopedName.from_string("Return") @property def name(self): @@ -408,40 +428,49 @@ def name(self): def format(self, allowed_line_length): code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) code = indent(code, INDENTATION) - if self.element_type in ['struct', 'namespace']: - particles = [f'{self.element_type} {self.name}:'] + if self.element_type in ["struct", "namespace"]: + particles = [f"{self.element_type} {self.name}:"] else: if self.implicit_arguments is not None: - first_particle_suffix = '{' + first_particle_suffix = "{" implicit_args_particles = [ - create_particle_sublist(self.implicit_arguments.get_particles(), '}(')] + create_particle_sublist(self.implicit_arguments.get_particles(), "}(") + ] else: - first_particle_suffix = '(' + first_particle_suffix = "(" implicit_args_particles = [] if self.returns is not None: particles = [ - f'{self.element_type} {self.name}{first_particle_suffix}', + f"{self.element_type} {self.name}{first_particle_suffix}", *implicit_args_particles, - create_particle_sublist(self.arguments.get_particles(), ') -> ('), - create_particle_sublist(self.returns.get_particles(), '):')] + create_particle_sublist(self.arguments.get_particles(), ") -> ("), + create_particle_sublist(self.returns.get_particles(), "):"), + ] else: particles = [ - f'{self.element_type} {self.name}{first_particle_suffix}', + f"{self.element_type} {self.name}{first_particle_suffix}", *implicit_args_particles, - create_particle_sublist(self.arguments.get_particles(), '):')] + create_particle_sublist(self.arguments.get_particles(), "):"), + ] - decorators = ''.join(f'@{decorator.format()}\n' for decorator in self.decorators) + decorators = "".join(f"@{decorator.format()}\n" for decorator in self.decorators) header = particles_in_lines( particles=particles, config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION * 2)) - return f'{decorators}{header}\n{code}end' + allowed_line_length=allowed_line_length, line_indent=INDENTATION * 2 + ), + ) + return f"{decorators}{header}\n{code}end" def get_children(self) -> Sequence[Optional[AstNode]]: return [ - self.identifier, self.arguments, self.implicit_arguments, self.returns, self.code_block] + self.identifier, + self.arguments, + self.implicit_arguments, + self.returns, + self.code_block, + ] @dataclasses.dataclass @@ -450,10 +479,10 @@ class CodeElementWith(CodeElement): code_block: CodeBlock def format(self, allowed_line_length): - identifier_list_str = ', '.join(identifier.format() for identifier in self.identifiers) + identifier_list_str = ", ".join(identifier.format() for identifier in self.identifiers) inner_code = self.code_block.format(allowed_line_length=allowed_line_length - INDENTATION) inner_code = indent(inner_code, INDENTATION) - return f'with {identifier_list_str}:\n{inner_code}end' + return f"with {identifier_list_str}:\n{inner_code}end" def get_children(self) -> Sequence[Optional[AstNode]]: return [*self.identifiers, self.code_block] @@ -469,24 +498,27 @@ class CodeElementIf(CodeElement): location: Optional[Location] = LocationField def format(self, allowed_line_length): - cond_particles = ['if ', *self.condition.get_particles()] - cond_particles[-1] = cond_particles[-1] + ':' + cond_particles = ["if ", *self.condition.get_particles()] + cond_particles[-1] = cond_particles[-1] + ":" code = particles_in_lines( particles=cond_particles, config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION + ), + ) main_code = self.main_code_block.format( - allowed_line_length=allowed_line_length - INDENTATION) + allowed_line_length=allowed_line_length - INDENTATION + ) main_code = indent(main_code, INDENTATION) - code += f'\n{main_code}' + code += f"\n{main_code}" if self.else_code_block is not None: - code += f'else:' + code += f"else:" else_code = self.else_code_block.format( - allowed_line_length=allowed_line_length - INDENTATION) + allowed_line_length=allowed_line_length - INDENTATION + ) else_code = indent(else_code, INDENTATION) - code += f'\n{else_code}' - code += 'end' + code += f"\n{else_code}" + code += "end" return code def get_children(self) -> Sequence[Optional[AstNode]]: @@ -517,7 +549,7 @@ class LangDirective(Directive): location: Optional[Location] = LocationField def format(self): - return f'%lang {self.name}' + return f"%lang {self.name}" def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -547,19 +579,19 @@ def format(self, allowed_line_length): note.assert_no_comments() items = [item.format() for item in self.import_items] - prefix = f'from {self.path.format()} import ' - one_liner = prefix + ', '.join(items) + prefix = f"from {self.path.format()} import " + one_liner = prefix + ", ".join(items) if len(one_liner) <= allowed_line_length: return one_liner - particles = [f'{prefix}(', create_particle_sublist(items, ')')] + particles = [f"{prefix}(", create_particle_sublist(items, ")")] return particles_in_lines( particles=particles, config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION, - one_per_line=False)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=False + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.path, *self.import_items] @@ -570,10 +602,11 @@ class CodeElementAllocLocals(CodeElement): """ Represents a statement of the form "alloc_locals". """ + location: Optional[Location] = LocationField def format(self, allowed_line_length): - return 'alloc_locals' + return "alloc_locals" def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -584,12 +617,14 @@ def is_empty_line(code_element: CommentedCodeElement): def is_comment_line(code_element: CommentedCodeElement): - return isinstance(code_element.code_elm, CodeElementEmptyLine) and \ - code_element.comment is not None + return ( + isinstance(code_element.code_elm, CodeElementEmptyLine) and code_element.comment is not None + ) def remove_redundant_empty_lines( - code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + code_elements: List[CommentedCodeElement], +) -> List[CommentedCodeElement]: """ Returns a new list of code elements where redundant empty lines are removed. Redundant empty lines are empty lines which are after: @@ -618,7 +653,8 @@ def remove_redundant_empty_lines( def add_empty_lines_before_labels( - code_elements: List[CommentedCodeElement]) -> List[CommentedCodeElement]: + code_elements: List[CommentedCodeElement], +) -> List[CommentedCodeElement]: """ Makes sure there is an empty line before labels. The empty line is added before the comment lines preceding the label. @@ -630,10 +666,11 @@ def add_empty_lines_before_labels( if is_empty_line(code_elm): add_empty_line = False elif not is_comment_line(code_elm): - new_code_elements_reversed.append(CommentedCodeElement( - code_elm=CodeElementEmptyLine(), - comment=None, - location=None)) + new_code_elements_reversed.append( + CommentedCodeElement( + code_elm=CodeElementEmptyLine(), comment=None, location=None + ) + ) add_empty_line = False if isinstance(code_elm.code_elm, CodeElementLabel): diff --git a/src/starkware/cairo/lang/compiler/ast/expr.py b/src/starkware/cairo/lang/compiler/ast/expr.py index 4ea4effa..7f75378a 100644 --- a/src/starkware/cairo/lang/compiler/ast/expr.py +++ b/src/starkware/cairo/lang/compiler/ast/expr.py @@ -22,9 +22,9 @@ class Expression(AstNode): def format(self): res = str(self.to_expr_str()) # Indent all lines except for the first. - res = res.replace('\n', '\n' + ' ' * INDENTATION) + res = res.replace("\n", "\n" + " " * INDENTATION) # Remove trailing spaces. - res = re.sub(r' +\n', '\n', res) + res = re.sub(r" +\n", "\n", res) return res @abstractmethod @@ -41,8 +41,11 @@ class ExprConst(Expression): # Indicates the way the absolute value of the expression should be formatted in the code. # For example, it may contain the hexadecimal representation. format_str: Optional[str] = field( - default=None, hash=False, compare=False, metadata=dict( - marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True))) + default=None, + hash=False, + compare=False, + metadata=dict(marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True)), + ) location: Optional[Location] = LocationField def to_expr_str(self): @@ -62,13 +65,13 @@ class ExprPyConst(Expression): @classmethod def from_str(cls, src: str, location: Optional[Location] = None): - assert src.startswith('%[') - assert src.endswith('%]') + assert src.startswith("%[") + assert src.endswith("%]") code = src[2:-2] return cls(code, location) def to_expr_str(self): - return ExpressionString.highest(f'%[{self.code}%]') + return ExpressionString.highest(f"%[{self.code}%]") def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -83,36 +86,37 @@ class ExprHint(Expression): @classmethod def from_str(cls, val, location): - HINT_PATTERN = r'%\{(?P([ \t]*\n)*)(?P.*?)%\}' + HINT_PATTERN = r"%\{(?P([ \t]*\n)*)(?P.*?)%\}" m = re.match(HINT_PATTERN, val, re.DOTALL) assert m is not None - code = m.group('code').rstrip() + code = m.group("code").rstrip() if code is None: - code = '' + code = "" # Remove common indentation. - lines = code.split('\n') + lines = code.split("\n") common_indent = min( - (len(line) - len(line.lstrip(' ')) for line in lines if line), - default=0) - code = '\n'.join(line[common_indent:] for line in lines) + (len(line) - len(line.lstrip(" ")) for line in lines if line), default=0 + ) + code = "\n".join(line[common_indent:] for line in lines) return cls( hint_code=code, - n_prefix_newlines=m.group('prefix_whitespace').count('\n'), - location=location) + n_prefix_newlines=m.group("prefix_whitespace").count("\n"), + location=location, + ) def to_str(self): - if self.hint_code == '': - return '%{\n%}' - if '\n' not in self.hint_code: + if self.hint_code == "": + return "%{\n%}" + if "\n" not in self.hint_code: # One liner. - return f'%{{ {self.hint_code} %}}' + return f"%{{ {self.hint_code} %}}" code = indent(self.hint_code, INDENTATION) - return f'%{{\n{code}\n%}}' + return f"%{{\n{code}\n%}}" def to_expr_str(self): - return ExpressionString.highest(f'nondet {self.to_str()}') + return ExpressionString.highest(f"nondet {self.to_str()}") def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -135,6 +139,7 @@ class ExprAssignment(AstNode): """ A code element of the form [ident=]expr. The identifier is optional. """ + identifier: Optional[ExprIdentifier] expr: Expression location: Optional[Location] = LocationField @@ -142,7 +147,7 @@ class ExprAssignment(AstNode): def format(self): if self.identifier is None: return self.expr.format() - return f'{self.identifier.format()}={self.expr.format()}' + return f"{self.identifier.format()}={self.expr.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.identifier, self.expr] @@ -154,6 +159,7 @@ class ArgList(AstNode): Represents a list of arguments (e.g., to a function call or a return statement). For example: 'a=1, b=2'. """ + args: List[ExprAssignment] notes: List[Notes] has_trailing_comma: bool @@ -168,18 +174,18 @@ def format(self): assert len(self.notes) == 1 return self.notes[0].format() - code = '' + code = "" assert len(self.args) + 1 == len(self.notes) for notes, arg in zip(self.notes[:-1], self.args): - if code != '': - code += ',' + if code != "": + code += "," if notes.empty: - code += ' ' - code += f'{notes.format()}{arg.format()}' + code += " " + code += f"{notes.format()}{arg.format()}" # Add trailing comma at the end if necessary. if self.has_trailing_comma: - code += ',' + code += "," code += self.notes[-1].format() return code @@ -212,14 +218,14 @@ def to_expr_str(self): a = self.a.to_expr_str() b = self.b.to_expr_str() if not self.notes.empty: - b = b.prepend('\n') - if self.op == '+': + b = b.prepend("\n") + if self.op == "+": return a + b - elif self.op == '-': + elif self.op == "-": return a - b - elif self.op == '*': + elif self.op == "*": return a * b - elif self.op == '/': + elif self.op == "/": return a / b else: raise NotImplementedError(f"Unexpected operator '{self.op}'") @@ -240,7 +246,7 @@ def to_expr_str(self): a = self.a.to_expr_str() b = self.b.to_expr_str() if not self.notes.empty: - b = b.prepend('\n') + b = b.prepend("\n") return a.double_star_pow(b) def get_children(self) -> Sequence[Optional[AstNode]]: @@ -252,6 +258,7 @@ class ExprAddressOf(Expression): """ Represents an expression of the form "&expr". """ + expr: Expression location: Optional[Location] = LocationField @@ -281,7 +288,7 @@ class ExprParentheses(Expression): location: Optional[Location] = LocationField def to_expr_str(self): - return ExpressionString.highest(f'({self.notes.format()}{str(self.val.to_expr_str())})') + return ExpressionString.highest(f"({self.notes.format()}{str(self.val.to_expr_str())})") def get_children(self) -> Sequence[Optional[AstNode]]: return [self.val] @@ -292,14 +299,15 @@ class ExprDeref(Expression): """ Represents an expression of the form "[addr]". """ + addr: Expression notes: Notes = NotesField location: Optional[Location] = LocationField def to_expr_str(self): self.notes.assert_no_comments() - notes = '' if self.notes.empty else '\n' - return ExpressionString.highest(f'[{notes}{str(self.addr.to_expr_str())}]') + notes = "" if self.notes.empty else "\n" + return ExpressionString.highest(f"[{notes}{str(self.addr.to_expr_str())}]") def get_children(self) -> Sequence[Optional[AstNode]]: return [self.addr] @@ -310,6 +318,7 @@ class ExprSubscript(Expression): """ Represents an expression of the form "expr[offset]". """ + expr: Expression offset: Expression notes: Notes = NotesField @@ -317,10 +326,11 @@ class ExprSubscript(Expression): def to_expr_str(self): self.notes.assert_no_comments() - notes = '' if self.notes.empty else '\n' + notes = "" if self.notes.empty else "\n" # If expr is not an atom, add parentheses. return ExpressionString.highest( - f'{self.expr.to_expr_str():HIGHEST}[{notes}{str(self.offset.to_expr_str())}]') + f"{self.expr.to_expr_str():HIGHEST}[{notes}{str(self.offset.to_expr_str())}]" + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.expr, self.offset] @@ -331,6 +341,7 @@ class ExprDot(Expression): """ Represents an expression of the form "expr.member". """ + expr: Expression member: ExprIdentifier location: Optional[Location] = LocationField @@ -338,7 +349,8 @@ class ExprDot(Expression): def to_expr_str(self): # If expr is not an atom, add parentheses. return ExpressionString.highest( - f'{self.expr.to_expr_str():HIGHEST}.{str(self.member.to_expr_str())}') + f"{self.expr.to_expr_str():HIGHEST}.{str(self.member.to_expr_str())}" + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.expr, self.member] @@ -349,6 +361,7 @@ class ExprCast(Expression): """ Represents a cast expression of the form "cast(expr, T)" (which transforms expr to type T). """ + expr: Expression dest_type: CairoType # Cast expressions resulting from the Cairo code always have cast_type=CastType.EXPLICIT. @@ -359,9 +372,10 @@ class ExprCast(Expression): def to_expr_str(self): self.notes.assert_no_comments() - notes = '' if self.notes.empty else '\n' + notes = "" if self.notes.empty else "\n" return ExpressionString.highest( - f'cast({notes}{str(self.expr.to_expr_str())}, {self.dest_type.format()})') + f"cast({notes}{str(self.expr.to_expr_str())}, {self.dest_type.format()})" + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.expr, self.dest_type] @@ -374,7 +388,7 @@ class ExprTuple(Expression): def to_expr_str(self): code = self.members.format() - return ExpressionString.highest(f'({code})') + return ExpressionString.highest(f"({code})") def get_children(self) -> Sequence[Optional[AstNode]]: return [self.members] @@ -385,6 +399,7 @@ class ExprFutureLabel(Expression): """ Represents a future label whose current pc is not known yet. """ + identifier: ExprIdentifier def to_expr_str(self): diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py index 99b87bfb..6fa6733c 100644 --- a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py @@ -13,9 +13,13 @@ from starkware.cairo.lang.compiler.error_handling import LocationError INDENTATION = 4 -LocationField = field(default=None, hash=False, compare=False, metadata=dict( - marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True))) -max_line_length_ctx_var: ContextVar[int] = ContextVar('max_line_length', default=100) +LocationField = field( + default=None, + hash=False, + compare=False, + metadata=dict(marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True)), +) +max_line_length_ctx_var: ContextVar[int] = ContextVar("max_line_length", default=100) def get_max_line_length(): @@ -44,7 +48,7 @@ class ParticleFormattingConfig: # The indentation, starting from the second line. line_indent: int # The prefix of the first line. - first_line_prefix: str = '' + first_line_prefix: str = "" # At most one item per line. one_per_line: bool = False @@ -69,7 +73,7 @@ def newline(self): return self.lines.append(self.line) self.line_is_new = True - self.line = ' ' * self.config.line_indent + self.line = " " * self.config.line_indent def add_to_line(self, string): """ @@ -86,13 +90,13 @@ def finalize(self): """ if self.line: self.lines.append(self.line) - return '\n'.join(line.rstrip() for line in self.lines) + return "\n".join(line.rstrip() for line in self.lines) -def create_particle_sublist(lst, end='', separator=', '): - if not lst: +def create_particle_sublist(lst: List[str], end: str = "", separator: str = ", ") -> List[str]: + if len(lst) == 0: # If the list is empty, return the single element 'end'. - return end + return [end] # Concatenate the 'separator' to all elements of the 'lst' and 'end' to the last one. return [elm + separator for elm in lst[:-1]] + [lst[-1] + end] @@ -139,7 +143,7 @@ def particles_in_lines(particles, config: ParticleFormattingConfig): if isinstance(particle, list): # If the entire sublist fits in a single line, add it. if sum(map(len, particle), config.line_indent) < config.allowed_line_length: - builder.add_to_line(''.join(particle)) + builder.add_to_line("".join(particle)) continue builder.newline() for member in particle: diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py index d5a9c0dd..59345b04 100644 --- a/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils_test.py @@ -1,14 +1,17 @@ from starkware.cairo.lang.compiler.ast.formatting_utils import ( - ParticleFormattingConfig, create_particle_sublist, particles_in_lines) + ParticleFormattingConfig, + create_particle_sublist, + particles_in_lines, +) def test_particles_in_lines(): particles = [ - 'start ', - 'foo ', - 'bar ', - create_particle_sublist(['a', 'b', 'c', 'dddd', 'e', 'f'], '*'), - ' asdf', + "start ", + "foo ", + "bar ", + create_particle_sublist(["a", "b", "c", "dddd", "e", "f"], "*"), + " asdf", ] expected = """\ start foo @@ -17,15 +20,18 @@ def test_particles_in_lines(): dddd, e, f* asdf\ """ - assert particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=12, line_indent=2), - ) == expected + assert ( + particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=2), + ) + == expected + ) particles = [ - 'func f(', - create_particle_sublist(['x', 'y', 'z'], ') -> ('), - create_particle_sublist(['a', 'b', 'c'], '):'), + "func f(", + create_particle_sublist(["x", "y", "z"], ") -> ("), + create_particle_sublist(["a", "b", "c"], "):"), ] expected = """\ func f( @@ -34,10 +40,13 @@ def test_particles_in_lines(): a, b, c):\ """ - assert particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=12, line_indent=4), - ) == expected + assert ( + particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=12, line_indent=4), + ) + == expected + ) # Same particles, using one_per_line=True. expected = """\ @@ -49,11 +58,15 @@ def test_particles_in_lines(): b, c):\ """ - assert particles_in_lines( - particles=particles, - config=ParticleFormattingConfig( - allowed_line_length=12, line_indent=4, one_per_line=True), - ) == expected + assert ( + particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=12, line_indent=4, one_per_line=True + ), + ) + == expected + ) # Same particles, using one_per_line=True, longer lines. expected = """\ @@ -61,22 +74,29 @@ def test_particles_in_lines(): x, y, z) -> ( a, b, c):\ """ - assert particles_in_lines( - particles=particles, - config=ParticleFormattingConfig( - allowed_line_length=19, line_indent=4, one_per_line=True), - ) == expected + assert ( + particles_in_lines( + particles=particles, + config=ParticleFormattingConfig( + allowed_line_length=19, line_indent=4, one_per_line=True + ), + ) + == expected + ) particles = [ - 'func f(', - create_particle_sublist(['x', 'y', 'z'], ') -> ('), - create_particle_sublist([], '):'), + "func f(", + create_particle_sublist(["x", "y", "z"], ") -> ("), + create_particle_sublist([], "):"), ] expected = """\ func f( x, y, z) -> ():\ """ - assert particles_in_lines( - particles=particles, - config=ParticleFormattingConfig(allowed_line_length=19, line_indent=4), - ) == expected + assert ( + particles_in_lines( + particles=particles, + config=ParticleFormattingConfig(allowed_line_length=19, line_indent=4), + ) + == expected + ) diff --git a/src/starkware/cairo/lang/compiler/ast/instructions.py b/src/starkware/cairo/lang/compiler/ast/instructions.py index 8cf7b297..58d2b77f 100644 --- a/src/starkware/cairo/lang/compiler/ast/instructions.py +++ b/src/starkware/cairo/lang/compiler/ast/instructions.py @@ -33,7 +33,7 @@ class AssertEqInstruction(InstructionBody): location: Optional[Location] = LocationField def format(self): - return f'{self.a.format()} = {self.b.format()}' + return f"{self.a.format()} = {self.b.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.a, self.b] @@ -67,8 +67,8 @@ class JumpToLabelInstruction(InstructionBody): location: Optional[Location] = LocationField def format(self): - condition_str = '' if self.condition is None else f' if {self.condition.format()} != 0' - return f'jmp {self.label.format()}{condition_str}' + condition_str = "" if self.condition is None else f" if {self.condition.format()} != 0" + return f"jmp {self.label.format()}{condition_str}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.label, self.condition] @@ -85,7 +85,7 @@ class JnzInstruction(InstructionBody): location: Optional[Location] = LocationField def format(self): - return f'jmp rel {self.jump_offset.format()} if {self.condition.format()} != 0' + return f"jmp rel {self.jump_offset.format()} if {self.condition.format()} != 0" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.jump_offset, self.condition] @@ -118,7 +118,7 @@ class CallLabelInstruction(InstructionBody): location: Optional[Location] = LocationField def format(self): - return f'call {self.label.format()}' + return f"call {self.label.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.label] @@ -133,7 +133,7 @@ class RetInstruction(InstructionBody): location: Optional[Location] = LocationField def format(self): - return 'ret' + return "ret" def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -149,7 +149,7 @@ class AddApInstruction(InstructionBody): location: Optional[Location] = LocationField def format(self): - return f'ap += {self.expr.format()}' + return f"ap += {self.expr.format()}" def get_children(self) -> Sequence[Optional[AstNode]]: return [self.expr] @@ -160,12 +160,13 @@ class InstructionAst(AstNode): """ Represents an instruction, including the ap++ flag (inc_ap). """ + body: InstructionBody inc_ap: bool location: Optional[Location] = LocationField def format(self): - return self.body.format() + ('; ap++' if self.inc_ap else '') + return self.body.format() + ("; ap++" if self.inc_ap else "") def get_children(self) -> Sequence[Optional[AstNode]]: return [self.body] diff --git a/src/starkware/cairo/lang/compiler/ast/node.py b/src/starkware/cairo/lang/compiler/ast/node.py index 72413521..d0208f53 100644 --- a/src/starkware/cairo/lang/compiler/ast/node.py +++ b/src/starkware/cairo/lang/compiler/ast/node.py @@ -4,7 +4,7 @@ class AstNode(ABC): @abstractmethod - def get_children(self) -> Sequence[Optional['AstNode']]: + def get_children(self) -> Sequence[Optional["AstNode"]]: """ Returns a list of the node's children (notes are not included). """ diff --git a/src/starkware/cairo/lang/compiler/ast/notes.py b/src/starkware/cairo/lang/compiler/ast/notes.py index dafc09f2..786f618a 100644 --- a/src/starkware/cairo/lang/compiler/ast/notes.py +++ b/src/starkware/cairo/lang/compiler/ast/notes.py @@ -20,6 +20,7 @@ class Notes(AstNode): assert a = b + # Hello. c + d # World. """ + # The comments of the note. If empty, the value of starts_new_line is ignored. comments: List[str] = field(default_factory=list) # Whether the note starts on its own line. @@ -34,8 +35,9 @@ def assert_no_comments(self): if len(self.comments) == 0: return raise FormattingError( - 'Comments inside expressions are not supported by the auto-formatter.', - location=self.location) + "Comments inside expressions are not supported by the auto-formatter.", + location=self.location, + ) def __add__(self, other): if not isinstance(other, type(self)): @@ -45,20 +47,21 @@ def __add__(self, other): return Notes( comments=self.comments + other.comments, starts_new_line=self.starts_new_line, - location=self.location) + location=self.location, + ) def format(self): - code = '' + code = "" if self.starts_new_line: - code += '\n' + code += "\n" elif len(self.comments) > 0: - code += ' ' + code += " " for comment in self.comments: - assert comment.startswith('#') + assert comment.startswith("#") comment_body = comment[1:].strip() - if comment_body != '': - comment_body = ' ' + comment_body - code += f'#{comment_body}\n' + if comment_body != "": + comment_body = " " + comment_body + code += f"#{comment_body}\n" return code def get_children(self) -> Sequence[Optional[AstNode]]: diff --git a/src/starkware/cairo/lang/compiler/ast/rvalue.py b/src/starkware/cairo/lang/compiler/ast/rvalue.py index d5fe3967..8c069069 100644 --- a/src/starkware/cairo/lang/compiler/ast/rvalue.py +++ b/src/starkware/cairo/lang/compiler/ast/rvalue.py @@ -4,8 +4,12 @@ from starkware.cairo.lang.compiler.ast.expr import ArgList, Expression, ExprIdentifier from starkware.cairo.lang.compiler.ast.formatting_utils import ( - INDENTATION, LocationField, ParticleFormattingConfig, create_particle_sublist, - particles_in_lines) + INDENTATION, + LocationField, + ParticleFormattingConfig, + create_particle_sublist, + particles_in_lines, +) from starkware.cairo.lang.compiler.ast.instructions import CallInstruction from starkware.cairo.lang.compiler.ast.node import AstNode from starkware.cairo.lang.compiler.error_handling import Location @@ -45,6 +49,7 @@ class RvalueExpr(Rvalue): """ Represents an rvalue which is a simple expression. E.g., fp + 17. """ + expr: Expression @property @@ -76,6 +81,7 @@ class RvalueCallInst(RvalueCall): call_inst is CallInstruction that calls the function. """ + call_inst: CallInstruction @property @@ -98,6 +104,7 @@ class RvalueFuncCall(RvalueCall): Represents an rvalue of the form: func_ident([ident=]expr, ...). """ + func_ident: ExprIdentifier arguments: ArgList implicit_arguments: Optional[ArgList] @@ -114,13 +121,14 @@ def get_particles(self): particles = [self.func_ident.format()] if self.implicit_arguments is not None: - particles[-1] += '{' - particles.append(create_particle_sublist( - [x.format() for x in self.implicit_arguments.args], '}(')) + particles[-1] += "{" + particles.append( + create_particle_sublist([x.format() for x in self.implicit_arguments.args], "}(") + ) else: - particles[-1] += '(' + particles[-1] += "(" - particles.append(create_particle_sublist([x.format() for x in self.arguments.args], ')')) + particles.append(create_particle_sublist([x.format() for x in self.arguments.args], ")")) return particles def format(self, allowed_line_length): @@ -128,9 +136,9 @@ def format(self, allowed_line_length): return particles_in_lines( particles=self.get_particles(), config=ParticleFormattingConfig( - allowed_line_length=allowed_line_length, - line_indent=INDENTATION, - one_per_line=True)) + allowed_line_length=allowed_line_length, line_indent=INDENTATION, one_per_line=True + ), + ) def format_for_expr(self) -> str: """ @@ -141,9 +149,9 @@ def format_for_expr(self) -> str: res = self.func_ident.format() if self.implicit_arguments is not None: - res += '{' + self.implicit_arguments.format() + '}' + res += "{" + self.implicit_arguments.format() + "}" - res += '(' + self.arguments.format() + ')' + res += "(" + self.arguments.format() + ")" return res def get_children(self) -> Sequence[Optional[AstNode]]: diff --git a/src/starkware/cairo/lang/compiler/ast/types.py b/src/starkware/cairo/lang/compiler/ast/types.py index 77970cdc..bc10a617 100644 --- a/src/starkware/cairo/lang/compiler/ast/types.py +++ b/src/starkware/cairo/lang/compiler/ast/types.py @@ -32,8 +32,8 @@ class TypedIdentifier(AstNode): modifier: Optional[Modifier] = None def format(self): - modifier_str = '' if self.modifier is None else self.modifier.format() + ' ' - type_str = '' if self.expr_type is None else f' : {self.expr_type.format()}' + modifier_str = "" if self.modifier is None else self.modifier.format() + " " + type_str = "" if self.expr_type is None else f" : {self.expr_type.format()}" return modifier_str + self.identifier.format() + type_str def override_type(self, expr_type): diff --git a/src/starkware/cairo/lang/compiler/ast/visitor.py b/src/starkware/cairo/lang/compiler/ast/visitor.py index b4ed4b5e..ba8a17c6 100644 --- a/src/starkware/cairo/lang/compiler/ast/visitor.py +++ b/src/starkware/cairo/lang/compiler/ast/visitor.py @@ -2,8 +2,14 @@ from typing import List, Optional from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeBlock, CodeElementDirective, CodeElementFunction, CodeElementScoped, CodeElementWith, - CommentedCodeElement, LangDirective) + CodeBlock, + CodeElementDirective, + CodeElementFunction, + CodeElementScoped, + CodeElementWith, + CommentedCodeElement, + LangDirective, +) from starkware.cairo.lang.compiler.ast.module import CairoFile, CairoModule from starkware.cairo.lang.compiler.ast.node import AstNode from starkware.cairo.lang.compiler.error_handling import LocationError @@ -29,10 +35,10 @@ def visit(self, obj): Visits an object by calling its type's 'visit_{type}'. If no corresponding visit function is found, calls '_visit_default'. """ - return getattr(self, f'visit_{type(obj).__name__}', self._visit_default)(obj) + return getattr(self, f"visit_{type(obj).__name__}", self._visit_default)(obj) def visit_CodeElementFunction(self, elm: CodeElementFunction): - if elm.element_type == 'struct': + if elm.element_type == "struct": return elm new_scope = self.current_scope + elm.name @@ -48,8 +54,9 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): ) def visit_CairoModule(self, module: CairoModule): - with self.scoped(module.module_name, parent=module), \ - self.with_file_lang(get_lang_from_file(module.cairo_file)): + with self.scoped(module.module_name, parent=module), self.with_file_lang( + get_lang_from_file(module.cairo_file) + ): return CairoModule( cairo_file=CairoFile(code_block=self.visit(module.cairo_file.code_block)), module_name=module.module_name, @@ -63,13 +70,16 @@ def visit_CodeElementScoped(self, elm: CodeElementScoped): ) def visit_CodeBlock(self, elm: CodeBlock): - return CodeBlock(code_elements=[ - CommentedCodeElement( - code_elm=self.visit(commented_code_elm.code_elm), - comment=commented_code_elm.comment, - location=commented_code_elm.location) - for commented_code_elm in elm.code_elements - ]) + return CodeBlock( + code_elements=[ + CommentedCodeElement( + code_elm=self.visit(commented_code_elm.code_elm), + comment=commented_code_elm.comment, + location=commented_code_elm.location, + ) + for commented_code_elm in elm.code_elements + ] + ) def visit_CodeElementWith(self, elm: CodeElementWith): return CodeElementWith(identifiers=elm.identifiers, code_block=self.visit(elm.code_block)) @@ -79,7 +89,8 @@ def _visit_default(self, obj): Default behavior for visitor if 'obj' type isn't handled. By default, raise exception. """ raise NotImplementedError( - f'No handler found for type {type(obj).__name__} in {type(self).__name__}.') + f"No handler found for type {type(obj).__name__} in {type(self).__name__}." + ) @contextmanager def scoped(self, new_scope: ScopedName, parent: Optional[AstNode]): @@ -128,6 +139,6 @@ def get_lang_from_file(cairo_file: CairoFile) -> Optional[str]: if not isinstance(directive, LangDirective): continue if lang is not None: - raise VisitorError('Found two %lang directives', location=code_elm.location) + raise VisitorError("Found two %lang directives", location=code_elm.location) lang = directive.name return lang diff --git a/src/starkware/cairo/lang/compiler/ast_objects_test.py b/src/starkware/cairo/lang/compiler/ast_objects_test.py index 34f3b1da..9dc9c325 100644 --- a/src/starkware/cairo/lang/compiler/ast_objects_test.py +++ b/src/starkware/cairo/lang/compiler/ast_objects_test.py @@ -13,44 +13,50 @@ def test_format_parentheses(): # Call remove_parentheses(parse_expr()) to create an expression tree in the given structure # without ExprParentheses. - assert remove_parentheses(parse_expr('(a + b) * (c - d) * (e * f)')).format() == \ - '(a + b) * (c - d) * e * f' - assert remove_parentheses(parse_expr('x - (a + b) - (c - d) - (e * f)')).format() == \ - 'x - (a + b) - (c - d) - e * f' - assert remove_parentheses(parse_expr('(a + b) + (c - d) + (e * f)')).format() == \ - 'a + b + c - d + e * f' - assert remove_parentheses(parse_expr('-(a + b + c)')).format() == '-(a + b + c)' - assert remove_parentheses(parse_expr('a + -b + c')).format() == 'a + (-b) + c' - assert remove_parentheses(parse_expr('&(a + b)')).format() == '&(a + b)' - assert remove_parentheses(parse_expr('a ** b ** c ** d')).format() == 'a ** (b ** (c ** d))' + assert ( + remove_parentheses(parse_expr("(a + b) * (c - d) * (e * f)")).format() + == "(a + b) * (c - d) * e * f" + ) + assert ( + remove_parentheses(parse_expr("x - (a + b) - (c - d) - (e * f)")).format() + == "x - (a + b) - (c - d) - e * f" + ) + assert ( + remove_parentheses(parse_expr("(a + b) + (c - d) + (e * f)")).format() + == "a + b + c - d + e * f" + ) + assert remove_parentheses(parse_expr("-(a + b + c)")).format() == "-(a + b + c)" + assert remove_parentheses(parse_expr("a + -b + c")).format() == "a + (-b) + c" + assert remove_parentheses(parse_expr("&(a + b)")).format() == "&(a + b)" + assert remove_parentheses(parse_expr("a ** b ** c ** d")).format() == "a ** (b ** (c ** d))" # Test that parentheses are added to non-atomized Dot and Subscript expressions. - assert remove_parentheses(parse_expr('(x * y).z')).format() == '(x * y).z' - assert remove_parentheses(parse_expr('(-x).y')).format() == '(-x).y' - assert remove_parentheses(parse_expr('(&x).y')).format() == '(&x).y' - assert remove_parentheses(parse_expr('(x * y)[z]')).format() == '(x * y)[z]' - assert remove_parentheses(parse_expr('(-x)[y]')).format() == '(-x)[y]' - assert remove_parentheses(parse_expr('(&x)[y]')).format() == '(&x)[y]' - - assert remove_parentheses(parse_expr('&(x.y)')).format() == '&x.y' - assert remove_parentheses(parse_expr('-(x.y)')).format() == '-x.y' - assert remove_parentheses(parse_expr('(x.y)*z')).format() == 'x.y * z' - assert remove_parentheses(parse_expr('x-(y.z)')).format() == 'x - y.z' - - assert remove_parentheses(parse_expr('([x].y).z')).format() == '[x].y.z' - assert remove_parentheses(parse_expr('&(x[y])')).format() == '&x[y]' - assert remove_parentheses(parse_expr('-(x[y])')).format() == '-x[y]' - assert remove_parentheses(parse_expr('(x[y])*z')).format() == 'x[y] * z' - assert remove_parentheses(parse_expr('x-(y[z])')).format() == 'x - y[z]' - assert remove_parentheses(parse_expr('(([x][y])[z])')).format() == '[x][y][z]' - assert remove_parentheses(parse_expr('x[(y+z)]')).format() == 'x[y + z]' - - assert remove_parentheses(parse_expr('[((x+y) + z)]')).format() == '[x + y + z]' + assert remove_parentheses(parse_expr("(x * y).z")).format() == "(x * y).z" + assert remove_parentheses(parse_expr("(-x).y")).format() == "(-x).y" + assert remove_parentheses(parse_expr("(&x).y")).format() == "(&x).y" + assert remove_parentheses(parse_expr("(x * y)[z]")).format() == "(x * y)[z]" + assert remove_parentheses(parse_expr("(-x)[y]")).format() == "(-x)[y]" + assert remove_parentheses(parse_expr("(&x)[y]")).format() == "(&x)[y]" + + assert remove_parentheses(parse_expr("&(x.y)")).format() == "&x.y" + assert remove_parentheses(parse_expr("-(x.y)")).format() == "-x.y" + assert remove_parentheses(parse_expr("(x.y)*z")).format() == "x.y * z" + assert remove_parentheses(parse_expr("x-(y.z)")).format() == "x - y.z" + + assert remove_parentheses(parse_expr("([x].y).z")).format() == "[x].y.z" + assert remove_parentheses(parse_expr("&(x[y])")).format() == "&x[y]" + assert remove_parentheses(parse_expr("-(x[y])")).format() == "-x[y]" + assert remove_parentheses(parse_expr("(x[y])*z")).format() == "x[y] * z" + assert remove_parentheses(parse_expr("x-(y[z])")).format() == "x - y[z]" + assert remove_parentheses(parse_expr("(([x][y])[z])")).format() == "[x][y][z]" + assert remove_parentheses(parse_expr("x[(y+z)]")).format() == "x[y + z]" + + assert remove_parentheses(parse_expr("[((x+y) + z)]")).format() == "[x + y + z]" # Test that parentheses are not added if they were already present. - assert parse_expr('(a * (b + c))').format() == '(a * (b + c))' - assert parse_expr('((a * ((b + c))))').format() == '((a * ((b + c))))' - assert parse_expr('(x + y)[z]').format() == '(x + y)[z]' + assert parse_expr("(a * (b + c))").format() == "(a * (b + c))" + assert parse_expr("((a * ((b + c))))").format() == "((a * ((b + c))))" + assert parse_expr("(x + y)[z]").format() == "(x + y)[z]" def test_format_parentheses_notes(): @@ -99,18 +105,20 @@ def test_format_func_call_notes(): before = """\ foo(x = 12 # Comment. )""" - with pytest.raises(FormattingError, match='Comments inside expressions are not supported'): + with pytest.raises(FormattingError, match="Comments inside expressions are not supported"): parse_code_element(before).format(allowed_line_length=100) def test_negative_numbers(): - assert ExprConst(-1).format() == '-1' - assert ExprNeg(val=ExprConst(val=1)).format() == '-1' - assert ExprOperator(a=ExprConst(val=-1), op='+', b=ExprConst(val=-2)).format() == '(-1) + (-2)' - assert ExprOperator( - a=ExprNeg(val=ExprConst(val=1)), - op='+', - b=ExprNeg(val=ExprConst(val=2))).format() == '(-1) + (-2)' + assert ExprConst(-1).format() == "-1" + assert ExprNeg(val=ExprConst(val=1)).format() == "-1" + assert ExprOperator(a=ExprConst(val=-1), op="+", b=ExprConst(val=-2)).format() == "(-1) + (-2)" + assert ( + ExprOperator( + a=ExprNeg(val=ExprConst(val=1)), op="+", b=ExprNeg(val=ExprConst(val=2)) + ).format() + == "(-1) + (-2)" + ) def test_file_format(): @@ -136,6 +144,7 @@ def test_file_format(): local z :T*=x assert x*z+x= y+y static_assert ap + (3 + 7 )+ ap ==fp +let()=foo() return (1,[fp], [ap +3],) fibonacci (a = 3 , b=[fp +1]) @@ -164,6 +173,7 @@ def test_file_format(): local z : T* = x assert x * z + x = y + y static_assert ap + (3 + 7) + ap == fp +let () = foo() return (1, [fp], [ap + 3]) fibonacci(a=3, b=[fp + 1]) [ap - 1] = [fp] # This is a comment. @@ -240,7 +250,9 @@ def test_file_format_comment_spaces(): # First line. # Second line. [ap] = [ap] #{spaces} -""".format(spaces=' ') +""".format( + spaces=" " + ) after = """\ # First line. # @@ -511,8 +523,11 @@ def test_with(): [ap] = [ap] end """ - assert parse_file(code).format() == """\ + assert ( + parse_file(code).format() + == """\ with a, b as c, d: [ap] = [ap] end """ + ) diff --git a/src/starkware/cairo/lang/compiler/cairo_compile.py b/src/starkware/cairo/lang/compiler/cairo_compile.py index 2ed6ca45..26871326 100644 --- a/src/starkware/cairo/lang/compiler/cairo_compile.py +++ b/src/starkware/cairo/lang/compiler/cairo_compile.py @@ -22,40 +22,60 @@ def cairo_compile_add_common_args(parser: argparse.ArgumentParser): - parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') - parser.add_argument('files', metavar='file', type=str, nargs='+', help='File names') + parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument("files", metavar="file", type=str, nargs="+", help="File names") parser.add_argument( - '--prime', type=int, default=DEFAULT_PRIME, help='The size of the finite field.') + "--prime", type=int, default=DEFAULT_PRIME, help="The size of the finite field." + ) parser.add_argument( - '--cairo_path', type=str, default='', + "--cairo_path", + type=str, + default="", help=( 'A list of directories, separated by ":" to resolve import paths. ' - 'The full list will consist of directories defined by this argument, followed by ' - f'the environment variable {LIBS_DIR_ENVVAR}, the working directory and the standard ' - 'library path.')) + "The full list will consist of directories defined by this argument, followed by " + f"the environment variable {LIBS_DIR_ENVVAR}, the working directory and the standard " + "library path." + ), + ) parser.add_argument( - '--preprocess', action='store_true', - help='Stop after the preprocessor step and output the preprocessed program.') + "--preprocess", + action="store_true", + help="Stop after the preprocessor step and output the preprocessed program.", + ) parser.add_argument( - '--output', type=argparse.FileType('w'), help='The output file name (default: stdout).') + "--output", type=argparse.FileType("w"), help="The output file name (default: stdout)." + ) parser.add_argument( - '--no_debug_info', dest='debug_info', action='store_false', - help='Include debug information.') + "--no_debug_info", + dest="debug_info", + action="store_false", + help="Include debug information.", + ) parser.add_argument( - '--debug_info_with_source', action='store_true', - help='Include debug information with a copy of the source code.') + "--debug_info_with_source", + action="store_true", + help="Include debug information with a copy of the source code.", + ) parser.add_argument( - '--cairo_dependencies', type=str, - help='Output a list of the Cairo source files used during the compilation as a CMake file.') + "--cairo_dependencies", + type=str, + help="Output a list of the Cairo source files used during the compilation as a CMake file.", + ) parser.add_argument( - '--no_opt_unused_functions', dest='opt_unused_functions', action='store_false', - default=True, help='Disables unused function optimization.') + "--no_opt_unused_functions", + dest="opt_unused_functions", + action="store_false", + default=True, + help="Disables unused function optimization.", + ) def cairo_compile_common( - args: argparse.Namespace, - pass_manager_factory: Callable[[argparse.Namespace, ModuleReader], PassManager], - assemble_func: Callable) -> PreprocessedProgram: + args: argparse.Namespace, + pass_manager_factory: Callable[[argparse.Namespace, ModuleReader], PassManager], + assemble_func: Callable, +) -> PreprocessedProgram: """ Common code for CLI Cairo compilation. @@ -72,35 +92,39 @@ def cairo_compile_common( try: codes = get_codes(args.files) file_contents_for_debug_info = {} - if getattr(args, 'proof_mode', False): + if getattr(args, "proof_mode", False): codes = add_start_code(codes) file_contents_for_debug_info[START_FILE_NAME] = codes[0][0] out = args.output if args.output is not None else sys.stdout - cairo_path: List[str] = list(filter( - None, args.cairo_path.split(':') + os.getenv(LIBS_DIR_ENVVAR, '').split(':'))) + cairo_path: List[str] = list( + filter(None, args.cairo_path.split(":") + os.getenv(LIBS_DIR_ENVVAR, "").split(":")) + ) module_reader = get_module_reader(cairo_path=cairo_path) pass_manager = pass_manager_factory(args, module_reader) preprocessed = preprocess_codes( - codes=codes, - pass_manager=pass_manager, - main_scope=MAIN_SCOPE) + codes=codes, pass_manager=pass_manager, main_scope=MAIN_SCOPE + ) if args.preprocess: - print(preprocessed.format(with_locations=debug_info), end='', file=out) + print(preprocessed.format(with_locations=debug_info), end="", file=out) else: if args.debug_info_with_source: for source_file in module_reader.source_files | set(args.files): file_contents_for_debug_info[source_file] = open(source_file).read() assembled_program = assemble_func( - preprocessed, main_scope=MAIN_SCOPE, add_debug_info=debug_info, - file_contents_for_debug_info=file_contents_for_debug_info) + preprocessed, + main_scope=MAIN_SCOPE, + add_debug_info=debug_info, + file_contents_for_debug_info=file_contents_for_debug_info, + ) json.dump( - assembled_program.Schema().dump(assembled_program), out, indent=4, sort_keys=True) + assembled_program.Schema().dump(assembled_program), out, indent=4, sort_keys=True + ) # Print a new line at the end. print(file=out) @@ -108,24 +132,26 @@ def cairo_compile_common( finally: if args.cairo_dependencies: generate_cairo_dependencies_file( - args.cairo_dependencies, module_reader.source_files | set(args.files), start_time) + args.cairo_dependencies, module_reader.source_files | set(args.files), start_time + ) def get_module_reader(cairo_path: List[str]) -> ModuleReader: - starkware_src = os.path.join(os.path.dirname(__file__), '../../../..') + starkware_src = os.path.join(os.path.dirname(__file__), "../../../..") cairo_path = [ os.path.abspath(path) for path in cairo_path + [os.curdir, starkware_src] - if path is not None and os.path.isdir(path)] + if path is not None and os.path.isdir(path) + ] - return ModuleReader(paths=cairo_path, cairo_suffix='.cairo') + return ModuleReader(paths=cairo_path, cairo_suffix=".cairo") def get_codes(file_names: List[str]) -> List[Tuple[str, str]]: """ Returns a list of pairs (file_content, file_name). """ - codes = (open(path).read() if path != '-' else sys.stdin.read() for path in file_names) + codes = (open(path).read() if path != "-" else sys.stdin.read() for path in file_names) codes_with_filenames = list(zip(codes, file_names)) return codes_with_filenames @@ -136,33 +162,44 @@ def add_start_code(codes_with_filenames: List[Tuple[str, str]]) -> List[Tuple[st def compile_cairo_files( - files: List[str], prime: Optional[int] = None, - cairo_path: List[str] = [], debug_info: bool = False, - pass_manager: Optional[PassManager] = None, - main_scope: Optional[ScopedName] = None) -> Program: + files: List[str], + prime: Optional[int] = None, + cairo_path: List[str] = [], + debug_info: bool = False, + pass_manager: Optional[PassManager] = None, + main_scope: Optional[ScopedName] = None, +) -> Program: """ Compiles a list of files (provided by their names). Note that cairo_path is ignored when reading the input files, it is only used when importing modules. """ return compile_cairo( - code=get_codes(files), prime=prime, cairo_path=cairo_path, debug_info=debug_info, - pass_manager=pass_manager, main_scope=main_scope) + code=get_codes(files), + prime=prime, + cairo_path=cairo_path, + debug_info=debug_info, + pass_manager=pass_manager, + main_scope=main_scope, + ) def compile_cairo_ex( - code: Union[str, Sequence[Tuple[str, str]]], prime: Optional[int] = None, - cairo_path: List[str] = [], debug_info: bool = False, - pass_manager: Optional[PassManager] = None, - add_start: bool = False, main_scope: Optional[ScopedName] = None) -> \ - Tuple[Program, PreprocessedProgram]: + code: Union[str, Sequence[Tuple[str, str]]], + prime: Optional[int] = None, + cairo_path: List[str] = [], + debug_info: bool = False, + pass_manager: Optional[PassManager] = None, + add_start: bool = False, + main_scope: Optional[ScopedName] = None, +) -> Tuple[Program, PreprocessedProgram]: """ Same as compile_cairo, but returns the preprocessed program as well. """ file_contents_for_debug_info = {} if isinstance(code, str): - codes_with_filenames = [(code, '')] + codes_with_filenames = [(code, "")] if isinstance(code, list): codes_with_filenames = code @@ -174,31 +211,37 @@ def compile_cairo_ex( file_contents_for_debug_info[START_FILE_NAME] = codes_with_filenames[0][0] if pass_manager is None: - assert prime is not None, 'Exactly one of prime and pass_manager must be given.' + assert prime is not None, "Exactly one of prime and pass_manager must be given." module_reader = get_module_reader(cairo_path) pass_manager = default_pass_manager(prime=prime, read_module=module_reader.read) else: - assert prime is None, 'Exactly one of prime and pass_manager must be given.' - assert len(cairo_path) == 0, 'cairo_path cannot be specified where pass_manager is used.' + assert prime is None, "Exactly one of prime and pass_manager must be given." + assert len(cairo_path) == 0, "cairo_path cannot be specified where pass_manager is used." if main_scope is None: main_scope = MAIN_SCOPE preprocessed_program = preprocess_codes( - codes=codes_with_filenames, - pass_manager=pass_manager, - main_scope=main_scope) + codes=codes_with_filenames, pass_manager=pass_manager, main_scope=main_scope + ) program = cairo_assemble_program( - preprocessed_program, main_scope=main_scope, add_debug_info=debug_info, - file_contents_for_debug_info=file_contents_for_debug_info) + preprocessed_program, + main_scope=main_scope, + add_debug_info=debug_info, + file_contents_for_debug_info=file_contents_for_debug_info, + ) return program, preprocessed_program def compile_cairo( - code: Union[str, Sequence[Tuple[str, str]]], prime: Optional[int] = None, - cairo_path: List[str] = [], debug_info: bool = False, - pass_manager: Optional[PassManager] = None, - add_start: bool = False, main_scope: Optional[ScopedName] = None) -> Program: + code: Union[str, Sequence[Tuple[str, str]]], + prime: Optional[int] = None, + cairo_path: List[str] = [], + debug_info: bool = False, + pass_manager: Optional[PassManager] = None, + add_start: bool = False, + main_scope: Optional[ScopedName] = None, +) -> Program: """ Compiles a single code represented by a string, or a list codes. The codes in the list are joined with file names, used for indicative @@ -206,8 +249,14 @@ def compile_cairo( Returns the program. """ program, _ = compile_cairo_ex( - code=code, prime=prime, cairo_path=cairo_path, debug_info=debug_info, - pass_manager=pass_manager, add_start=add_start, main_scope=main_scope) + code=code, + prime=prime, + cairo_path=cairo_path, + debug_info=debug_info, + pass_manager=pass_manager, + add_start=add_start, + main_scope=main_scope, + ) return program @@ -216,36 +265,47 @@ def check_main_args(program: Program): Makes sure that for every builtin included in the program an appropriate ptr was passed as an argument to main() and is subsequently returned. """ - expected_builtin_ptrs = [f'{builtin_name}_ptr' for builtin_name in program.builtins] + expected_builtin_ptrs = [f"{builtin_name}_ptr" for builtin_name in program.builtins] try: - implicit_args = list(get_struct_definition( - struct_name=ScopedName.from_string('__main__.main.ImplicitArgs'), - identifier_manager=program.identifiers).members) + implicit_args = list( + get_struct_definition( + struct_name=ScopedName.from_string("__main__.main.ImplicitArgs"), + identifier_manager=program.identifiers, + ).members + ) except IdentifierError: return try: - main_args = implicit_args + list(get_struct_definition( - struct_name=ScopedName.from_string('__main__.main.Args'), - identifier_manager=program.identifiers).members) + main_args = implicit_args + list( + get_struct_definition( + struct_name=ScopedName.from_string("__main__.main.Args"), + identifier_manager=program.identifiers, + ).members + ) except IdentifierError: pass else: - assert main_args == expected_builtin_ptrs, \ - 'Expected main to contain the following arguments (in this order): ' \ - f'{expected_builtin_ptrs}. Found: {main_args}.' + assert main_args == expected_builtin_ptrs, ( + "Expected main to contain the following arguments (in this order): " + f"{expected_builtin_ptrs}. Found: {main_args}." + ) try: - main_returns = implicit_args + list(get_struct_definition( - struct_name=ScopedName.from_string('__main__.main.Return'), - identifier_manager=program.identifiers).members) + main_returns = implicit_args + list( + get_struct_definition( + struct_name=ScopedName.from_string("__main__.main.Return"), + identifier_manager=program.identifiers, + ).members + ) except IdentifierError: pass else: - assert main_returns == expected_builtin_ptrs, \ - 'Expected main to return the following values (in this order): ' \ - f'{expected_builtin_ptrs}. Found: {main_returns}.' + assert main_returns == expected_builtin_ptrs, ( + "Expected main to return the following values (in this order): " + f"{expected_builtin_ptrs}. Found: {main_returns}." + ) def get_start_code(): @@ -265,11 +325,11 @@ def get_start_code(): def generate_cairo_dependencies_file(dependencies_path: str, files: Set[str], start_time): # Generate Cairo dependencies. - res = '' - res += 'SET (DEPENDENCIES\n' + res = "" + res += "SET (DEPENDENCIES\n" for filename in sorted(files): - res += filename + '\n' - res += ')\n' + res += filename + "\n" + res += ")\n" try: if open(dependencies_path).read() == res: @@ -278,7 +338,7 @@ def generate_cairo_dependencies_file(dependencies_path: str, files: Set[str], st except FileNotFoundError: pass - with open(dependencies_path, 'w') as dependencies_file: + with open(dependencies_path, "w") as dependencies_file: dependencies_file.write(res) # Change the modification time of the file to make sure it is older than the generated @@ -287,42 +347,57 @@ def generate_cairo_dependencies_file(dependencies_path: str, files: Set[str], st def cairo_assemble_program( - preprocessed_program: PreprocessedProgram, main_scope: ScopedName, - add_debug_info: bool, file_contents_for_debug_info: Dict[str, str]) -> Program: + preprocessed_program: PreprocessedProgram, + main_scope: ScopedName, + add_debug_info: bool, + file_contents_for_debug_info: Dict[str, str], +) -> Program: program = assemble( - preprocessed_program, main_scope=MAIN_SCOPE, add_debug_info=add_debug_info, - file_contents_for_debug_info=file_contents_for_debug_info) + preprocessed_program, + main_scope=MAIN_SCOPE, + add_debug_info=add_debug_info, + file_contents_for_debug_info=file_contents_for_debug_info, + ) check_main_args(program) return program def main(): - parser = argparse.ArgumentParser(description='A tool to compile Cairo code.') + parser = argparse.ArgumentParser(description="A tool to compile Cairo code.") parser.add_argument( - '--proof_mode', action='store_true', default=False, - help='Add instructions to call main() at the beginning of the program. This should be used ' - 'if the program is proven directly (without the bootloader).') + "--proof_mode", + action="store_true", + default=False, + help="Add instructions to call main() at the beginning of the program. This should be used " + "if the program is proven directly (without the bootloader).", + ) parser.add_argument( - '--no_proof_mode', dest='proof_mode', action='store_false', - help='Disable proof mode (see --proof_mode).') + "--no_proof_mode", + dest="proof_mode", + action="store_false", + help="Disable proof mode (see --proof_mode).", + ) def pass_manager_factory(args: argparse.Namespace, module_reader: ModuleReader) -> PassManager: return default_pass_manager( prime=args.prime, read_module=module_reader.read, - opt_unused_functions=args.opt_unused_functions) + opt_unused_functions=args.opt_unused_functions, + ) try: cairo_compile_add_common_args(parser) args = parser.parse_args() cairo_compile_common( - args=args, pass_manager_factory=pass_manager_factory, - assemble_func=cairo_assemble_program) + args=args, + pass_manager_factory=pass_manager_factory, + assemble_func=cairo_assemble_program, + ) except LocationError as err: print(err, file=sys.stderr) return 1 return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/lang/compiler/cairo_compile_test.py b/src/starkware/cairo/lang/compiler/cairo_compile_test.py index f8657bf9..083e972a 100644 --- a/src/starkware/cairo/lang/compiler/cairo_compile_test.py +++ b/src/starkware/cairo/lang/compiler/cairo_compile_test.py @@ -12,8 +12,10 @@ def test_main_args_match_builtins(): Checks that an appropriate exception is thrown if the arguments in a given Cairo program's main don't match the list of builtins specified by the directive (and their order). """ - expected_error_msg = 'Expected main to contain the following arguments (in this order): ' \ + expected_error_msg = ( + "Expected main to contain the following arguments (in this order): " "['output_ptr', 'range_check_ptr']" + ) with pytest.raises(AssertionError, match=re.escape(expected_error_msg)): compile_cairo( code=""" @@ -22,7 +24,9 @@ def test_main_args_match_builtins(): func main(output_ptr) -> (output_ptr): return (output_ptr=output_ptr + 1) end -""", prime=PRIME) +""", + prime=PRIME, + ) # Check that even if all builtin ptrs were passed as arguments but in the wrong order then # the same exception is thrown. @@ -34,7 +38,9 @@ def test_main_args_match_builtins(): func main(range_check_ptr, output_ptr) -> (range_check_ptr, output_ptr): return (range_check_ptr + 1, output_ptr=output_ptr + 1) end -""", prime=PRIME) +""", + prime=PRIME, + ) def test_main_return_match_builtins(): @@ -42,8 +48,10 @@ def test_main_return_match_builtins(): Checks that an appropriate exception is thrown if the arguments in a given Cairo program's main don't match the list of builtins specified by the directive (and their order). """ - expected_error_msg = 'Expected main to return the following values (in this order): ' \ + expected_error_msg = ( + "Expected main to return the following values (in this order): " "['output_ptr', 'range_check_ptr']" + ) with pytest.raises(AssertionError, match=re.escape(expected_error_msg)): compile_cairo( code=""" @@ -52,4 +60,6 @@ def test_main_return_match_builtins(): func main(output_ptr, range_check_ptr) -> (output_ptr): return (output_ptr=output_ptr + 1) end -""", prime=PRIME) +""", + prime=PRIME, + ) diff --git a/src/starkware/cairo/lang/compiler/cairo_format.py b/src/starkware/cairo/lang/compiler/cairo_format.py index 976281f5..b197a0fe 100644 --- a/src/starkware/cairo/lang/compiler/cairo_format.py +++ b/src/starkware/cairo/lang/compiler/cairo_format.py @@ -6,39 +6,39 @@ def main(): - parser = argparse.ArgumentParser( - description='A tool to automatically format Cairo code.') - parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') - parser.add_argument('files', metavar='file', type=str, nargs='+', help='File names') + parser = argparse.ArgumentParser(description="A tool to automatically format Cairo code.") + parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument("files", metavar="file", type=str, nargs="+", help="File names") action = parser.add_mutually_exclusive_group(required=False) - action.add_argument('-i', dest='inplace', action='store_true', help='Edit files inplace.') - action.add_argument('-c', dest='check', action='store_true', help='Check files\' formats.') + action.add_argument("-i", dest="inplace", action="store_true", help="Edit files inplace.") + action.add_argument("-c", dest="check", action="store_true", help="Check files' formats.") args = parser.parse_args() return_code = 0 for path in args.files: - old_content = open(path).read() if path != '-' else sys.stdin.read() + old_content = open(path).read() if path != "-" else sys.stdin.read() try: new_content = parse_file( - old_content, filename='' if path == '-' else path).format() + old_content, filename="" if path == "-" else path + ).format() except Exception as exc: print(exc, file=sys.stderr) return 2 if args.inplace: - assert path != '-', 'Using "-i" together with "-" is not supported.' - open(path, 'w').write(new_content) + assert path != "-", 'Using "-i" together with "-" is not supported.' + open(path, "w").write(new_content) elif args.check: - assert path != '-', 'Using "-c" together with "-" is not supported.' + assert path != "-", 'Using "-c" together with "-" is not supported.' if old_content != new_content: print(f'File "{path}" is incorrectly formatted.', file=sys.stderr) return_code = 1 else: - print(new_content, end='') + print(new_content, end="") return return_code -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/lang/compiler/conftest.py b/src/starkware/cairo/lang/compiler/conftest.py index 5e0ea869..62304241 100644 --- a/src/starkware/cairo/lang/compiler/conftest.py +++ b/src/starkware/cairo/lang/compiler/conftest.py @@ -3,4 +3,4 @@ # Instruct pytest to print full information (e.g., the values on both sides of the equality) # about asserts that failed in the module below. # Normally, pytest prints full information only for test files (according to their name). -pytest.register_assert_rewrite('starkware.cairo.lang.compiler.parser_test_utils') +pytest.register_assert_rewrite("starkware.cairo.lang.compiler.parser_test_utils") diff --git a/src/starkware/cairo/lang/compiler/const_expr_checker.py b/src/starkware/cairo/lang/compiler/const_expr_checker.py index 40148f03..4bd8abed 100644 --- a/src/starkware/cairo/lang/compiler/const_expr_checker.py +++ b/src/starkware/cairo/lang/compiler/const_expr_checker.py @@ -1,5 +1,12 @@ from starkware.cairo.lang.compiler.ast.expr import ( - ExprConst, ExprDeref, ExprFutureLabel, ExprNeg, ExprOperator, ExprPyConst, ExprReg) + ExprConst, + ExprDeref, + ExprFutureLabel, + ExprNeg, + ExprOperator, + ExprPyConst, + ExprReg, +) class ConstExprChecker: @@ -9,7 +16,7 @@ class ConstExprChecker: """ def visit(self, obj): - return getattr(self, f'visit_{type(obj).__name__}')(obj) + return getattr(self, f"visit_{type(obj).__name__}")(obj) def visit_ExprConst(self, expr: ExprConst): return True diff --git a/src/starkware/cairo/lang/compiler/constants.py b/src/starkware/cairo/lang/compiler/constants.py index 228b9db9..825425a9 100644 --- a/src/starkware/cairo/lang/compiler/constants.py +++ b/src/starkware/cairo/lang/compiler/constants.py @@ -1,6 +1,6 @@ from starkware.cairo.lang.compiler.scoped_name import ScopedName -LIBS_DIR_ENVVAR = 'CAIRO_PATH' -MAIN_SCOPE = ScopedName.from_string('__main__') -START_FILE_NAME = '' -SIZE_CONSTANT = ScopedName.from_string('SIZE') +LIBS_DIR_ENVVAR = "CAIRO_PATH" +MAIN_SCOPE = ScopedName.from_string("__main__") +START_FILE_NAME = "" +SIZE_CONSTANT = ScopedName.from_string("SIZE") diff --git a/src/starkware/cairo/lang/compiler/debug_info.py b/src/starkware/cairo/lang/compiler/debug_info.py index 818b2c3c..ddbe5236 100644 --- a/src/starkware/cairo/lang/compiler/debug_info.py +++ b/src/starkware/cairo/lang/compiler/debug_info.py @@ -23,7 +23,8 @@ class InstructionLocation: inst: Location hints: List[Optional[HintLocation]] accessible_scopes: List[ScopedName] = field( - metadata=dict(marshmallow_field=mfields.List(ScopedNameAsStr))) + metadata=dict(marshmallow_field=mfields.List(ScopedNameAsStr)) + ) flow_tracking_data: FlowTrackingDataActual @@ -64,14 +65,16 @@ def add_autogen_file_contents(self): for loc in instruction_location.get_all_locations(): input_file = loc.input_file is_autogen = ( - input_file.filename is not None and - input_file.filename.startswith('autogen/') and - input_file.content is not None) + input_file.filename is not None + and input_file.filename.startswith("autogen/") + and input_file.content is not None + ) if not is_autogen: continue if input_file.filename in self.file_contents: - assert self.file_contents[input_file.filename] == input_file.content, \ - f'Found two versions of auto-generated file "{input_file.filename}":\n' \ - f'{input_file.content}\n\n\n{self.file_contents[input_file.filename]}' + assert self.file_contents[input_file.filename] == input_file.content, ( + f'Found two versions of auto-generated file "{input_file.filename}":\n' + f"{input_file.content}\n\n\n{self.file_contents[input_file.filename]}" + ) else: self.file_contents[input_file.filename] = input_file.content diff --git a/src/starkware/cairo/lang/compiler/debug_info_test.py b/src/starkware/cairo/lang/compiler/debug_info_test.py index 9f5d6452..b3043a60 100644 --- a/src/starkware/cairo/lang/compiler/debug_info_test.py +++ b/src/starkware/cairo/lang/compiler/debug_info_test.py @@ -9,25 +9,34 @@ def dummy_instruction_location(filename: str, content: Optional[str]) -> InstructionLocation: location = Location( - start_line=1, start_col=2, end_line=3, end_col=4, - input_file=InputFile(filename=filename, content=content)) + start_line=1, + start_col=2, + end_line=3, + end_col=4, + input_file=InputFile(filename=filename, content=content), + ) return InstructionLocation( - inst=location, hints=[], accessible_scopes=[], - flow_tracking_data=FlowTrackingDataActual.new(lambda: 0)) + inst=location, + hints=[], + accessible_scopes=[], + flow_tracking_data=FlowTrackingDataActual.new(lambda: 0), + ) def test_autogen_files(): - inst_location0 = dummy_instruction_location('autogen/1', 'content 1') - inst_location1 = dummy_instruction_location('not/autogen/2', 'content 2') - inst_location2 = dummy_instruction_location('autogen/3', None) + inst_location0 = dummy_instruction_location("autogen/1", "content 1") + inst_location1 = dummy_instruction_location("not/autogen/2", "content 2") + inst_location2 = dummy_instruction_location("autogen/3", None) debug_info = DebugInfo( - instruction_locations={0: inst_location0, 1: inst_location1, 2: inst_location2}) + instruction_locations={0: inst_location0, 1: inst_location1, 2: inst_location2} + ) debug_info.add_autogen_file_contents() - assert debug_info.file_contents == {'autogen/1': 'content 1'} + assert debug_info.file_contents == {"autogen/1": "content 1"} # Create a location to the same file name, with a different content. - mismatch_location = dummy_instruction_location('autogen/1', 'a different content') + mismatch_location = dummy_instruction_location("autogen/1", "a different content") debug_info = DebugInfo(instruction_locations={0: inst_location0, 1: mismatch_location}) with pytest.raises( - AssertionError, match='Found two versions of auto-generated file "autogen/1"'): + AssertionError, match='Found two versions of auto-generated file "autogen/1"' + ): debug_info.add_autogen_file_contents() diff --git a/src/starkware/cairo/lang/compiler/encode.py b/src/starkware/cairo/lang/compiler/encode.py index e63d9687..64270595 100644 --- a/src/starkware/cairo/lang/compiler/encode.py +++ b/src/starkware/cairo/lang/compiler/encode.py @@ -1,7 +1,11 @@ from typing import List, Optional from starkware.cairo.lang.compiler.instruction import ( - OFFSET_BITS, Instruction, Register, decode_instruction_values) + OFFSET_BITS, + Instruction, + Register, + decode_instruction_values, +) DST_REG_BIT = 0 OP0_REG_BIT = 1 @@ -26,12 +30,15 @@ def encode_instruction(inst: Instruction, prime: int) -> List[int]: Given an Instruction, returns a list of 1 or 2 integers representing the instruction. """ assert prime > 2 ** (3 * OFFSET_BITS + 16) - assert -2 ** (OFFSET_BITS - 1) <= inst.off0 < 2 ** (OFFSET_BITS - 1), \ - f'off0 must be in range [-2**{OFFSET_BITS - 1}, 2**{OFFSET_BITS - 1})' - assert -2 ** (OFFSET_BITS - 1) <= inst.off1 < 2 ** (OFFSET_BITS - 1), \ - f'off1 must be in range [-2**{OFFSET_BITS - 1}, 2**{OFFSET_BITS - 1})' - assert -2 ** (OFFSET_BITS - 1) <= inst.off2 < 2 ** (OFFSET_BITS - 1), \ - f'off2 must be in range [-2**{OFFSET_BITS - 1}, 2**{OFFSET_BITS - 1})' + assert ( + -(2 ** (OFFSET_BITS - 1)) <= inst.off0 < 2 ** (OFFSET_BITS - 1) + ), f"off0 must be in range [-2**{OFFSET_BITS - 1}, 2**{OFFSET_BITS - 1})" + assert ( + -(2 ** (OFFSET_BITS - 1)) <= inst.off1 < 2 ** (OFFSET_BITS - 1) + ), f"off1 must be in range [-2**{OFFSET_BITS - 1}, 2**{OFFSET_BITS - 1})" + assert ( + -(2 ** (OFFSET_BITS - 1)) <= inst.off2 < 2 ** (OFFSET_BITS - 1) + ), f"off2 must be in range [-2**{OFFSET_BITS - 1}, 2**{OFFSET_BITS - 1})" off0_enc = inst.off0 + 2 ** (OFFSET_BITS - 1) off1_enc = inst.off1 + 2 ** (OFFSET_BITS - 1) off2_enc = inst.off2 + 2 ** (OFFSET_BITS - 1) @@ -46,14 +53,15 @@ def encode_instruction(inst: Instruction, prime: int) -> List[int]: flags |= (1 << OP0_REG_BIT) if inst.op0_register is Register.FP else 0 # Set op1_addr. - assert (inst.imm is not None) == (inst.op1_addr is Instruction.Op1Addr.IMM), \ - 'Immediate must appear iff op1_addr is Op1Addr.IMM' + assert (inst.imm is not None) == ( + inst.op1_addr is Instruction.Op1Addr.IMM + ), "Immediate must appear iff op1_addr is Op1Addr.IMM" flags |= { Instruction.Op1Addr.IMM: 1 << OP1_IMM_BIT, Instruction.Op1Addr.AP: 1 << OP1_AP_BIT, Instruction.Op1Addr.FP: 1 << OP1_FP_BIT, - Instruction.Op1Addr.OP0: 0 + Instruction.Op1Addr.OP0: 0, }[inst.op1_addr] # Set res. @@ -63,9 +71,9 @@ def encode_instruction(inst: Instruction, prime: int) -> List[int]: Instruction.Res.OP1: 0, Instruction.Res.UNCONSTRAINED: 0, }[inst.res] - assert (inst.res is Instruction.Res.UNCONSTRAINED) == \ - (inst.pc_update == Instruction.PcUpdate.JNZ), \ - 'res must be UNCONSTRAINED iff pc_update is JNZ' + assert (inst.res is Instruction.Res.UNCONSTRAINED) == ( + inst.pc_update == Instruction.PcUpdate.JNZ + ), "res must be UNCONSTRAINED iff pc_update is JNZ" # Set pc_update. flags |= { @@ -76,9 +84,9 @@ def encode_instruction(inst: Instruction, prime: int) -> List[int]: }[inst.pc_update] # Set ap_update. - assert (inst.ap_update is Instruction.ApUpdate.ADD2) == \ - (inst.opcode is Instruction.Opcode.CALL), \ - 'ap_update is ADD2 iff opcode is CALL' + assert (inst.ap_update is Instruction.ApUpdate.ADD2) == ( + inst.opcode is Instruction.Opcode.CALL + ), "ap_update is ADD2 iff opcode is CALL" flags |= { Instruction.ApUpdate.ADD: 1 << AP_ADD_BIT, Instruction.ApUpdate.ADD1: 1 << AP_ADD1_BIT, @@ -93,7 +101,7 @@ def encode_instruction(inst: Instruction, prime: int) -> List[int]: Instruction.Opcode.CALL: Instruction.FpUpdate.AP_PLUS2, Instruction.Opcode.RET: Instruction.FpUpdate.DST, Instruction.Opcode.ASSERT_EQ: Instruction.FpUpdate.REGULAR, - }[inst.opcode], f'fp_update {inst.fp_update} does not match opcode f{inst.opcode}' + }[inst.opcode], f"fp_update {inst.fp_update} does not match opcode f{inst.opcode}" # Set opcode. flags |= { @@ -132,11 +140,11 @@ def decode_instruction(encoding: int, imm: Optional[int] = None) -> Instruction: (1, 0, 0): Instruction.Op1Addr.IMM, (0, 1, 0): Instruction.Op1Addr.AP, (0, 0, 1): Instruction.Op1Addr.FP, - (0, 0, 0): Instruction.Op1Addr.OP0 + (0, 0, 0): Instruction.Op1Addr.OP0, }[(flags >> OP1_IMM_BIT) & 1, (flags >> OP1_AP_BIT) & 1, (flags >> OP1_FP_BIT) & 1] if op1_addr is Instruction.Op1Addr.IMM: - assert imm is not None, 'op1_addr is Op1Addr.IMM, but no immediate given' + assert imm is not None, "op1_addr is Op1Addr.IMM, but no immediate given" else: imm = None @@ -145,16 +153,16 @@ def decode_instruction(encoding: int, imm: Optional[int] = None) -> Instruction: (1, 0, 0): Instruction.PcUpdate.JUMP, (0, 1, 0): Instruction.PcUpdate.JUMP_REL, (0, 0, 1): Instruction.PcUpdate.JNZ, - (0, 0, 0): Instruction.PcUpdate.REGULAR + (0, 0, 0): Instruction.PcUpdate.REGULAR, }[(flags >> PC_JUMP_ABS_BIT) & 1, (flags >> PC_JUMP_REL_BIT) & 1, (flags >> PC_JNZ_BIT) & 1] # Get res. res = { (1, 0): Instruction.Res.ADD, (0, 1): Instruction.Res.MUL, - (0, 0): - Instruction.Res.UNCONSTRAINED if pc_update is Instruction.PcUpdate.JNZ - else Instruction.Res.OP1, + (0, 0): Instruction.Res.UNCONSTRAINED + if pc_update is Instruction.PcUpdate.JNZ + else Instruction.Res.OP1, }[(flags >> RES_ADD_BIT) & 1, (flags >> RES_MUL_BIT) & 1] # JNZ opcode means res must be UNCONSTRAINED. @@ -173,14 +181,16 @@ def decode_instruction(encoding: int, imm: Optional[int] = None) -> Instruction: (1, 0, 0): Instruction.Opcode.CALL, (0, 1, 0): Instruction.Opcode.RET, (0, 0, 1): Instruction.Opcode.ASSERT_EQ, - (0, 0, 0): Instruction.Opcode.NOP + (0, 0, 0): Instruction.Opcode.NOP, }[ - (flags >> OPCODE_CALL_BIT) & 1, (flags >> OPCODE_RET_BIT) & 1, - (flags >> OPCODE_ASSERT_EQ_BIT) & 1] + (flags >> OPCODE_CALL_BIT) & 1, + (flags >> OPCODE_RET_BIT) & 1, + (flags >> OPCODE_ASSERT_EQ_BIT) & 1, + ] # CALL opcode means ap_update must be ADD2. if opcode is Instruction.Opcode.CALL: - assert ap_update is Instruction.ApUpdate.REGULAR, 'CALL must have update_ap is ADD2' + assert ap_update is Instruction.ApUpdate.REGULAR, "CALL must have update_ap is ADD2" ap_update = Instruction.ApUpdate.ADD2 # Get fp_update. @@ -194,7 +204,7 @@ def decode_instruction(encoding: int, imm: Optional[int] = None) -> Instruction: return Instruction( off0=off0_enc - 2 ** (OFFSET_BITS - 1), off1=off1_enc - 2 ** (OFFSET_BITS - 1), - off2=off2_enc - 2**(OFFSET_BITS - 1), + off2=off2_enc - 2 ** (OFFSET_BITS - 1), imm=imm, dst_register=dst_register, op0_register=op0_register, @@ -216,9 +226,9 @@ def is_call_instruction(encoded_instruction: int, imm: Optional[int]): except Exception: return False return ( - instruction.res is Instruction.Res.OP1 and - instruction.pc_update in [Instruction.PcUpdate.JUMP, Instruction.PcUpdate.JUMP_REL] and - instruction.ap_update is Instruction.ApUpdate.ADD2 and - instruction.fp_update is Instruction.FpUpdate.AP_PLUS2 and - instruction.opcode is Instruction.Opcode.CALL + instruction.res is Instruction.Res.OP1 + and instruction.pc_update in [Instruction.PcUpdate.JUMP, Instruction.PcUpdate.JUMP_REL] + and instruction.ap_update is Instruction.ApUpdate.ADD2 + and instruction.fp_update is Instruction.FpUpdate.AP_PLUS2 + and instruction.opcode is Instruction.Opcode.CALL ) diff --git a/src/starkware/cairo/lang/compiler/encode_test.py b/src/starkware/cairo/lang/compiler/encode_test.py index 16db1b10..48d2ced9 100644 --- a/src/starkware/cairo/lang/compiler/encode_test.py +++ b/src/starkware/cairo/lang/compiler/encode_test.py @@ -1,16 +1,19 @@ import dataclasses from starkware.cairo.lang.compiler.encode import ( - decode_instruction, encode_instruction, is_call_instruction) + decode_instruction, + encode_instruction, + is_call_instruction, +) from starkware.cairo.lang.compiler.instruction import Instruction, Register from starkware.cairo.lang.compiler.instruction_builder import build_instruction from starkware.cairo.lang.compiler.parser import parse_instruction -PRIME = 2**64 + 13 +PRIME = 2 ** 64 + 13 def test_assert_eq(): - encoded = [0x480680017fff8000, 1] + encoded = [0x480680017FFF8000, 1] instruction = Instruction( off0=0, off1=-1, @@ -25,20 +28,20 @@ def test_assert_eq(): fp_update=Instruction.FpUpdate.REGULAR, opcode=Instruction.Opcode.ASSERT_EQ, ) - assert build_instruction(parse_instruction('[ap] = 1; ap++')) == instruction + assert build_instruction(parse_instruction("[ap] = 1; ap++")) == instruction assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction # Remove "ap++". instruction = dataclasses.replace(instruction, ap_update=Instruction.ApUpdate.REGULAR) - encoded = [0x400680017fff8000, 1] + encoded = [0x400680017FFF8000, 1] assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(*encoded) is False def test_jmp(): - encoded = [0x0129800080027fff] + encoded = [0x0129800080027FFF] instruction = Instruction( off0=-1, off1=2, @@ -53,20 +56,20 @@ def test_jmp(): fp_update=Instruction.FpUpdate.REGULAR, opcode=Instruction.Opcode.NOP, ) - assert build_instruction(parse_instruction('jmp rel [ap + 2] + [fp]')) == instruction + assert build_instruction(parse_instruction("jmp rel [ap + 2] + [fp]")) == instruction assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction # Change to jmp abs. instruction = dataclasses.replace(instruction, pc_update=Instruction.PcUpdate.JUMP) - encoded = [0x00a9800080027fff] + encoded = [0x00A9800080027FFF] assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(encoded[0], None) is False def test_jnz(): - encoded = [0x020a7ff07fff8003] + encoded = [0x020A7FF07FFF8003] instruction = Instruction( off0=3, off1=-1, @@ -81,7 +84,7 @@ def test_jnz(): fp_update=Instruction.FpUpdate.REGULAR, opcode=Instruction.Opcode.NOP, ) - assert build_instruction(parse_instruction('jmp rel [fp - 16] if [ap + 3] != 0')) == instruction + assert build_instruction(parse_instruction("jmp rel [fp - 16] if [ap + 3] != 0")) == instruction assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(encoded[0], None) is False @@ -103,14 +106,14 @@ def test_call(): fp_update=Instruction.FpUpdate.AP_PLUS2, opcode=Instruction.Opcode.CALL, ) - assert build_instruction(parse_instruction('call rel 1234')) == instruction + assert build_instruction(parse_instruction("call rel 1234")) == instruction assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(*encoded) is True def test_ret(): - encoded = [0x208b7fff7fff7ffe] + encoded = [0x208B7FFF7FFF7FFE] instruction = Instruction( off0=-2, off1=-1, @@ -125,14 +128,14 @@ def test_ret(): fp_update=Instruction.FpUpdate.DST, opcode=Instruction.Opcode.RET, ) - assert build_instruction(parse_instruction('ret')) == instruction + assert build_instruction(parse_instruction("ret")) == instruction assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(encoded[0], None) is False def test_addap(): - encoded = [0x40780017fff7fff, 123] + encoded = [0x40780017FFF7FFF, 123] instruction = Instruction( off0=-1, off1=-1, @@ -145,8 +148,9 @@ def test_addap(): pc_update=Instruction.PcUpdate.REGULAR, ap_update=Instruction.ApUpdate.ADD, fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert build_instruction(parse_instruction('ap += 123')) == instruction + opcode=Instruction.Opcode.NOP, + ) + assert build_instruction(parse_instruction("ap += 123")) == instruction assert encode_instruction(instruction, prime=PRIME) == encoded assert decode_instruction(*encoded) == instruction assert is_call_instruction(*encoded) is False diff --git a/src/starkware/cairo/lang/compiler/error_handling.py b/src/starkware/cairo/lang/compiler/error_handling.py index ce293a7c..f5d4d5eb 100644 --- a/src/starkware/cairo/lang/compiler/error_handling.py +++ b/src/starkware/cairo/lang/compiler/error_handling.py @@ -16,12 +16,12 @@ def get_content(self) -> str: If the content member is not set, it will be updated. """ if self.content is None: - assert self.filename is not None, 'Content must be set if filename is None.' - self.content = open(self.filename, 'r').read() + assert self.filename is not None, "Content must be set if filename is None." + self.content = open(self.filename, "r").read() return self.content -ParentLocation = Tuple['Location', str] +ParentLocation = Tuple["Location", str] @dataclasses.dataclass(frozen=True) @@ -36,15 +36,17 @@ class Location: # expansion type, such as "While expanding the reference 'x'...". parent_location: Optional[ParentLocation] = None - def with_parent_location(self, new_parent_location: 'Location', message: str): + def with_parent_location(self, new_parent_location: "Location", message: str): if self.parent_location is None: return dataclasses.replace(self, parent_location=(new_parent_location, message)) else: old_self_parent_location, self_parent_location_message = self.parent_location new_self_parent_location = old_self_parent_location.with_parent_location( - new_parent_location, message) - return dataclasses.replace(self, parent_location=( - new_self_parent_location, self_parent_location_message)) + new_parent_location, message + ) + return dataclasses.replace( + self, parent_location=(new_self_parent_location, self_parent_location_message) + ) def topmost_location(self): """ @@ -57,30 +59,27 @@ def topmost_location(self): @post_dump def remove_none_values(self, data, many=False): - return { - key: value for key, value in data.items() - if value is not None - } + return {key: value for key, value in data.items() if value is not None} - def to_string(self, message: str = ''): + def to_string(self, message: str = ""): """ Prints the location with the passed message. """ input_file = self.input_file line = self.start_line col = self.start_col - filename = '' if input_file.filename is None else input_file.filename - message_prefix = ': ' if len(message) > 0 else '' - return f'{filename}:{line}:{col}{message_prefix}{message}' + filename = "" if input_file.filename is None else input_file.filename + message_prefix = ": " if len(message) > 0 else "" + return f"{filename}:{line}:{col}{message_prefix}{message}" - def to_string_with_content(self, message: str = ''): + def to_string_with_content(self, message: str = ""): """ Prints the location with the passed message, including the content of the line and the location marks. """ first_line = self.to_string(message=message) content = self.input_file.get_content() - return first_line + '\n' + get_location_marks(content, self) + return first_line + "\n" + get_location_marks(content, self) def __repr__(self): return self.to_string() @@ -102,16 +101,16 @@ def get_location_marks(content: str, location: Location): # The location does not refer to a valid location in the source file. This may happen when # the file is changed after compilation. # Don't return location marks in this case. - return '' + return "" start_line = lines[location.start_line - 1] start_col = location.start_col - res = start_line + '\n' + res = start_line + "\n" end_col = location.end_col if location.start_line == location.end_line else len(start_line) + 1 if end_col > start_col + 1: - res += ' ' * (start_col - 1) + '^' + '*' * (end_col - start_col - 2) + '^' + res += " " * (start_col - 1) + "^" + "*" * (end_col - start_col - 2) + "^" else: - res += ' ' * (start_col - 1) + '^' + res += " " * (start_col - 1) + "^" return res @@ -121,8 +120,12 @@ class LocationError(Exception): """ def __init__( - self, message, location: Optional[Location], traceback: Optional[str] = None, - notes: Optional[List[str]] = None): + self, + message, + location: Optional[Location], + traceback: Optional[str] = None, + notes: Optional[List[str]] = None, + ): super().__init__(message, location) self.message = message self.location = location @@ -131,21 +134,21 @@ def __init__( def __str__(self): if self.location is None: - res = self.message + '\n' + res = self.message + "\n" else: - res = '' + res = "" location, message = self.location, self.message while True: - res = location.to_string_with_content(message) + '\n' + res + res = location.to_string_with_content(message) + "\n" + res if location.parent_location is None: break location, message = location.parent_location if self.traceback is not None: - res += self.traceback + '\n' + res += self.traceback + "\n" # Add notes. for note in self.notes: - res += note + '\n' + res += note + "\n" return res.rstrip() diff --git a/src/starkware/cairo/lang/compiler/error_handling_test.py b/src/starkware/cairo/lang/compiler/error_handling_test.py index 47434b60..bf27182a 100644 --- a/src/starkware/cairo/lang/compiler/error_handling_test.py +++ b/src/starkware/cairo/lang/compiler/error_handling_test.py @@ -11,25 +11,27 @@ def test_location_error(): start_col=8, end_line=2, end_col=12, - input_file=InputFile( - filename='file.cairo', - content=content)) + input_file=InputFile(filename="file.cairo", content=content), + ) - expected_message = 'file.cairo:2:8: Error message.' + expected_message = "file.cairo:2:8: Error message." expected_message_with_content = f"""\ {expected_message} second line. ^**^\ """ - assert location.to_string('Error message.') == expected_message - assert str(location) == 'file.cairo:2:8' - assert location.to_string_with_content('Error message.') == expected_message_with_content - assert str(LocationError('Error message.', location=location)) == expected_message_with_content + assert location.to_string("Error message.") == expected_message + assert str(location) == "file.cairo:2:8" + assert location.to_string_with_content("Error message.") == expected_message_with_content + assert str(LocationError("Error message.", location=location)) == expected_message_with_content - err2 = LocationError(message='Error message.', location=None) - err2.notes.append('note1') - err2.notes.append('note2') - assert str(err2) == """\ + err2 = LocationError(message="Error message.", location=None) + err2.notes.append("note1") + err2.notes.append("note2") + assert ( + str(err2) + == """\ Error message. note1 note2""" + ) diff --git a/src/starkware/cairo/lang/compiler/expression_evaluator.py b/src/starkware/cairo/lang/compiler/expression_evaluator.py index 6e2c9e7b..2bc477b5 100644 --- a/src/starkware/cairo/lang/compiler/expression_evaluator.py +++ b/src/starkware/cairo/lang/compiler/expression_evaluator.py @@ -17,8 +17,13 @@ class ExpressionEvaluator(ExpressionSimplifier): prime: int def __init__( - self, prime: int, ap: Optional[int], fp: int, memory: MutableMapping[int, int], - identifiers: Optional[IdentifierManager] = None): + self, + prime: int, + ap: Optional[int], + fp: int, + memory: MutableMapping[int, int], + identifiers: Optional[IdentifierManager] = None, + ): super().__init__(prime=prime) assert self.prime is not None self.ap = ap @@ -28,8 +33,9 @@ def __init__( def eval(self, expr: Expression) -> int: expr, expr_type = simplify_type_system(expr, identifiers=self.identifiers) - assert isinstance(expr_type, (TypeFelt, TypePointer)), \ - f"Unable to evaluate expression of type '{expr_type.format()}'." + assert isinstance( + expr_type, (TypeFelt, TypePointer) + ), f"Unable to evaluate expression of type '{expr_type.format()}'." res = self.visit(expr) assert isinstance(res, ExprConst), f"Unable to evaluate expression '{expr.format()}'." assert self.prime is not None @@ -37,12 +43,12 @@ def eval(self, expr: Expression) -> int: def visit_ExprReg(self, expr: ExprReg) -> ExprConst: if expr.reg is Register.AP: - assert self.ap is not None, 'Cannot substitute ap in the expression.' + assert self.ap is not None, "Cannot substitute ap in the expression." return ExprConst(val=self.ap, location=expr.location) elif expr.reg is Register.FP: return ExprConst(val=self.fp, location=expr.location) else: - raise NotImplementedError(f'Register of type {expr.reg} is not supported') + raise NotImplementedError(f"Register of type {expr.reg} is not supported") def visit_ExprDeref(self, expr: ExprDeref) -> Expression: addr = self.visit(expr.addr) diff --git a/src/starkware/cairo/lang/compiler/expression_evaluator_test.py b/src/starkware/cairo/lang/compiler/expression_evaluator_test.py index 53b65786..c6a5ff4e 100644 --- a/src/starkware/cairo/lang/compiler/expression_evaluator_test.py +++ b/src/starkware/cairo/lang/compiler/expression_evaluator_test.py @@ -8,7 +8,7 @@ def test_eval_registers(): prime = 13 evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory={}) - assert evaluator.eval(parse_expr('2 * ap + 3 * fp - 5')) == (2 * ap + 3 * fp - 5) % prime + assert evaluator.eval(parse_expr("2 * ap + 3 * fp - 5")) == (2 * ap + 3 * fp - 5) % prime def test_eval_with_types(): @@ -17,7 +17,7 @@ def test_eval_with_types(): prime = 13 evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory={}) - assert evaluator.eval(parse_expr('cast(ap, T*)')) == ap + assert evaluator.eval(parse_expr("cast(ap, T*)")) == ap def test_eval_registers_and_memory(): @@ -27,7 +27,9 @@ def test_eval_registers_and_memory(): memory = {(2 * ap + 3 * fp - 5) % prime: 7, 7: 5, 6: 0} evaluator = ExpressionEvaluator(prime=prime, ap=ap, fp=fp, memory=memory) - assert evaluator.eval(parse_expr('[2 * ap + 3 * fp - 5]')) == 7 - assert evaluator.eval(parse_expr('[[2 * ap + 3 * fp - 5]] + 3 * ap')) == \ - (memory[7] + 3 * ap) % prime - assert evaluator.eval(parse_expr('[[[2 * ap + 3 * fp - 5]]+1]')) == 0 + assert evaluator.eval(parse_expr("[2 * ap + 3 * fp - 5]")) == 7 + assert ( + evaluator.eval(parse_expr("[[2 * ap + 3 * fp - 5]] + 3 * ap")) + == (memory[7] + 3 * ap) % prime + ) + assert evaluator.eval(parse_expr("[[[2 * ap + 3 * fp - 5]]+1]")) == 0 diff --git a/src/starkware/cairo/lang/compiler/expression_simplifier.py b/src/starkware/cairo/lang/compiler/expression_simplifier.py index 2b524d5e..2c8e434b 100644 --- a/src/starkware/cairo/lang/compiler/expression_simplifier.py +++ b/src/starkware/cairo/lang/compiler/expression_simplifier.py @@ -2,15 +2,22 @@ from typing import Optional from starkware.cairo.lang.compiler.ast.expr import ( - ExprConst, ExprDeref, ExprNeg, ExprOperator, ExprParentheses, ExprPow, ExprPyConst) + ExprConst, + ExprDeref, + ExprNeg, + ExprOperator, + ExprParentheses, + ExprPow, + ExprPyConst, +) from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer from starkware.python.math_utils import div_mod OPERATOR_DICT = { - '+': operator.add, - '-': operator.sub, - '*': operator.mul, + "+": operator.add, + "-": operator.sub, + "*": operator.mul, } @@ -29,13 +36,12 @@ def __init__(self, prime: Optional[int] = None): self.prime = prime def visit_ExprConst(self, expr: ExprConst): - return ExprConst( - val=self._to_field_element(expr.val), location=expr.location) + return ExprConst(val=self._to_field_element(expr.val), location=expr.location) def visit_ExprPyConst(self, expr: ExprPyConst): if self.prime is None: return expr - val = eval(expr.code, {'PRIME': self.prime}, {}) + val = eval(expr.code, {"PRIME": self.prime}, {}) return ExprConst(val=val, location=expr.location) def visit_ExprOperator(self, expr: ExprOperator): @@ -43,65 +49,67 @@ def visit_ExprOperator(self, expr: ExprOperator): b = self.visit(expr.b) op = expr.op - if isinstance(b, ExprConst) and op == '/' and b.val == 0: - raise SimplifierError('Division by zero.', location=b.location) + if isinstance(b, ExprConst) and op == "/" and b.val == 0: + raise SimplifierError("Division by zero.", location=b.location) if isinstance(a, ExprConst) and isinstance(b, ExprConst): val = None - if op == '/' and self.prime is not None: + if op == "/" and self.prime is not None: if b.val % self.prime == 0: - raise SimplifierError('Division by zero.', location=b.location) + raise SimplifierError("Division by zero.", location=b.location) val = div_mod(a.val, b.val, self.prime) - if op != '/': + if op != "/": val = self._to_field_element(OPERATOR_DICT[op](a.val, b.val)) if val is not None: return ExprConst(val, location=expr.location) - if isinstance(a, ExprConst) and op == '+': + if isinstance(a, ExprConst) and op == "+": assert not isinstance(b, ExprConst) # Move constant expression to the right. E.g., "5 + fp" -> "fp + 5" a, b = b, a - if isinstance(b, ExprConst) and op == '-': + if isinstance(b, ExprConst) and op == "-": # Replace x - y with x + (-y) for constant y. - op = '+' + op = "+" b = ExprConst(val=self._to_field_element(-b.val), location=b.location) - if isinstance(b, ExprConst) and op == '/' and self.prime is not None: + if isinstance(b, ExprConst) and op == "/" and self.prime is not None: # Replace x / y with x * (1/y) for constant y. - op = '*' + op = "*" if b.val % self.prime == 0: - raise SimplifierError('Division by zero.', location=b.location) + raise SimplifierError("Division by zero.", location=b.location) inv_val = div_mod(1, b.val, self.prime) b = ExprConst(val=self._to_field_element(inv_val), location=b.location) - if isinstance(b, ExprConst) and b.val == 0 and op in ['+', '-']: + if isinstance(b, ExprConst) and b.val == 0 and op in ["+", "-"]: # Replace x + 0 and x - 0 by x. return a - if isinstance(b, ExprConst) and b.val == 1 and op in ['*', '/']: + if isinstance(b, ExprConst) and b.val == 1 and op in ["*", "/"]: # Replace x * 1 and x / 1 by x. return a - if isinstance(a, ExprConst) and a.val == 1 and op == '*': + if isinstance(a, ExprConst) and a.val == 1 and op == "*": # Replace 1 * x by x. return b - if isinstance(b, ExprConst) and isinstance(a, ExprOperator) and \ - ((op == '+' and a.op in ['+', '-']) or (op == '*' and a.op == '*')): + if ( + isinstance(b, ExprConst) + and isinstance(a, ExprOperator) + and ((op == "+" and a.op in ["+", "-"]) or (op == "*" and a.op == "*")) + ): # If the expression is of the form "(a + b) + c" where c is constant, change it to # "a + (b + c)", this allows compiling expressions of the form: "[fp + x + y]". # Rotate right. - return self.visit(ExprOperator( - a=a.a, - op=a.op, - b=ExprOperator( - a=a.b, + return self.visit( + ExprOperator( + a=a.a, op=a.op, - b=b, - location=expr.location), - location=expr.location)) + b=ExprOperator(a=a.b, op=a.op, b=b, location=expr.location), + location=expr.location, + ) + ) return ExprOperator(a=a, op=op, b=b, location=expr.location) @@ -115,7 +123,8 @@ def visit_ExprPow(self, expr: ExprPow): if isinstance(a, ExprConst) and isinstance(b, ExprConst): if b.val < 0: raise SimplifierError( - 'Power is not supported with a negative exponent.', location=expr.location) + "Power is not supported with a negative exponent.", location=expr.location + ) if self.prime is not None: val = pow(a.val, b.val, self.prime) else: @@ -127,8 +136,7 @@ def visit_ExprPow(self, expr: ExprPow): def visit_ExprNeg(self, expr: ExprNeg): val = self.visit(expr.val) if isinstance(val, ExprConst): - return ExprConst( - val=self._to_field_element(-val.val), location=expr.location) + return ExprConst(val=self._to_field_element(-val.val), location=expr.location) return ExprNeg(val=val, location=expr.location) def visit_ExprParentheses(self, expr: ExprParentheses): diff --git a/src/starkware/cairo/lang/compiler/expression_simplifier_test.py b/src/starkware/cairo/lang/compiler/expression_simplifier_test.py index 2b7c8c1f..ec9f5e0b 100644 --- a/src/starkware/cairo/lang/compiler/expression_simplifier_test.py +++ b/src/starkware/cairo/lang/compiler/expression_simplifier_test.py @@ -1,88 +1,89 @@ import pytest from starkware.cairo.lang.compiler.expression_simplifier import ( - ExpressionSimplifier, SimplifierError) + ExpressionSimplifier, + SimplifierError, +) from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.substitute_identifiers import substitute_identifiers -@pytest.mark.parametrize('prime', [None, 3 * 2**30 + 1]) +@pytest.mark.parametrize("prime", [None, 3 * 2 ** 30 + 1]) def test_simplifier(prime): - assignments = {'x': 10, 'y': 3, 'z': -2, 'w': -60} + assignments = {"x": 10, "y": 3, "z": -2, "w": -60} simplifier = ExpressionSimplifier(prime) - simplify = lambda expr: simplifier.visit(substitute_identifiers( - expr=expr, get_identifier_callback=lambda var: assignments[var.name])) - assert simplify(parse_expr('fp + x * (y + -1)')).format() == 'fp + 20' - assert simplify(parse_expr('[fp + x] + [ap - (-z)]')).format() == \ - '[fp + 10] + [ap + (-2)]' - assert simplify(parse_expr('fp + x - y')).format() == 'fp + 7' - assert simplify(parse_expr('[1 + fp + 5]')).format() == '[fp + 6]' - assert simplify(parse_expr('[fp] - 3')).format() == '[fp] + (-3)' + simplify = lambda expr: simplifier.visit( + substitute_identifiers(expr=expr, get_identifier_callback=lambda var: assignments[var.name]) + ) + assert simplify(parse_expr("fp + x * (y + -1)")).format() == "fp + 20" + assert simplify(parse_expr("[fp + x] + [ap - (-z)]")).format() == "[fp + 10] + [ap + (-2)]" + assert simplify(parse_expr("fp + x - y")).format() == "fp + 7" + assert simplify(parse_expr("[1 + fp + 5]")).format() == "[fp + 6]" + assert simplify(parse_expr("[fp] - 3")).format() == "[fp] + (-3)" if prime is not None: - assert simplify(parse_expr('fp * (x - 1) / y')).format() == 'fp * 3' - assert simplify(parse_expr('fp * w / x / y / z')).format() == 'fp' + assert simplify(parse_expr("fp * (x - 1) / y")).format() == "fp * 3" + assert simplify(parse_expr("fp * w / x / y / z")).format() == "fp" else: - assert simplify(parse_expr('fp * (x - 1) / y')).format() == 'fp * 9 / 3' - assert simplify(parse_expr('fp * w / x / y / z')).format() == \ - 'fp * (-60) / 10 / 3 / (-2)' - assert simplify(parse_expr('fp * 1')).format() == 'fp' - assert simplify(parse_expr('1 * fp')).format() == 'fp' + assert simplify(parse_expr("fp * (x - 1) / y")).format() == "fp * 9 / 3" + assert simplify(parse_expr("fp * w / x / y / z")).format() == "fp * (-60) / 10 / 3 / (-2)" + assert simplify(parse_expr("fp * 1")).format() == "fp" + assert simplify(parse_expr("1 * fp")).format() == "fp" -@pytest.mark.parametrize('prime', [None, 3 * 2**30 + 1]) +@pytest.mark.parametrize("prime", [None, 3 * 2 ** 30 + 1]) def test_pow(prime): simplifier = ExpressionSimplifier(prime) - assert simplifier.visit(parse_expr('4 ** 3 ** 2')).format() == '262144' + assert simplifier.visit(parse_expr("4 ** 3 ** 2")).format() == "262144" if prime is not None: # Make sure the exponent is not computed modulo prime (if it were, # the result would have been 1). - assert simplifier.visit(parse_expr('(3 * 2**30 + 4) ** (3 * 2**30 + 1)')).format() == '3' + assert simplifier.visit(parse_expr("(3 * 2**30 + 4) ** (3 * 2**30 + 1)")).format() == "3" - with pytest.raises(SimplifierError, match='Power is not supported with a negative exponent'): - simplifier.visit(parse_expr('2 ** (-1)')) + with pytest.raises(SimplifierError, match="Power is not supported with a negative exponent"): + simplifier.visit(parse_expr("2 ** (-1)")) def test_modulo(): PRIME = 19 simplifier = ExpressionSimplifier(PRIME) # Check that the range is (-PRIME/2, PRIME/2). - assert simplifier.visit(parse_expr('-9')).format() == '-9' - assert simplifier.visit(parse_expr('-10')).format() == '9' - assert simplifier.visit(parse_expr('9')).format() == '9' - assert simplifier.visit(parse_expr('10')).format() == '-9' + assert simplifier.visit(parse_expr("-9")).format() == "-9" + assert simplifier.visit(parse_expr("-10")).format() == "9" + assert simplifier.visit(parse_expr("9")).format() == "9" + assert simplifier.visit(parse_expr("10")).format() == "-9" # Check value which is bigger than PRIME. - assert simplifier.visit(parse_expr('20')).format() == '1' + assert simplifier.visit(parse_expr("20")).format() == "1" # Check operators. - assert simplifier.visit(parse_expr('10 + 10')).format() == '1' - assert simplifier.visit(parse_expr('10 - 30')).format() == '-1' - assert simplifier.visit(parse_expr('10 * 10')).format() == '5' - assert simplifier.visit(parse_expr('2 / 3')).format() == '7' + assert simplifier.visit(parse_expr("10 + 10")).format() == "1" + assert simplifier.visit(parse_expr("10 - 30")).format() == "-1" + assert simplifier.visit(parse_expr("10 * 10")).format() == "5" + assert simplifier.visit(parse_expr("2 / 3")).format() == "7" -@pytest.mark.parametrize('prime', [None, 3 * 2**30 + 1]) +@pytest.mark.parametrize("prime", [None, 3 * 2 ** 30 + 1]) def test_rotation(prime): simplifier = ExpressionSimplifier(prime) - assert simplifier.visit(parse_expr('(fp + 10) + 1')).format() == 'fp + 11' - assert simplifier.visit(parse_expr('(fp + 10) - 1')).format() == 'fp + 9' - assert simplifier.visit(parse_expr('(fp - 10) + 1')).format() == 'fp + (-9)' - assert simplifier.visit(parse_expr('(fp - 10) - 1')).format() == 'fp + (-11)' + assert simplifier.visit(parse_expr("(fp + 10) + 1")).format() == "fp + 11" + assert simplifier.visit(parse_expr("(fp + 10) - 1")).format() == "fp + 9" + assert simplifier.visit(parse_expr("(fp - 10) + 1")).format() == "fp + (-9)" + assert simplifier.visit(parse_expr("(fp - 10) - 1")).format() == "fp + (-11)" - assert simplifier.visit(parse_expr('(10 + fp) - 1')).format() == 'fp + 9' - assert simplifier.visit(parse_expr('10 + (fp - 1)')).format() == 'fp + 9' - assert simplifier.visit(parse_expr('10 + (1 + fp)')).format() == 'fp + 11' - assert simplifier.visit(parse_expr('10 + (1 + fp) + 100')).format() == 'fp + 111' - assert simplifier.visit(parse_expr('10 + (1 + (fp + 100))')).format() == 'fp + 111' + assert simplifier.visit(parse_expr("(10 + fp) - 1")).format() == "fp + 9" + assert simplifier.visit(parse_expr("10 + (fp - 1)")).format() == "fp + 9" + assert simplifier.visit(parse_expr("10 + (1 + fp)")).format() == "fp + 11" + assert simplifier.visit(parse_expr("10 + (1 + fp) + 100")).format() == "fp + 111" + assert simplifier.visit(parse_expr("10 + (1 + (fp + 100))")).format() == "fp + 111" -@pytest.mark.parametrize('prime', [None, 3 * 2**30 + 1]) +@pytest.mark.parametrize("prime", [None, 3 * 2 ** 30 + 1]) def test_division_by_zero(prime): simplifier = ExpressionSimplifier(prime) - with pytest.raises(SimplifierError, match='Division by zero'): - simplifier.visit(parse_expr('fp / 0')) - with pytest.raises(SimplifierError, match='Division by zero'): - simplifier.visit(parse_expr('5 / 0')) + with pytest.raises(SimplifierError, match="Division by zero"): + simplifier.visit(parse_expr("fp / 0")) + with pytest.raises(SimplifierError, match="Division by zero"): + simplifier.visit(parse_expr("5 / 0")) if prime is not None: - with pytest.raises(SimplifierError, match='Division by zero'): - simplifier.visit(parse_expr(f'fp / {prime}')) + with pytest.raises(SimplifierError, match="Division by zero"): + simplifier.visit(parse_expr(f"fp / {prime}")) diff --git a/src/starkware/cairo/lang/compiler/expression_transformer.py b/src/starkware/cairo/lang/compiler/expression_transformer.py index c164400c..36502116 100644 --- a/src/starkware/cairo/lang/compiler/expression_transformer.py +++ b/src/starkware/cairo/lang/compiler/expression_transformer.py @@ -1,9 +1,26 @@ from typing import Optional from starkware.cairo.lang.compiler.ast.expr import ( - ArgList, ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, ExprDot, Expression, - ExprFutureLabel, ExprHint, ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPow, - ExprPyConst, ExprReg, ExprSubscript, ExprTuple) + ArgList, + ExprAddressOf, + ExprAssignment, + ExprCast, + ExprConst, + ExprDeref, + ExprDot, + Expression, + ExprFutureLabel, + ExprHint, + ExprIdentifier, + ExprNeg, + ExprOperator, + ExprParentheses, + ExprPow, + ExprPyConst, + ExprReg, + ExprSubscript, + ExprTuple, +) from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall from starkware.cairo.lang.compiler.error_handling import Location, LocationError @@ -27,25 +44,23 @@ def visit_ExprParentheses(self, expr: ExprParentheses): """ def visit(self, expr: Expression): - funcname = f'visit_{type(expr).__name__}' + funcname = f"visit_{type(expr).__name__}" return getattr(self, funcname)(expr) def visit_ExprConst(self, expr: ExprConst): return ExprConst( - val=expr.val, - format_str=expr.format_str, - location=self.location_modifier(expr.location)) + val=expr.val, format_str=expr.format_str, location=self.location_modifier(expr.location) + ) def visit_ExprPyConst(self, expr: ExprPyConst): - return ExprPyConst( - code=expr.code, - location=self.location_modifier(expr.location)) + return ExprPyConst(code=expr.code, location=self.location_modifier(expr.location)) def visit_ExprHint(self, expr: ExprHint): return ExprHint( hint_code=expr.hint_code, n_prefix_newlines=expr.n_prefix_newlines, - location=self.location_modifier(expr.location)) + location=self.location_modifier(expr.location), + ) def visit_ExprIdentifier(self, expr: ExprIdentifier): return ExprIdentifier(name=expr.name, location=self.location_modifier(expr.location)) @@ -58,20 +73,26 @@ def visit_ExprReg(self, expr: ExprReg): def visit_ExprOperator(self, expr: ExprOperator): return ExprOperator( - a=self.visit(expr.a), op=expr.op, b=self.visit(expr.b), - location=self.location_modifier(expr.location)) + a=self.visit(expr.a), + op=expr.op, + b=self.visit(expr.b), + location=self.location_modifier(expr.location), + ) def visit_ExprPow(self, expr: ExprPow): return ExprPow( - a=self.visit(expr.a), b=self.visit(expr.b), - location=self.location_modifier(expr.location)) + a=self.visit(expr.a), + b=self.visit(expr.b), + location=self.location_modifier(expr.location), + ) def visit_ExprNeg(self, expr: ExprNeg): return ExprNeg(val=self.visit(expr.val), location=self.location_modifier(expr.location)) def visit_ExprParentheses(self, expr: ExprParentheses): return ExprParentheses( - val=self.visit(expr.val), location=self.location_modifier(expr.location)) + val=self.visit(expr.val), location=self.location_modifier(expr.location) + ) def visit_ExprDeref(self, expr: ExprDeref): return ExprDeref(addr=self.visit(expr.addr), location=self.location_modifier(expr.location)) @@ -80,7 +101,8 @@ def visit_ExprSubscript(self, expr: ExprSubscript): return ExprSubscript( expr=self.visit(expr.expr), offset=self.visit(expr.offset), - location=self.location_modifier(expr.location)) + location=self.location_modifier(expr.location), + ) def visit_ExprDot(self, expr: ExprDot): return ExprDot( @@ -88,12 +110,12 @@ def visit_ExprDot(self, expr: ExprDot): # Avoid visiting 'member' with an overridden visit_ExprIdentifier, as it is not a # proper identifier. member=ExpressionTransformer.visit_ExprIdentifier(self, expr.member), - location=self.location_modifier(expr.location)) + location=self.location_modifier(expr.location), + ) def visit_ExprAddressOf(self, expr: ExprAddressOf): inner_expr = self.visit(expr.expr) - return ExprAddressOf( - expr=inner_expr, location=self.location_modifier(expr.location)) + return ExprAddressOf(expr=inner_expr, location=self.location_modifier(expr.location)) def visit_ExprCast(self, expr: ExprCast): inner_expr = self.visit(expr.expr) @@ -101,7 +123,8 @@ def visit_ExprCast(self, expr: ExprCast): expr=inner_expr, dest_type=expr.dest_type, cast_type=expr.cast_type, - location=self.location_modifier(expr.location)) + location=self.location_modifier(expr.location), + ) def visit_ArgList(self, arg_list: ArgList): return ArgList( @@ -109,30 +132,35 @@ def visit_ArgList(self, arg_list: ArgList): ExprAssignment( identifier=item.identifier, expr=self.visit(item.expr), - location=self.location_modifier(item.location)) + location=self.location_modifier(item.location), + ) for item in arg_list.args ], notes=arg_list.notes, has_trailing_comma=arg_list.has_trailing_comma, - location=self.location_modifier(arg_list.location)) + location=self.location_modifier(arg_list.location), + ) def visit_ExprTuple(self, expr: ExprTuple): return ExprTuple( - members=self.visit_ArgList(expr.members), - location=self.location_modifier(expr.location)) + members=self.visit_ArgList(expr.members), location=self.location_modifier(expr.location) + ) def visit_RvalueFuncCall(self, rvalue: RvalueFuncCall): return RvalueFuncCall( func_ident=self.visit(rvalue.func_ident), arguments=self.visit_ArgList(rvalue.arguments), - implicit_arguments=None if rvalue.implicit_arguments is None else self.visit_ArgList( - rvalue.implicit_arguments), - location=self.location_modifier(rvalue.location)) + implicit_arguments=None + if rvalue.implicit_arguments is None + else self.visit_ArgList(rvalue.implicit_arguments), + location=self.location_modifier(rvalue.location), + ) def visit_ExprFuncCall(self, expr: ExprFuncCall): return ExprFuncCall( rvalue=self.visit_RvalueFuncCall(expr.rvalue), - location=self.location_modifier(expr.location)) + location=self.location_modifier(expr.location), + ) def location_modifier(self, location: Optional[Location]) -> Optional[Location]: """ diff --git a/src/starkware/cairo/lang/compiler/fields.py b/src/starkware/cairo/lang/compiler/fields.py index 2ff61b27..35b53290 100644 --- a/src/starkware/cairo/lang/compiler/fields.py +++ b/src/starkware/cairo/lang/compiler/fields.py @@ -3,7 +3,10 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType from starkware.cairo.lang.compiler.parser import parse_expr, parse_type from starkware.cairo.lang.compiler.type_system import ( - is_type_resolved, mark_type_resolved, mark_types_in_expr_resolved) + is_type_resolved, + mark_type_resolved, + mark_types_in_expr_resolved, +) class ExpressionAsStr(mfields.Field): @@ -14,8 +17,9 @@ class ExpressionAsStr(mfields.Field): def _serialize(self, value, attr, obj, **kwargs): if value is None: return None - assert mark_types_in_expr_resolved(value) == value, \ - f"Expected types in '{value}' to be resolved." + assert ( + mark_types_in_expr_resolved(value) == value + ), f"Expected types in '{value}' to be resolved." return value.format() def _deserialize(self, value, attr, data, **kwargs): @@ -30,7 +34,7 @@ class CairoTypeAsStr(mfields.Field): def _serialize(self, value, attr, obj, **kwargs): if value is None: return None - assert isinstance(value, CairoType), f'Expected CairoType, found: {type(value).__name__}.' + assert isinstance(value, CairoType), f"Expected CairoType, found: {type(value).__name__}." assert is_type_resolved(value), f"Cairo type '{value}' must be resolved." return value.format() diff --git a/src/starkware/cairo/lang/compiler/identifier_definition.py b/src/starkware/cairo/lang/compiler/identifier_definition.py index 2ab4449b..812aa3a6 100644 --- a/src/starkware/cairo/lang/compiler/identifier_definition.py +++ b/src/starkware/cairo/lang/compiler/identifier_definition.py @@ -13,7 +13,10 @@ from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.fields import CairoTypeAsStr from starkware.cairo.lang.compiler.preprocessor.flow import ( - FlowTrackingData, FlowTrackingDataActual, ReferenceManager) + FlowTrackingData, + FlowTrackingDataActual, + ReferenceManager, +) from starkware.cairo.lang.compiler.references import Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName, ScopedNameAsStr @@ -35,13 +38,13 @@ class FutureIdentifierDefinition(IdentifierDefinition): Represents an identifier that will be defined later in the code. """ - TYPE: ClassVar[str] = 'future' + TYPE: ClassVar[str] = "future" identifier_type: type @marshmallow_dataclass.dataclass class AliasDefinition(IdentifierDefinition): - TYPE: ClassVar[str] = 'alias' + TYPE: ClassVar[str] = "alias" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema destination: ScopedName = field(metadata=dict(marshmallow_field=ScopedNameAsStr())) @@ -49,7 +52,7 @@ class AliasDefinition(IdentifierDefinition): @marshmallow_dataclass.dataclass class ConstDefinition(IdentifierDefinition): - TYPE: ClassVar[str] = 'const' + TYPE: ClassVar[str] = "const" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema value: int @@ -57,12 +60,11 @@ class ConstDefinition(IdentifierDefinition): @marshmallow_dataclass.dataclass class MemberDefinition(IdentifierDefinition): - TYPE: ClassVar[str] = 'member' + TYPE: ClassVar[str] = "member" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema offset: int - cairo_type: CairoType = field( - metadata=dict(marshmallow_field=CairoTypeAsStr(required=True))) + cairo_type: CairoType = field(metadata=dict(marshmallow_field=CairoTypeAsStr(required=True))) location: Optional[Location] = LocationField @@ -76,7 +78,8 @@ class StructDefinition(IdentifierDefinition): ... end """ - TYPE: ClassVar[str] = 'struct' + + TYPE: ClassVar[str] = "struct" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema full_name: ScopedName = field(metadata=dict(marshmallow_field=ScopedNameAsStr())) @@ -91,14 +94,15 @@ def sort_members(self, item, many, **kwargs): """ Sorts the members according to their offset. """ - item['members'] = dict( - sorted(item['members'].items(), key=lambda key_value: key_value[1].offset)) + item["members"] = dict( + sorted(item["members"].items(), key=lambda key_value: key_value[1].offset) + ) return item @marshmallow_dataclass.dataclass class LabelDefinition(IdentifierDefinition): - TYPE: ClassVar[str] = 'label' + TYPE: ClassVar[str] = "label" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema pc: int @@ -106,7 +110,7 @@ class LabelDefinition(IdentifierDefinition): @marshmallow_dataclass.dataclass class FunctionDefinition(LabelDefinition): - TYPE: ClassVar[str] = 'function' + TYPE: ClassVar[str] = "function" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema decorators: List[str] @@ -114,22 +118,22 @@ class FunctionDefinition(LabelDefinition): @marshmallow_dataclass.dataclass class ReferenceDefinition(IdentifierDefinition): - TYPE: ClassVar[str] = 'reference' + TYPE: ClassVar[str] = "reference" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema full_name: ScopedName = field(metadata=dict(marshmallow_field=ScopedNameAsStr())) - cairo_type: CairoType = field( - metadata=dict(marshmallow_field=CairoTypeAsStr(required=True))) + cairo_type: CairoType = field(metadata=dict(marshmallow_field=CairoTypeAsStr(required=True))) references: List[Reference] def eval( - self, reference_manager: ReferenceManager, flow_tracking_data: FlowTrackingData) -> \ - Expression: + self, reference_manager: ReferenceManager, flow_tracking_data: FlowTrackingData + ) -> Expression: reference = flow_tracking_data.resolve_reference( - reference_manager=reference_manager, - name=self.full_name) - assert isinstance(flow_tracking_data, FlowTrackingDataActual), \ - 'Resolved references can only come from FlowTrackingDataActual.' + reference_manager=reference_manager, name=self.full_name + ) + assert isinstance( + flow_tracking_data, FlowTrackingDataActual + ), "Resolved references can only come from FlowTrackingDataActual." expr = reference.eval(flow_tracking_data.ap_tracking) return expr @@ -137,7 +141,7 @@ def eval( @marshmallow_dataclass.dataclass class ScopeDefinition(IdentifierDefinition): - TYPE: ClassVar[str] = 'scope' + TYPE: ClassVar[str] = "scope" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema diff --git a/src/starkware/cairo/lang/compiler/identifier_definition_test.py b/src/starkware/cairo/lang/compiler/identifier_definition_test.py index 5a5f4855..920c9b7d 100644 --- a/src/starkware/cairo/lang/compiler/identifier_definition_test.py +++ b/src/starkware/cairo/lang/compiler/identifier_definition_test.py @@ -1,6 +1,9 @@ from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt from starkware.cairo.lang.compiler.identifier_definition import ( - IdentifierDefinitionSchema, MemberDefinition, StructDefinition) + IdentifierDefinitionSchema, + MemberDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName scope = ScopedName.from_string @@ -8,20 +11,22 @@ def test_struct_sorting(): orig = StructDefinition( - full_name=ScopedName.from_string('T'), + full_name=ScopedName.from_string("T"), members={ - 'b': MemberDefinition(offset=1, cairo_type=TypeFelt()), - 'a': MemberDefinition(offset=0, cairo_type=TypeFelt()), + "b": MemberDefinition(offset=1, cairo_type=TypeFelt()), + "a": MemberDefinition(offset=0, cairo_type=TypeFelt()), }, - size=2 + size=2, ) members = orig.members assert list(members.items()) != sorted( - members.items(), key=lambda key_value: key_value[1].offset) + members.items(), key=lambda key_value: key_value[1].offset + ) schema = IdentifierDefinitionSchema() loaded = schema.load(schema.dump(orig)) members = loaded.members assert list(members.items()) == sorted( - members.items(), key=lambda key_value: key_value[1].offset) + members.items(), key=lambda key_value: key_value[1].offset + ) diff --git a/src/starkware/cairo/lang/compiler/identifier_manager.py b/src/starkware/cairo/lang/compiler/identifier_manager.py index 59ad8940..378f7e89 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager.py @@ -2,7 +2,10 @@ from typing import Dict, List, Optional, Set, Union from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, FutureIdentifierDefinition, IdentifierDefinition) + AliasDefinition, + FutureIdentifierDefinition, + IdentifierDefinition, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -22,7 +25,8 @@ class NotAScopeError(IdentifierError): """ def __init__( - self, fullname: ScopedName, definition: IdentifierDefinition, non_parsed: ScopedName): + self, fullname: ScopedName, definition: IdentifierDefinition, non_parsed: ScopedName + ): self.fullname = fullname self.definition = definition self.non_parsed = non_parsed @@ -49,7 +53,8 @@ def assert_fully_parsed(self): return raise IdentifierError( f"Unexpected '.' after '{self.canonical_name}' which is " - f'{self.identifier_definition.TYPE}.') + f"{self.identifier_definition.TYPE}." + ) def get_canonical_name(self) -> ScopedName: """ @@ -77,7 +82,8 @@ def add_identifier(self, name: ScopedName, definition: IdentifierDefinition): @classmethod def from_dict( - cls, identifier_dict: Dict[ScopedName, IdentifierDefinition]) -> 'IdentifierManager': + cls, identifier_dict: Dict[ScopedName, IdentifierDefinition] + ) -> "IdentifierManager": identifier_manager = cls() for name, identifier_definition in identifier_dict.items(): identifier_manager.add_identifier(name, identifier_definition) @@ -111,15 +117,15 @@ def get(self, name: ScopedName) -> IdentifierSearchResult: # Detect cycles. if current_identifier in visited_identifiers: - cycle_str = ' -> '.join(map(str, visited_identifiers + [current_identifier])) - raise IdentifierError(f'Cyclic aliasing detected: {cycle_str}') + cycle_str = " -> ".join(map(str, visited_identifiers + [current_identifier])) + raise IdentifierError(f"Cyclic aliasing detected: {cycle_str}") visited_identifiers.append(current_identifier) try: result = self.root.get(current_identifier) except MissingIdentifierError as exc: - resolution_str = ' -> '.join(map(str, visited_identifiers)) - raise IdentifierError(f'Alias resolution failed: {resolution_str}. {exc}') + resolution_str = " -> ".join(map(str, visited_identifiers)) + raise IdentifierError(f"Alias resolution failed: {resolution_str}. {exc}") return result @@ -142,7 +148,7 @@ def get_by_full_name(self, name: ScopedName) -> Optional[IdentifierDefinition]: return result.identifier_definition - def get_scope(self, name: ScopedName) -> 'IdentifierScope': + def get_scope(self, name: ScopedName) -> "IdentifierScope": """ Finds the scope with the given name. Includes alias resolution. """ @@ -170,16 +176,16 @@ def get_scope(self, name: ScopedName) -> 'IdentifierScope': if len(visited_identifiers) == 1: raise # Add a prefix with the alias resolution. - resolution_str = ' -> '.join(map(str, visited_identifiers)) - raise IdentifierError(f'Alias resolution failed: {resolution_str}. {exc}') from None + resolution_str = " -> ".join(map(str, visited_identifiers)) + raise IdentifierError(f"Alias resolution failed: {resolution_str}. {exc}") from None # We found an alias cycle. - cycle_str = ' -> '.join(map(str, visited_identifiers + [current_identifier])) - raise IdentifierError(f'Cyclic aliasing detected: {cycle_str}') + cycle_str = " -> ".join(map(str, visited_identifiers + [current_identifier])) + raise IdentifierError(f"Cyclic aliasing detected: {cycle_str}") def _search( - self, accessible_scopes: List[ScopedName], - name: ScopedName, get_scope: bool) -> Union[IdentifierSearchResult, 'IdentifierScope']: + self, accessible_scopes: List[ScopedName], name: ScopedName, get_scope: bool + ) -> Union[IdentifierSearchResult, "IdentifierScope"]: """ Searches an identifier (if get_scope=False) or a scope (if get_scope=True) in the given accessible scopes. Later scopes override the first ones. @@ -206,7 +212,8 @@ def _search( raise MissingIdentifierError(name[:1]) def search( - self, accessible_scopes: List[ScopedName], name: ScopedName) -> IdentifierSearchResult: + self, accessible_scopes: List[ScopedName], name: ScopedName + ) -> IdentifierSearchResult: """ Searches an identifier in the given accessible scopes. Later scopes override the first ones. """ @@ -215,7 +222,8 @@ def search( return res def search_scope( - self, accessible_scopes: List[ScopedName], name: ScopedName) -> 'IdentifierScope': + self, accessible_scopes: List[ScopedName], name: ScopedName + ) -> "IdentifierScope": """ Searches a scope in the given accessible scopes. Later scopes override the first ones. """ @@ -223,16 +231,14 @@ def search_scope( assert isinstance(res, IdentifierScope) return res - def exclude(self, other: 'IdentifierManager') -> 'IdentifierManager': + def exclude(self, other: "IdentifierManager") -> "IdentifierManager": """ Returns a copy of the identifier manager without the identifiers that exist in other. """ other_as_dict = other.as_dict() - return IdentifierManager.from_dict({ - name: value - for name, value in self.as_dict().items() - if name not in other_as_dict - }) + return IdentifierManager.from_dict( + {name: value for name, value in self.as_dict().items() if name not in other_as_dict} + ) def prune(self, prefixes_to_prune: Set[ScopedName]): """ @@ -247,9 +253,10 @@ def prune(self, prefixes_to_prune: Set[ScopedName]): break parent = parent[:-1] if parent in prefixes_to_prune: - assert isinstance(value, (IdentifierDefinition, FutureIdentifierDefinition)), \ - f"Attempted to prune identifier '{value}'" \ + assert isinstance(value, (IdentifierDefinition, FutureIdentifierDefinition)), ( + f"Attempted to prune identifier '{value}'" f" of unprunable type '{type(value).__name__}'." + ) continue new_dict[name] = value self.dict = new_dict @@ -279,7 +286,7 @@ def add_identifier(self, name: ScopedName, definition: IdentifierDefinition): Adds an identifier to the manager. name is relative to the current scope. """ if len(name) == 0: - raise ValueError('The name argument must not be empty.') + raise ValueError("The name argument must not be empty.") first_name, non_parsed = name.path[0], name[1:] @@ -290,7 +297,8 @@ def add_identifier(self, name: ScopedName, definition: IdentifierDefinition): if first_name not in self.subscopes: self.subscopes[first_name] = IdentifierScope( - manager=self.manager, fullname=self.fullname + first_name) + manager=self.manager, fullname=self.fullname + first_name + ) self.subscopes[first_name].add_identifier(non_parsed, definition) @@ -311,11 +319,12 @@ def get(self, name: ScopedName) -> IdentifierSearchResult: return IdentifierSearchResult( identifier_definition=self.identifiers[first_name], canonical_name=canonical_name, - non_parsed=non_parsed) + non_parsed=non_parsed, + ) raise MissingIdentifierError(fullname=self.fullname + first_name) - def get_scope(self, name: ScopedName) -> 'IdentifierScope': + def get_scope(self, name: ScopedName) -> "IdentifierScope": """ Retrieves the scope with the given name. Raises NotAScopeError if name refers to an identifier rather than a scope @@ -328,8 +337,10 @@ def get_scope(self, name: ScopedName) -> 'IdentifierScope': fullname = self.fullname + first_name if first_name in self.identifiers: raise NotAScopeError( - fullname=fullname, definition=self.identifiers[first_name], - non_parsed=non_parsed) + fullname=fullname, + definition=self.identifiers[first_name], + non_parsed=non_parsed, + ) else: raise MissingIdentifierError(fullname=fullname) return self.subscopes[first_name].get_scope(non_parsed) diff --git a/src/starkware/cairo/lang/compiler/identifier_manager_field.py b/src/starkware/cairo/lang/compiler/identifier_manager_field.py index 4bad013a..0d5d67f3 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager_field.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager_field.py @@ -21,8 +21,11 @@ def _serialize(self, value, attr, obj, **kwargs): def _deserialize(self, value, attr, data, **kwargs) -> IdentifierManager: identifier_definition_schema = IdentifierDefinitionSchema() - return IdentifierManager.from_dict({ - ScopedName.from_string(name): identifier_definition_schema.load( - serialized_identifier_definition) - for name, serialized_identifier_definition in value.items() - }) + return IdentifierManager.from_dict( + { + ScopedName.from_string(name): identifier_definition_schema.load( + serialized_identifier_definition + ) + for name, serialized_identifier_definition in value.items() + } + ) diff --git a/src/starkware/cairo/lang/compiler/identifier_manager_field_test.py b/src/starkware/cairo/lang/compiler/identifier_manager_field_test.py index 8f86bfd5..bc103cf2 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager_field_test.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager_field_test.py @@ -14,18 +14,19 @@ def test_identifier_manager_field_serialization(): @marshmallow_dataclass.dataclass class Foo: identifiers: IdentifierManager = field( - metadata=dict(marshmallow_field=IdentifierManagerField())) + metadata=dict(marshmallow_field=IdentifierManagerField()) + ) Schema = marshmallow_dataclass.class_schema(Foo) - foo = Foo(identifiers=IdentifierManager.from_dict({ - scope('aa.b'): LabelDefinition(pc=1000), - })) + foo = Foo( + identifiers=IdentifierManager.from_dict( + { + scope("aa.b"): LabelDefinition(pc=1000), + } + ) + ) serialized = Schema().dump(foo) - assert serialized == { - 'identifiers': { - 'aa.b': {'pc': 1000, 'type': 'label'} - } - } + assert serialized == {"identifiers": {"aa.b": {"pc": 1000, "type": "label"}}} assert Schema().load(serialized) == foo diff --git a/src/starkware/cairo/lang/compiler/identifier_manager_test.py b/src/starkware/cairo/lang/compiler/identifier_manager_test.py index 525b2e36..82accf50 100644 --- a/src/starkware/cairo/lang/compiler/identifier_manager_test.py +++ b/src/starkware/cairo/lang/compiler/identifier_manager_test.py @@ -4,7 +4,11 @@ from starkware.cairo.lang.compiler.identifier_definition import AliasDefinition, ConstDefinition from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, IdentifierSearchResult, MissingIdentifierError) + IdentifierError, + IdentifierManager, + IdentifierSearchResult, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName scope = ScopedName.from_string @@ -12,21 +16,22 @@ def test_identifier_manager_get(): identifier_dict = { - scope('a.b.c'): ConstDefinition(value=7), + scope("a.b.c"): ConstDefinition(value=7), } manager = IdentifierManager.from_dict(identifier_dict) - for name in ['a', 'a.b']: + for name in ["a", "a.b"]: with pytest.raises(MissingIdentifierError, match=f"Unknown identifier '{name}'."): manager.get(scope(name)) # Search 'a.b.c.*'. - for suffix in ['d', 'd.e']: - result = manager.get(scope('a.b.c') + scope(suffix)) + for suffix in ["d", "d.e"]: + result = manager.get(scope("a.b.c") + scope(suffix)) assert result == IdentifierSearchResult( - identifier_definition=identifier_dict[scope('a.b.c')], - canonical_name=scope('a.b.c'), - non_parsed=scope(suffix)) + identifier_definition=identifier_dict[scope("a.b.c")], + canonical_name=scope("a.b.c"), + non_parsed=scope(suffix), + ) error_msg = re.escape("Unexpected '.' after 'a.b.c' which is const") with pytest.raises(IdentifierError, match=error_msg): @@ -34,15 +39,16 @@ def test_identifier_manager_get(): with pytest.raises(IdentifierError, match=error_msg): result.get_canonical_name() - result = manager.get(scope('a.b.c')) + result = manager.get(scope("a.b.c")) assert result == IdentifierSearchResult( - identifier_definition=identifier_dict[scope('a.b.c')], - canonical_name=scope('a.b.c'), - non_parsed=ScopedName()) + identifier_definition=identifier_dict[scope("a.b.c")], + canonical_name=scope("a.b.c"), + non_parsed=ScopedName(), + ) result.assert_fully_parsed() - assert result.get_canonical_name() == scope('a.b.c') + assert result.get_canonical_name() == scope("a.b.c") - for name in ['a.d', 'a.d.e']: + for name in ["a.d", "a.d.e"]: # The error should point to the first unknown item, rather then the entire name. with pytest.raises(MissingIdentifierError, match="Unknown identifier 'a.d'."): manager.get(scope(name)) @@ -50,101 +56,113 @@ def test_identifier_manager_get(): def test_identifier_manager_get_by_full_name(): identifier_dict = { - scope('a.b.c'): ConstDefinition(value=7), - scope('x'): AliasDefinition(destination=scope('a')), + scope("a.b.c"): ConstDefinition(value=7), + scope("x"): AliasDefinition(destination=scope("a")), } manager = IdentifierManager.from_dict(identifier_dict) - assert manager.get_by_full_name(scope('a.b.c')) == identifier_dict[scope('a.b.c')] - assert manager.get_by_full_name(scope('x')) == identifier_dict[scope('x')] + assert manager.get_by_full_name(scope("a.b.c")) == identifier_dict[scope("a.b.c")] + assert manager.get_by_full_name(scope("x")) == identifier_dict[scope("x")] - assert manager.get_by_full_name(scope('a.b')) is None - assert manager.get_by_full_name(scope('a.b.c.d')) is None - assert manager.get_by_full_name(scope('x.b.c')) is None + assert manager.get_by_full_name(scope("a.b")) is None + assert manager.get_by_full_name(scope("a.b.c.d")) is None + assert manager.get_by_full_name(scope("x.b.c")) is None def test_identifier_manager_aliases(): identifier_dict = { - scope('a.b.c'): AliasDefinition(destination=scope('x.y')), - scope('x.y'): AliasDefinition(destination=scope('x.y2')), - scope('x.y2.z'): ConstDefinition(value=3), - scope('x.y2.s.z'): ConstDefinition(value=4), - scope('x.y2.s2'): AliasDefinition(destination=scope('x.y2.s')), - - scope('z0'): AliasDefinition(destination=scope('z1.z2')), - scope('z1.z2'): AliasDefinition(destination=scope('z3')), - scope('z3'): AliasDefinition(destination=scope('z0')), - - scope('to_const'): AliasDefinition(destination=scope('x.y2.z')), - scope('unresolved'): AliasDefinition(destination=scope('z1.missing')), + scope("a.b.c"): AliasDefinition(destination=scope("x.y")), + scope("x.y"): AliasDefinition(destination=scope("x.y2")), + scope("x.y2.z"): ConstDefinition(value=3), + scope("x.y2.s.z"): ConstDefinition(value=4), + scope("x.y2.s2"): AliasDefinition(destination=scope("x.y2.s")), + scope("z0"): AliasDefinition(destination=scope("z1.z2")), + scope("z1.z2"): AliasDefinition(destination=scope("z3")), + scope("z3"): AliasDefinition(destination=scope("z0")), + scope("to_const"): AliasDefinition(destination=scope("x.y2.z")), + scope("unresolved"): AliasDefinition(destination=scope("z1.missing")), } manager = IdentifierManager.from_dict(identifier_dict) # Test manager.get(). - assert manager.get(scope('a.b.c.z.w')) == IdentifierSearchResult( - identifier_definition=identifier_dict[scope('x.y2.z')], - canonical_name=scope('x.y2.z'), - non_parsed=scope('w')) - assert manager.get(scope('to_const.w')) == IdentifierSearchResult( - identifier_definition=identifier_dict[scope('x.y2.z')], - canonical_name=scope('x.y2.z'), - non_parsed=scope('w')) - - with pytest.raises(IdentifierError, match='Cyclic aliasing detected: z0 -> z1.z2 -> z3 -> z0'): - manager.get(scope('z0')) - - with pytest.raises(IdentifierError, match=(re.escape( - 'Alias resolution failed: unresolved -> z1.missing. ' - "Unknown identifier 'z1.missing'."))): - manager.get(scope('unresolved')) + assert manager.get(scope("a.b.c.z.w")) == IdentifierSearchResult( + identifier_definition=identifier_dict[scope("x.y2.z")], + canonical_name=scope("x.y2.z"), + non_parsed=scope("w"), + ) + assert manager.get(scope("to_const.w")) == IdentifierSearchResult( + identifier_definition=identifier_dict[scope("x.y2.z")], + canonical_name=scope("x.y2.z"), + non_parsed=scope("w"), + ) + + with pytest.raises(IdentifierError, match="Cyclic aliasing detected: z0 -> z1.z2 -> z3 -> z0"): + manager.get(scope("z0")) + + with pytest.raises( + IdentifierError, + match=( + re.escape( + "Alias resolution failed: unresolved -> z1.missing. " + "Unknown identifier 'z1.missing'." + ) + ), + ): + manager.get(scope("unresolved")) # Test manager.get_scope(). - assert manager.get_scope(scope('a.b')).fullname == scope('a.b') - assert manager.get_scope(scope('a.b.c')).fullname == scope('x.y2') - assert manager.get_scope(scope('a.b.c.s')).fullname == scope('x.y2.s') - assert manager.get_scope(scope('a.b.c.s2')).fullname == scope('x.y2.s') - - with pytest.raises(IdentifierError, match='Cyclic aliasing detected: z0 -> z1.z2 -> z3 -> z0'): - manager.get_scope(scope('z0')) - with pytest.raises(IdentifierError, match=( - 'Alias resolution failed: unresolved -> z1.missing. ' - "Unknown identifier 'z1.missing'.")): - manager.get_scope(scope('unresolved')) - with pytest.raises(IdentifierError, match=( - "^Identifier 'x.y2.z' is const, expected a scope.")): - manager.get_scope(scope('x.y2.z')) - with pytest.raises(IdentifierError, match=( - 'Alias resolution failed: a.b.c.z.w -> x.y.z.w -> x.y2.z.w. ' - "Identifier 'x.y2.z' is const, expected a scope.")): - manager.get_scope(scope('a.b.c.z.w')) + assert manager.get_scope(scope("a.b")).fullname == scope("a.b") + assert manager.get_scope(scope("a.b.c")).fullname == scope("x.y2") + assert manager.get_scope(scope("a.b.c.s")).fullname == scope("x.y2.s") + assert manager.get_scope(scope("a.b.c.s2")).fullname == scope("x.y2.s") + + with pytest.raises(IdentifierError, match="Cyclic aliasing detected: z0 -> z1.z2 -> z3 -> z0"): + manager.get_scope(scope("z0")) + with pytest.raises( + IdentifierError, + match=( + "Alias resolution failed: unresolved -> z1.missing. " "Unknown identifier 'z1.missing'." + ), + ): + manager.get_scope(scope("unresolved")) + with pytest.raises(IdentifierError, match=("^Identifier 'x.y2.z' is const, expected a scope.")): + manager.get_scope(scope("x.y2.z")) + with pytest.raises( + IdentifierError, + match=( + "Alias resolution failed: a.b.c.z.w -> x.y.z.w -> x.y2.z.w. " + "Identifier 'x.y2.z' is const, expected a scope." + ), + ): + manager.get_scope(scope("a.b.c.z.w")) def test_identifier_manager_search(): identifier_dict = { - scope('a.b.c.y'): ConstDefinition(value=1), - scope('a.b.x'): ConstDefinition(value=2), - scope('a.b.z'): ConstDefinition(value=3), - scope('a.x'): ConstDefinition(value=4), - scope('x'): ConstDefinition(value=5), - scope('d.b.w'): ConstDefinition(value=6), + scope("a.b.c.y"): ConstDefinition(value=1), + scope("a.b.x"): ConstDefinition(value=2), + scope("a.b.z"): ConstDefinition(value=3), + scope("a.x"): ConstDefinition(value=4), + scope("x"): ConstDefinition(value=5), + scope("d.b.w"): ConstDefinition(value=6), } manager = IdentifierManager.from_dict(identifier_dict) for accessible_scopes, name, canonical_name in [ - (['a', 'a.b', 'a.b.c', 'e'], 'x', 'a.b.x'), - (['a', 'a.b'], 'x', 'a.b.x'), - (['a.b', 'a'], 'x', 'a.x'), - ([''], 'x', 'x'), - (['a', 'e', 'a.b.c'], 'b.z', 'a.b.z'), + (["a", "a.b", "a.b.c", "e"], "x", "a.b.x"), + (["a", "a.b"], "x", "a.b.x"), + (["a.b", "a"], "x", "a.x"), + ([""], "x", "x"), + (["a", "e", "a.b.c"], "b.z", "a.b.z"), ]: result = manager.search(list(map(scope, accessible_scopes)), scope(name)) assert result.canonical_name == scope(canonical_name) assert result.identifier_definition == identifier_dict[scope(canonical_name)] with pytest.raises(IdentifierError, match="Unknown identifier 'x'"): - manager.search([], scope('x')) + manager.search([], scope("x")) # Since 'd.b' exists, and it does not contain a sub-identifier 'z' the following raises an # exception (even though a.b.z exists). # Compare with the line (['a', 'e', 'a.b.c'], 'b.z', 'a.b.z') above. with pytest.raises(IdentifierError, match="Unknown identifier 'd.b.z'."): - manager.search([scope('a'), scope('d'), scope('e'), scope('a.b.c')], scope('b.z.w')) + manager.search([scope("a"), scope("d"), scope("e"), scope("a.b.c")], scope("b.z.w")) diff --git a/src/starkware/cairo/lang/compiler/identifier_utils.py b/src/starkware/cairo/lang/compiler/identifier_utils.py index 7900318c..311b8619 100644 --- a/src/starkware/cairo/lang/compiler/identifier_utils.py +++ b/src/starkware/cairo/lang/compiler/identifier_utils.py @@ -2,13 +2,15 @@ from starkware.cairo.lang.compiler.identifier_definition import DefinitionError, StructDefinition from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierManager, MissingIdentifierError) + IdentifierManager, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName def get_struct_definition( - struct_name: ScopedName, - identifier_manager: IdentifierManager) -> StructDefinition: + struct_name: ScopedName, identifier_manager: IdentifierManager +) -> StructDefinition: """ Returns the struct definition of a struct given its full name (no alias resolution). """ @@ -18,22 +20,23 @@ def get_struct_definition( raise MissingIdentifierError(struct_name) if not isinstance(struct_def, StructDefinition): - raise DefinitionError(f"""\ -Expected '{struct_name}' to be a {StructDefinition.TYPE}. Found: '{struct_def.TYPE}'.""") + raise DefinitionError( + f"""\ +Expected '{struct_name}' to be a {StructDefinition.TYPE}. Found: '{struct_def.TYPE}'.""" + ) return struct_def def get_struct_member_offsets( - struct_name: ScopedName, - identifier_manager: IdentifierManager) -> Dict[str, int]: + struct_name: ScopedName, identifier_manager: IdentifierManager +) -> Dict[str, int]: """ Returns a dict that maps a struct member name to its offset in the struct. """ struct_def = get_struct_definition( - struct_name=struct_name, identifier_manager=identifier_manager) + struct_name=struct_name, identifier_manager=identifier_manager + ) - return { - name: member_def.offset for name, member_def in struct_def.members.items() - } + return {name: member_def.offset for name, member_def in struct_def.members.items()} diff --git a/src/starkware/cairo/lang/compiler/identifier_utils_test.py b/src/starkware/cairo/lang/compiler/identifier_utils_test.py index 23d3ce8f..76d16e44 100644 --- a/src/starkware/cairo/lang/compiler/identifier_utils_test.py +++ b/src/starkware/cairo/lang/compiler/identifier_utils_test.py @@ -4,9 +4,15 @@ from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, DefinitionError, MemberDefinition, StructDefinition) + ConstDefinition, + DefinitionError, + MemberDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierManager, MissingIdentifierError) + IdentifierManager, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -15,32 +21,31 @@ def test_get_struct_definition(): identifier_dict = { - scope('T'): StructDefinition( - full_name=scope('T'), + scope("T"): StructDefinition( + full_name=scope("T"), members={ - 'a': MemberDefinition(offset=0, cairo_type=TypeFelt()), - 'b': MemberDefinition(offset=1, cairo_type=TypeFelt()), + "a": MemberDefinition(offset=0, cairo_type=TypeFelt()), + "b": MemberDefinition(offset=1, cairo_type=TypeFelt()), }, size=2, ), - scope('MyConst'): ConstDefinition(value=5), + scope("MyConst"): ConstDefinition(value=5), } manager = IdentifierManager.from_dict(identifier_dict) - struct_def = get_struct_definition(ScopedName.from_string('T'), manager) + struct_def = get_struct_definition(ScopedName.from_string("T"), manager) # Convert to a list, to check the order of the elements in the dict. assert list(struct_def.members.items()) == [ - ('a', MemberDefinition(offset=0, cairo_type=TypeFelt())), - ('b', MemberDefinition(offset=1, cairo_type=TypeFelt())), + ("a", MemberDefinition(offset=0, cairo_type=TypeFelt())), + ("b", MemberDefinition(offset=1, cairo_type=TypeFelt())), ] assert struct_def.size == 2 with pytest.raises(DefinitionError, match="Expected 'MyConst' to be a struct. Found: 'const'."): - get_struct_definition(scope('MyConst'), manager) + get_struct_definition(scope("MyConst"), manager) - with pytest.raises( - MissingIdentifierError, match=re.escape("Unknown identifier 'abc'.")): - get_struct_definition(scope('abc'), manager) + with pytest.raises(MissingIdentifierError, match=re.escape("Unknown identifier 'abc'.")): + get_struct_definition(scope("abc"), manager) diff --git a/src/starkware/cairo/lang/compiler/import_loader.py b/src/starkware/cairo/lang/compiler/import_loader.py index 124a58ff..670a9ca7 100644 --- a/src/starkware/cairo/lang/compiler/import_loader.py +++ b/src/starkware/cairo/lang/compiler/import_loader.py @@ -1,7 +1,11 @@ from typing import Callable, Dict, List, Optional, Tuple from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeBlock, CodeElement, CodeElementFunction, CodeElementImport) + CodeBlock, + CodeElement, + CodeElementFunction, + CodeElementImport, +) from starkware.cairo.lang.compiler.ast.module import CairoFile from starkware.cairo.lang.compiler.ast.visitor import Visitor, get_lang_from_file from starkware.cairo.lang.compiler.error_handling import Location, LocationError @@ -10,7 +14,8 @@ def collect_imports( - curr_pkg_name: str, read_file: Callable[[str], Tuple[str, str]]) -> Dict[str, CairoFile]: + curr_pkg_name: str, read_file: Callable[[str], Tuple[str, str]] +) -> Dict[str, CairoFile]: """ Scans the graph of file imports (using DFS), starting with curr_pkg_name, and returns an ordered dictionary mapping package names to CairoFile AST. @@ -32,14 +37,14 @@ class UsingCycleError(Exception): """ def __init__(self, cycle: List[str]): - super().__init__(f'Found circular imports dependency:\n{self.cycle_to_string(cycle)}') + super().__init__(f"Found circular imports dependency:\n{self.cycle_to_string(cycle)}") self.cycle = cycle @staticmethod def cycle_to_string(cycle): - res = '' + res = "" for v in cycle[:-1]: - res += f'{v} imports\n' + res += f"{v} imports\n" res += cycle[-1] return res @@ -70,8 +75,8 @@ def collect(self, curr_pkg_name: str, location: Optional[Location] = None): raise ImportLoaderError(str(e), location=location) except Exception as e: raise ImportLoaderError( - f"Could not load module '{curr_pkg_name}'.\nError: {e}", - location=location) + f"Could not load module '{curr_pkg_name}'.\nError: {e}", location=location + ) parsed_file: CairoFile = parse_file(code, filename=filename) @@ -90,8 +95,9 @@ def collect(self, curr_pkg_name: str, location: Optional[Location] = None): if not (self.lang[pkg_name] is None or self.lang[pkg_name] == lang): raise ImportLoaderError( f"Importing modules with %lang directive '{self.lang[pkg_name]}' must " - 'be from a module with the same directive.', - location=location) + "be from a module with the same directive.", + location=location, + ) # Pop current package from ancestors list after scanning its dependencies. self.curr_ancestors.pop() @@ -118,7 +124,7 @@ def get_using_pkgs_in_block(self, code_block: CodeBlock): self.visit(elm.code_elm) def _visit_default(self, obj): - assert isinstance(obj, CodeElement), f'Got unexpected type {type(obj).__name__}.' + assert isinstance(obj, CodeElement), f"Got unexpected type {type(obj).__name__}." def visit_CodeBlock(self, elm: CodeBlock): pass diff --git a/src/starkware/cairo/lang/compiler/import_loader_test.py b/src/starkware/cairo/lang/compiler/import_loader_test.py index 1febc599..a936bf6c 100644 --- a/src/starkware/cairo/lang/compiler/import_loader_test.py +++ b/src/starkware/cairo/lang/compiler/import_loader_test.py @@ -6,7 +6,11 @@ from starkware.cairo.lang.compiler.error_handling import LocationError, get_location_marks from starkware.cairo.lang.compiler.import_loader import ( - DirectDependenciesCollector, ImportLoaderError, UsingCycleError, collect_imports) + DirectDependenciesCollector, + ImportLoaderError, + UsingCycleError, + collect_imports, +) from starkware.cairo.lang.compiler.parser import ParserError, parse_file from starkware.cairo.lang.compiler.test_utils import read_file_from_dict @@ -28,91 +32,93 @@ def test_get_imports(): ast = parse_file(code) collector = DirectDependenciesCollector() collector.get_using_pkgs_in_block(ast.code_block) - assert set([x for x, _ in collector.packages]) == {'a', 'b.c.d.e', 'vim', 'pytest'} + assert set([x for x, _ in collector.packages]) == {"a", "b.c.d.e", "vim", "pytest"} def test_unreachabale_file(): files = { - 'root.file': """ + "root.file": """ from fo.o import aa from bar import bb """, - 'bar': '[ap] = 2' + "bar": "[ap] = 2", } # Failed to parse internal module. with pytest.raises(ImportLoaderError) as e: - collect_imports('root.file', read_file_from_dict(files)) + collect_imports("root.file", read_file_from_dict(files)) assert f""" {get_location_marks(files['root.file'], e.value.location)} {e.value.message} -""".startswith(""" +""".startswith( + """ from fo.o import aa ^**^ Could not load module 'fo.o'. -Error: """) +Error: """ + ) # Failed to parse root module. with pytest.raises(ImportLoaderError) as e: - collect_imports('bad.root', read_file_from_dict(files)) + collect_imports("bad.root", read_file_from_dict(files)) assert e.value.message.startswith("Could not load module 'bad.root'.") def test_unparsable_import(): files = { - 'root.file': """ + "root.file": """ from foo import bar """, - 'foo': 'this is not cairo code' + "foo": "this is not cairo code", } with pytest.raises(ParserError): - collect_imports('root.file', read_file_from_dict(files)) + collect_imports("root.file", read_file_from_dict(files)) def test_shallow_tree_graph(): files = { - 'root.file': """ + "root.file": """ from a import aa from b import bb """, - 'a': '[ap] = 1', - 'b': '[ap] = 2' + "a": "[ap] = 1", + "b": "[ap] = 2", } expected_res = {name: parse_file(code) for name, code in files.items()} - assert collect_imports('root.file', read_file_from_dict(files)) == expected_res - assert set(collect_imports('a', read_file_from_dict(files)).keys()) == {'a'} + assert collect_imports("root.file", read_file_from_dict(files)) == expected_res + assert set(collect_imports("a", read_file_from_dict(files)).keys()) == {"a"} def test_long_path_grph(): - files = {f'a{i}': f'from a{i+1} import b' for i in range(10)} - files['a9'] = '[ap] = 0' + files = {f"a{i}": f"from a{i+1} import b" for i in range(10)} + files["a9"] = "[ap] = 0" expected_res = {name: parse_file(code) for name, code in files.items()} - assert collect_imports('a0', read_file_from_dict(files)) == expected_res + assert collect_imports("a0", read_file_from_dict(files)) == expected_res def test_dag(): files = { - 'root.file': """ + "root.file": """ from a import aa from b import bb """, - 'a': """ + "a": """ from common.first import some1 from common.second import some2 """, - 'b': """ + "b": """ from common.first import some1 from common.second import some2 """, - 'common.first': '[ap] = 1', - 'common.second': '[ap] = 2', + "common.first": "[ap] = 1", + "common.second": "[ap] = 2", } expected_res = {name: parse_file(code) for name, code in files.items()} - assert collect_imports('root.file', read_file_from_dict(files)) == expected_res + assert collect_imports("root.file", read_file_from_dict(files)) == expected_res def test_topologycal_order(): @@ -139,10 +145,10 @@ def test_topologycal_order(): files: Dict[str, str] = {} for i in range(N_VERTICES): # Build the i-th file. - files[f'a{i}'] = '\n'.join([f'from a{j} import nothing' for j in dependencies[i]]) + files[f"a{i}"] = "\n".join([f"from a{j} import nothing" for j in dependencies[i]]) # Collect packages. - packages = collect_imports('a0', read_file_from_dict(files)) + packages = collect_imports("a0", read_file_from_dict(files)) # Test order. seen = [False] * N_VERTICES @@ -156,18 +162,23 @@ def test_topologycal_order(): def test_circular_dep(): # Singleton circle. with pytest.raises(UsingCycleError) as e: - collect_imports('a', read_file_from_dict({'a': 'from a import b'})) - assert str(e.value) == """\ + collect_imports("a", read_file_from_dict({"a": "from a import b"})) + assert ( + str(e.value) + == """\ Found circular imports dependency: a imports a""" + ) # Big circle. with pytest.raises(UsingCycleError) as e: collect_imports( - 'a0', - read_file_from_dict({f'a{i}': f'from a{(i+1) % 9} import b' for i in range(10)})) - assert str(e.value) == """\ + "a0", read_file_from_dict({f"a{i}": f"from a{(i+1) % 9} import b" for i in range(10)}) + ) + assert ( + str(e.value) + == """\ Found circular imports dependency: a0 imports a1 imports @@ -179,54 +190,68 @@ def test_circular_dep(): a7 imports a8 imports a0""" + ) def test_lang_directive(): files = { - 'a': """ + "a": """ from c import x """, - 'b': """ + "b": """ %lang other_lang from c import x """, - 'c': """ + "c": """ %lang lang from d_lang import x from d_no_lang import x """, - 'd_lang': """ + "d_lang": """ %lang lang const x = 0 """, - 'd_no_lang': """ + "d_no_lang": """ const x = 0 """, - 'e': """ + "e": """ %lang lang # First line. %lang lang # Second line. -"""} +""", + } # Make sure that starting from 'c' does not raise an exception. - collect_imports('c', read_file_from_dict(files)) + collect_imports("c", read_file_from_dict(files)) - verify_exception(files, 'a', """ + verify_exception( + files, + "a", + """ a:?:?: Importing modules with %lang directive 'lang' must be from a module with the same directive. from c import x ^ -""") +""", + ) - verify_exception(files, 'b', """ + verify_exception( + files, + "b", + """ b:?:?: Importing modules with %lang directive 'lang' must be from a module with the same directive. from c import x ^ -""") +""", + ) - verify_exception(files, 'e', """ + verify_exception( + files, + "e", + """ e:?:?: Found two %lang directives %lang lang # Second line. ^********^ -""") +""", + ) def verify_exception(files: Dict[str, str], main_file: str, error: str): @@ -236,4 +261,4 @@ def verify_exception(files: Dict[str, str], main_file: str, error: str): with pytest.raises(LocationError) as e: collect_imports(main_file, read_file_from_dict(files)) # Remove line and column information from the error using a regular expression. - assert re.sub(':[0-9]+:[0-9]+: ', ':?:?: ', str(e.value)) == error.strip() + assert re.sub(":[0-9]+:[0-9]+: ", ":?:?: ", str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/instruction.py b/src/starkware/cairo/lang/compiler/instruction.py index eade2337..91ca26c1 100644 --- a/src/starkware/cairo/lang/compiler/instruction.py +++ b/src/starkware/cairo/lang/compiler/instruction.py @@ -34,6 +34,7 @@ class Op1Addr(Enum): FP = auto() # op1 = [op0]. OP0 = auto() + op1_addr: Op1Addr class Res(Enum): @@ -45,6 +46,7 @@ class Res(Enum): MUL = auto() # res is not constrained. UNCONSTRAINED = auto() + res: Res # Flags for register update. @@ -58,6 +60,7 @@ class PcUpdate(Enum): # Next pc: jnz_addr (jnz), where jnz_addr is a complex expression, representing the jnz # logic. JNZ = auto() + pc_update: PcUpdate class ApUpdate(Enum): @@ -69,6 +72,7 @@ class ApUpdate(Enum): ADD1 = auto() # Next ap: ap + 2. ADD2 = auto() + ap_update: ApUpdate class FpUpdate(Enum): @@ -78,6 +82,7 @@ class FpUpdate(Enum): AP_PLUS2 = auto() # Next fp: operand_dst. DST = auto() + fp_update: FpUpdate # Flags for opcodes. @@ -86,6 +91,7 @@ class Opcode(Enum): ASSERT_EQ = auto() CALL = auto() RET = auto() + opcode: Opcode @property @@ -97,7 +103,7 @@ def decode_instruction_values(encoded_instruction): """ Returns a tuple (flags, off0, off1, off2) according to the given encoded instruction. """ - assert 0 <= encoded_instruction < 2 ** (3 * OFFSET_BITS + N_FLAGS), 'Unsupported instruction.' + assert 0 <= encoded_instruction < 2 ** (3 * OFFSET_BITS + N_FLAGS), "Unsupported instruction." off0 = encoded_instruction & (2 ** OFFSET_BITS - 1) off1 = (encoded_instruction >> OFFSET_BITS) & (2 ** OFFSET_BITS - 1) off2 = (encoded_instruction >> (2 * OFFSET_BITS)) & (2 ** OFFSET_BITS - 1) diff --git a/src/starkware/cairo/lang/compiler/instruction_builder.py b/src/starkware/cairo/lang/compiler/instruction_builder.py index 68c18f4e..4d36004f 100644 --- a/src/starkware/cairo/lang/compiler/instruction_builder.py +++ b/src/starkware/cairo/lang/compiler/instruction_builder.py @@ -2,10 +2,23 @@ from typing import Optional, Tuple, cast from starkware.cairo.lang.compiler.ast.expr import ( - ExprConst, ExprDeref, Expression, ExprOperator, ExprReg) + ExprConst, + ExprDeref, + Expression, + ExprOperator, + ExprReg, +) from starkware.cairo.lang.compiler.ast.instructions import ( - AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, - JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) + AddApInstruction, + AssertEqInstruction, + CallInstruction, + CallLabelInstruction, + InstructionAst, + JnzInstruction, + JumpInstruction, + JumpToLabelInstruction, + RetInstruction, +) from starkware.cairo.lang.compiler.const_expr_checker import is_const_expr from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.instruction import OFFSET_BITS, Instruction, Register @@ -30,8 +43,9 @@ def build_instruction(instruction: InstructionAst) -> Instruction: return _build_addap_instruction(instruction) else: raise InstructionBuilderError( - f'Instructions of type {type(instruction.body).__name__} are not implemented.', - location=instruction.body.location) + f"Instructions of type {type(instruction.body).__name__} are not implemented.", + location=instruction.body.location, + ) def get_instruction_size(instruction: InstructionAst, allow_auto_deduction: bool = False): @@ -63,12 +77,12 @@ def _apply_inverse_syntactic_sugar(instruction_ast: AssertEqInstruction) -> Asse return instruction_ast expr: ExprOperator = instruction_ast.b - for op, inv_op in [('+', '-'), ('*', '/')]: + for op, inv_op in [("+", "-"), ("*", "/")]: if expr.op == inv_op: if isinstance(expr.b, ExprConst): # The preprocessor should have taken care of this. raise InstructionBuilderError( - 'Subtraction and division are not supported for immediates.', + "Subtraction and division are not supported for immediates.", location=expr.b.location, ) return AssertEqInstruction( @@ -79,7 +93,7 @@ def _apply_inverse_syntactic_sugar(instruction_ast: AssertEqInstruction) -> Asse b=expr.b, location=instruction_ast.location, ), - location=instruction_ast.location + location=instruction_ast.location, ) return instruction_ast @@ -101,12 +115,14 @@ def _build_assert_eq_instruction(instruction_ast: InstructionAst) -> Instruction try: # If it fails, try to parse it as b = a instead of a = b. instruction_body: AssertEqInstruction = cast(AssertEqInstruction, instruction_ast.body) - return _build_assert_eq_instruction_inner(dataclasses.replace( - instruction_ast, - body=dataclasses.replace( - instruction_body, - a=instruction_body.b, - b=instruction_body.a))) + return _build_assert_eq_instruction_inner( + dataclasses.replace( + instruction_ast, + body=dataclasses.replace( + instruction_body, a=instruction_body.b, b=instruction_body.a + ), + ) + ) except Exception: # If both fail, raise the exception thrown by parsing the original form. raise exc from None @@ -125,8 +141,9 @@ def _build_assert_eq_instruction_inner(instruction_ast: InstructionAst) -> Instr res_desc = _parse_res(instruction_body.b) - ap_update = Instruction.ApUpdate.ADD1 if instruction_ast.inc_ap else \ - Instruction.ApUpdate.REGULAR + ap_update = ( + Instruction.ApUpdate.ADD1 if instruction_ast.inc_ap else Instruction.ApUpdate.REGULAR + ) return Instruction( off0=off0, @@ -152,10 +169,12 @@ def _build_jump_instruction(instruction_ast: InstructionAst) -> Instruction: res_desc = _parse_res(instruction_body.val) - ap_update = Instruction.ApUpdate.ADD1 if instruction_ast.inc_ap else \ - Instruction.ApUpdate.REGULAR - pc_update = Instruction.PcUpdate.JUMP_REL if instruction_body.relative else \ - Instruction.PcUpdate.JUMP + ap_update = ( + Instruction.ApUpdate.ADD1 if instruction_ast.inc_ap else Instruction.ApUpdate.REGULAR + ) + pc_update = ( + Instruction.PcUpdate.JUMP_REL if instruction_body.relative else Instruction.PcUpdate.JUMP + ) return Instruction( # In this case dst is not involved. Choose [fp - 1] as the default. @@ -194,10 +213,12 @@ def _build_jnz_instruction(instruction_ast: InstructionAst) -> Instruction: op1_addr = Instruction.Op1Addr.IMM else: raise InstructionBuilderError( - 'Invalid expression for jmp offset.', location=jump_offset.location) + "Invalid expression for jmp offset.", location=jump_offset.location + ) - ap_update = Instruction.ApUpdate.ADD1 if instruction_ast.inc_ap else \ - Instruction.ApUpdate.REGULAR + ap_update = ( + Instruction.ApUpdate.ADD1 if instruction_ast.inc_ap else Instruction.ApUpdate.REGULAR + ) return Instruction( off0=off0, @@ -232,15 +253,16 @@ def _build_call_instruction(instruction_ast: InstructionAst) -> Instruction: imm = val.val op1_addr = Instruction.Op1Addr.IMM else: - raise InstructionBuilderError( - 'Invalid offset for call.', location=val.location) + raise InstructionBuilderError("Invalid offset for call.", location=val.location) if instruction_ast.inc_ap: raise InstructionBuilderError( - 'ap++ may not be used with the call opcode.', location=instruction_ast.location) + "ap++ may not be used with the call opcode.", location=instruction_ast.location + ) - pc_update = Instruction.PcUpdate.JUMP_REL if instruction_body.relative else \ - Instruction.PcUpdate.JUMP + pc_update = ( + Instruction.PcUpdate.JUMP_REL if instruction_body.relative else Instruction.PcUpdate.JUMP + ) return Instruction( # Use dst for [ap] <- fp. @@ -268,7 +290,8 @@ def _build_ret_instruction(instruction_ast: InstructionAst) -> Instruction: if instruction_ast.inc_ap: raise InstructionBuilderError( - 'ap++ may not be used with the ret opcode.', location=instruction_ast.location) + "ap++ may not be used with the ret opcode.", location=instruction_ast.location + ) return Instruction( # Use dst for fp <- [fp - 2]. @@ -300,7 +323,8 @@ def _build_addap_instruction(instruction_ast: InstructionAst) -> Instruction: if instruction_ast.inc_ap: raise InstructionBuilderError( - 'ap++ may not be used with the addap opcode.', location=instruction_ast.location) + "ap++ may not be used with the addap opcode.", location=instruction_ast.location + ) return Instruction( # In this case dst is not involved. Choose [fp - 1] as the default. @@ -352,9 +376,7 @@ def _parse_res(expr: Expression) -> ResDescription: elif isinstance(expr, ExprOperator): return _parse_res_operator(expr) else: - raise InstructionBuilderError( - 'Invalid RHS expression.', - location=expr.location) + raise InstructionBuilderError("Invalid RHS expression.", location=expr.location) def _parse_res_deref(expr: Expression) -> ResDescription: @@ -363,7 +385,8 @@ def _parse_res_deref(expr: Expression) -> ResDescription: corresponding to [[reg + off] + off] or [fp + off] respectively. """ if isinstance(expr, ExprDeref) or ( - isinstance(expr, ExprOperator) and isinstance(expr.a, ExprDeref)): + isinstance(expr, ExprOperator) and isinstance(expr.a, ExprDeref) + ): # Double dereference. inner, off2 = _parse_offset(expr) inner_addr = _parse_dereference(inner) @@ -394,14 +417,14 @@ def _parse_res_operator(expr: ExprOperator) -> ResDescription: Given an expression of the form "[reg + off] * [reg + off]" or "[reg + off] * imm" (* can be replaced by +), returns the corresponding ResDescription. """ - if expr.op == '+': + if expr.op == "+": res = Instruction.Res.ADD - elif expr.op == '*': + elif expr.op == "*": res = Instruction.Res.MUL else: raise InstructionBuilderError( - f"Expected '+' or '*', found: '{expr.op}'.", - location=expr.location) + f"Expected '+' or '*', found: '{expr.op}'.", location=expr.location + ) # Parse op0. op0_expr = _parse_dereference(expr.a) @@ -419,8 +442,9 @@ def _parse_res_operator(expr: ExprOperator) -> ResDescription: op1_addr = Instruction.Op1Addr.FP if op1_reg is Register.FP else Instruction.Op1Addr.AP else: raise InstructionBuilderError( - 'Expected a constant expression or a dereference expression.', - location=op1_expr.location) + "Expected a constant expression or a dereference expression.", + location=op1_expr.location, + ) return ResDescription( off1=off1, @@ -439,8 +463,7 @@ def _parse_dereference(expr: Expression): """ if not isinstance(expr, ExprDeref): - raise InstructionBuilderError( - 'Expected a dereference expression.', location=expr.location) + raise InstructionBuilderError("Expected a dereference expression.", location=expr.location) return expr.addr @@ -466,18 +489,20 @@ def _parse_offset(expr: Expression) -> Tuple[Expression, int]: if not isinstance(expr, ExprOperator): return expr, 0 - if expr.op == '+': + if expr.op == "+": sign = 1 - elif expr.op == '-': + elif expr.op == "-": sign = -1 else: raise InstructionBuilderError( - f"Expected '+' or '-', found: '{expr.op}'.", location=expr.location) + f"Expected '+' or '-', found: '{expr.op}'.", location=expr.location + ) offset_limit = 2 ** (OFFSET_BITS - 1) if not isinstance(expr.b, ExprConst) or not -offset_limit <= sign * expr.b.val < offset_limit: raise InstructionBuilderError( - f'Expected a constant offset in the range [-2^{OFFSET_BITS - 1}, 2^{OFFSET_BITS - 1}).', - location=expr.b.location) + f"Expected a constant offset in the range [-2^{OFFSET_BITS - 1}, 2^{OFFSET_BITS - 1}).", + location=expr.b.location, + ) return expr.a, sign * expr.b.val @@ -487,6 +512,7 @@ def _parse_register(expr: Expression) -> Register: """ if not isinstance(expr, ExprReg): raise InstructionBuilderError( - f'Expected a register. Found: {expr.format()}.', location=expr.location) + f"Expected a register. Found: {expr.format()}.", location=expr.location + ) return expr.reg diff --git a/src/starkware/cairo/lang/compiler/instruction_builder_test.py b/src/starkware/cairo/lang/compiler/instruction_builder_test.py index 4520e6c2..36ca2bc3 100644 --- a/src/starkware/cairo/lang/compiler/instruction_builder_test.py +++ b/src/starkware/cairo/lang/compiler/instruction_builder_test.py @@ -3,7 +3,9 @@ from starkware.cairo.lang.compiler.error_handling import get_location_marks from starkware.cairo.lang.compiler.instruction import Instruction, Register from starkware.cairo.lang.compiler.instruction_builder import ( - InstructionBuilderError, build_instruction) + InstructionBuilderError, + build_instruction, +) from starkware.cairo.lang.compiler.parser import parse_instruction @@ -15,519 +17,572 @@ def parse_and_build(inst: str) -> Instruction: def test_assert_eq(): - assert parse_and_build('[ap] = [fp]; ap++') == \ - Instruction( - off0=0, - off1=-1, - off2=0, - imm=None, - dst_register=Register.AP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.ADD1, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) - assert parse_and_build('[fp - 3] = [fp + 7]') == \ - Instruction( - off0=-3, - off1=-1, - off2=7, - imm=None, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) - assert parse_and_build('[ap - 3] = [ap]') == \ - Instruction( - off0=-3, - off1=-1, - off2=0, - imm=None, - dst_register=Register.AP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.AP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) + assert parse_and_build("[ap] = [fp]; ap++") == Instruction( + off0=0, + off1=-1, + off2=0, + imm=None, + dst_register=Register.AP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.ADD1, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) + assert parse_and_build("[fp - 3] = [fp + 7]") == Instruction( + off0=-3, + off1=-1, + off2=7, + imm=None, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) + assert parse_and_build("[ap - 3] = [ap]") == Instruction( + off0=-3, + off1=-1, + off2=0, + imm=None, + dst_register=Register.AP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.AP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) def test_assert_eq_reversed(): - assert parse_and_build('5 = [fp + 1]') == parse_and_build('[fp + 1] = 5') - assert parse_and_build('[[ap + 2] + 3] = [fp + 1]; ap++') == \ - parse_and_build('[fp + 1] = [[ap + 2] + 3]; ap++') - assert parse_and_build('[ap] + [fp] = [fp + 1]') == parse_and_build('[fp + 1] = [ap] + [fp]') + assert parse_and_build("5 = [fp + 1]") == parse_and_build("[fp + 1] = 5") + assert parse_and_build("[[ap + 2] + 3] = [fp + 1]; ap++") == parse_and_build( + "[fp + 1] = [[ap + 2] + 3]; ap++" + ) + assert parse_and_build("[ap] + [fp] = [fp + 1]") == parse_and_build("[fp + 1] = [ap] + [fp]") def test_assert_eq_instruction_failures(): - verify_exception("""\ + verify_exception( + """\ fp - 3 = [fp] ^****^ Expected a dereference expression. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ ap = [fp] ^^ Expected a dereference expression. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap] = [fp * 3] ^****^ Expected '+' or '-', found: '*'. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap] = [fp + 32768] ^***^ Expected a constant offset in the range [-2^15, 2^15). -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap] = [fp - 32769] ^***^ Expected a constant offset in the range [-2^15, 2^15). -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [5] = [fp] ^ Expected a register. Found: 5. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [x + 7] = [15] ^ Expected a register. Found: x. -""") +""" + ) # Make sure that if the instruction is invalid, the error is given for its original form, # rather than the reversed form. - verify_exception("""\ + verify_exception( + """\ [[ap + 1]] = [[ap + 1]] ^******^ Expected a register. Found: [ap + 1]. -""") +""" + ) def test_assert_eq_double_dereference(): - assert parse_and_build('[ap + 2] = [[fp]]') == \ - Instruction( - off0=2, - off1=0, - off2=0, - imm=None, - dst_register=Register.AP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.OP0, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) - assert parse_and_build('[ap + 2] = [[ap - 4] + 7]; ap++') == \ - Instruction( - off0=2, - off1=-4, - off2=7, - imm=None, - dst_register=Register.AP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.OP0, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.ADD1, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) + assert parse_and_build("[ap + 2] = [[fp]]") == Instruction( + off0=2, + off1=0, + off2=0, + imm=None, + dst_register=Register.AP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.OP0, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) + assert parse_and_build("[ap + 2] = [[ap - 4] + 7]; ap++") == Instruction( + off0=2, + off1=-4, + off2=7, + imm=None, + dst_register=Register.AP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.OP0, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.ADD1, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) def test_assert_eq_double_dereference_failures(): - verify_exception("""\ + verify_exception( + """\ [ap + 2] = [[fp + 32768] + 17] ^***^ Expected a constant offset in the range [-2^15, 2^15). -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap + 2] = [[fp * 32768] + 17] ^********^ Expected '+' or '-', found: '*'. -""") +""" + ) def test_assert_eq_imm(): - assert parse_and_build('[ap + 2] = 1234567890') == \ - Instruction( - off0=2, - off1=-1, - off2=1, - imm=1234567890, - dst_register=Register.AP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.IMM, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) + assert parse_and_build("[ap + 2] = 1234567890") == Instruction( + off0=2, + off1=-1, + off2=1, + imm=1234567890, + dst_register=Register.AP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.IMM, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) def test_assert_eq_operation(): - assert parse_and_build('[ap + 1] = [ap - 7] * [fp + 3]') == \ - Instruction( - off0=1, - off1=-7, - off2=3, - imm=None, - dst_register=Register.AP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.MUL, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) - assert parse_and_build('[ap + 10] = [fp] + 1234567890') == \ - Instruction( - off0=10, - off1=0, - off2=1, - imm=1234567890, - dst_register=Register.AP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.IMM, - res=Instruction.Res.ADD, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) - assert parse_and_build('[fp - 3] = [ap + 7] * [ap + 8]') == \ - Instruction( - off0=-3, - off1=7, - off2=8, - imm=None, - dst_register=Register.FP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.AP, - res=Instruction.Res.MUL, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.ASSERT_EQ) + assert parse_and_build("[ap + 1] = [ap - 7] * [fp + 3]") == Instruction( + off0=1, + off1=-7, + off2=3, + imm=None, + dst_register=Register.AP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.MUL, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) + assert parse_and_build("[ap + 10] = [fp] + 1234567890") == Instruction( + off0=10, + off1=0, + off2=1, + imm=1234567890, + dst_register=Register.AP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.IMM, + res=Instruction.Res.ADD, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) + assert parse_and_build("[fp - 3] = [ap + 7] * [ap + 8]") == Instruction( + off0=-3, + off1=7, + off2=8, + imm=None, + dst_register=Register.FP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.AP, + res=Instruction.Res.MUL, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.ASSERT_EQ, + ) def test_inverse_syntactic_sugar(): - assert parse_and_build('[fp] = [ap + 10] - [fp - 1]') == \ - parse_and_build('[ap + 10] = [fp] + [fp - 1]') - assert parse_and_build('[fp] = [ap + 10] / [fp - 1]') == \ - parse_and_build('[ap + 10] = [fp] * [fp - 1]') + assert parse_and_build("[fp] = [ap + 10] - [fp - 1]") == parse_and_build( + "[ap + 10] = [fp] + [fp - 1]" + ) + assert parse_and_build("[fp] = [ap + 10] / [fp - 1]") == parse_and_build( + "[ap + 10] = [fp] * [fp - 1]" + ) def test_inverse_syntactic_sugar_failures(): # The syntactic sugar for sub is op0 = dst - op1. - verify_exception("""\ + verify_exception( + """\ [fp] = [ap + 10] - 1234567890 ^********^ Subtraction and division are not supported for immediates. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [fp] = [ap + 10] / 1234567890 ^********^ Subtraction and division are not supported for immediates. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ 1234567890 = [ap + 10] - [fp] ^********^ Expected a dereference expression. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap] = [[fp]] - [ap] ^**^ Expected a register. Found: [fp]. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap] = 5 - [ap] ^ Expected a dereference expression. -""") +""" + ) def test_assert_eq_operation_failures(): - verify_exception("""\ + verify_exception( + """\ [ap + 1] = 1234 * [fp] ^**^ Expected a dereference expression. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ [ap + 1] = [fp] + [fp] * [fp] ^*********^ Expected a constant expression or a dereference expression. -""") +""" + ) def test_jump_instruction(): - assert parse_and_build('jmp rel [ap + 1] + [fp - 7]') == \ - Instruction( - off0=-1, - off1=1, - off2=-7, - imm=None, - dst_register=Register.FP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.ADD, - pc_update=Instruction.PcUpdate.JUMP_REL, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert parse_and_build('jmp abs 123; ap++') == \ - Instruction( - off0=-1, - off1=-1, - off2=1, - imm=123, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.IMM, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.JUMP, - ap_update=Instruction.ApUpdate.ADD1, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert parse_and_build('jmp rel [ap + 1] + [ap - 7]') == \ - Instruction( - off0=-1, - off1=1, - off2=-7, - imm=None, - dst_register=Register.FP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.AP, - res=Instruction.Res.ADD, - pc_update=Instruction.PcUpdate.JUMP_REL, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) + assert parse_and_build("jmp rel [ap + 1] + [fp - 7]") == Instruction( + off0=-1, + off1=1, + off2=-7, + imm=None, + dst_register=Register.FP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.ADD, + pc_update=Instruction.PcUpdate.JUMP_REL, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) + assert parse_and_build("jmp abs 123; ap++") == Instruction( + off0=-1, + off1=-1, + off2=1, + imm=123, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.IMM, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.JUMP, + ap_update=Instruction.ApUpdate.ADD1, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) + assert parse_and_build("jmp rel [ap + 1] + [ap - 7]") == Instruction( + off0=-1, + off1=1, + off2=-7, + imm=None, + dst_register=Register.FP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.AP, + res=Instruction.Res.ADD, + pc_update=Instruction.PcUpdate.JUMP_REL, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) def test_jnz_instruction(): - assert parse_and_build('jmp rel [fp - 1] if [fp - 7] != 0') == \ - Instruction( - off0=-7, - off1=-1, - off2=-1, - imm=None, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.UNCONSTRAINED, - pc_update=Instruction.PcUpdate.JNZ, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert parse_and_build('jmp rel [ap - 1] if [fp - 7] != 0') == \ - Instruction( - off0=-7, - off1=-1, - off2=-1, - imm=None, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.AP, - res=Instruction.Res.UNCONSTRAINED, - pc_update=Instruction.PcUpdate.JNZ, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert parse_and_build('jmp rel 123 if [ap] != 0; ap++') == \ - Instruction( - off0=0, - off1=-1, - off2=1, - imm=123, - dst_register=Register.AP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.IMM, - res=Instruction.Res.UNCONSTRAINED, - pc_update=Instruction.PcUpdate.JNZ, - ap_update=Instruction.ApUpdate.ADD1, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) + assert parse_and_build("jmp rel [fp - 1] if [fp - 7] != 0") == Instruction( + off0=-7, + off1=-1, + off2=-1, + imm=None, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.UNCONSTRAINED, + pc_update=Instruction.PcUpdate.JNZ, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) + assert parse_and_build("jmp rel [ap - 1] if [fp - 7] != 0") == Instruction( + off0=-7, + off1=-1, + off2=-1, + imm=None, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.AP, + res=Instruction.Res.UNCONSTRAINED, + pc_update=Instruction.PcUpdate.JNZ, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) + assert parse_and_build("jmp rel 123 if [ap] != 0; ap++") == Instruction( + off0=0, + off1=-1, + off2=1, + imm=123, + dst_register=Register.AP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.IMM, + res=Instruction.Res.UNCONSTRAINED, + pc_update=Instruction.PcUpdate.JNZ, + ap_update=Instruction.ApUpdate.ADD1, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) def test_jnz_instruction_failures(): - verify_exception("""\ + verify_exception( + """\ jmp rel [fp] if 5 != 0 ^ Expected a dereference expression. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ jmp rel [ap] if [fp] + 3 != 0 ^******^ Expected a dereference expression. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ jmp rel [ap] if [fp * 3] != 0 ^****^ Expected '+' or '-', found: '*'. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ jmp rel [ap] + [fp] if [fp] != 0 ^*********^ Invalid expression for jmp offset. -""") +""" + ) def test_call_instruction(): - assert parse_and_build('call abs [fp + 4]') == \ - Instruction( - off0=0, - off1=1, - off2=4, - imm=None, - dst_register=Register.AP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.JUMP, - ap_update=Instruction.ApUpdate.ADD2, - fp_update=Instruction.FpUpdate.AP_PLUS2, - opcode=Instruction.Opcode.CALL) - - assert parse_and_build('call rel [fp + 4]') == \ - Instruction( - off0=0, - off1=1, - off2=4, - imm=None, - dst_register=Register.AP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.JUMP_REL, - ap_update=Instruction.ApUpdate.ADD2, - fp_update=Instruction.FpUpdate.AP_PLUS2, - opcode=Instruction.Opcode.CALL) - assert parse_and_build('call rel [ap + 4]') == \ - Instruction( - off0=0, - off1=1, - off2=4, - imm=None, - dst_register=Register.AP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.AP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.JUMP_REL, - ap_update=Instruction.ApUpdate.ADD2, - fp_update=Instruction.FpUpdate.AP_PLUS2, - opcode=Instruction.Opcode.CALL) - assert parse_and_build('call rel 123') == \ - Instruction( - off0=0, - off1=1, - off2=1, - imm=123, - dst_register=Register.AP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.IMM, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.JUMP_REL, - ap_update=Instruction.ApUpdate.ADD2, - fp_update=Instruction.FpUpdate.AP_PLUS2, - opcode=Instruction.Opcode.CALL) + assert parse_and_build("call abs [fp + 4]") == Instruction( + off0=0, + off1=1, + off2=4, + imm=None, + dst_register=Register.AP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.JUMP, + ap_update=Instruction.ApUpdate.ADD2, + fp_update=Instruction.FpUpdate.AP_PLUS2, + opcode=Instruction.Opcode.CALL, + ) + + assert parse_and_build("call rel [fp + 4]") == Instruction( + off0=0, + off1=1, + off2=4, + imm=None, + dst_register=Register.AP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.JUMP_REL, + ap_update=Instruction.ApUpdate.ADD2, + fp_update=Instruction.FpUpdate.AP_PLUS2, + opcode=Instruction.Opcode.CALL, + ) + assert parse_and_build("call rel [ap + 4]") == Instruction( + off0=0, + off1=1, + off2=4, + imm=None, + dst_register=Register.AP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.AP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.JUMP_REL, + ap_update=Instruction.ApUpdate.ADD2, + fp_update=Instruction.FpUpdate.AP_PLUS2, + opcode=Instruction.Opcode.CALL, + ) + assert parse_and_build("call rel 123") == Instruction( + off0=0, + off1=1, + off2=1, + imm=123, + dst_register=Register.AP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.IMM, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.JUMP_REL, + ap_update=Instruction.ApUpdate.ADD2, + fp_update=Instruction.FpUpdate.AP_PLUS2, + opcode=Instruction.Opcode.CALL, + ) def test_call_instruction_failures(): - verify_exception("""\ + verify_exception( + """\ call rel [ap] + 5 ^******^ Invalid offset for call. -""") - verify_exception("""\ +""" + ) + verify_exception( + """\ call rel 5; ap++ ^**************^ ap++ may not be used with the call opcode. -""") +""" + ) def test_ret_instruction(): - assert parse_and_build('ret') == \ - Instruction( - off0=-2, - off1=-1, - off2=-1, - imm=None, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.JUMP, - ap_update=Instruction.ApUpdate.REGULAR, - fp_update=Instruction.FpUpdate.DST, - opcode=Instruction.Opcode.RET) + assert parse_and_build("ret") == Instruction( + off0=-2, + off1=-1, + off2=-1, + imm=None, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.JUMP, + ap_update=Instruction.ApUpdate.REGULAR, + fp_update=Instruction.FpUpdate.DST, + opcode=Instruction.Opcode.RET, + ) def test_ret_instruction_failures(): - verify_exception("""\ + verify_exception( + """\ ret; ap++ ^*******^ ap++ may not be used with the ret opcode. -""") +""" + ) def test_addap_instruction(): - assert parse_and_build('ap += [fp + 4] + [fp]') == \ - Instruction( - off0=-1, - off1=4, - off2=0, - imm=None, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.FP, - res=Instruction.Res.ADD, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.ADD, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert parse_and_build('ap += [ap + 4] + [ap]') == \ - Instruction( - off0=-1, - off1=4, - off2=0, - imm=None, - dst_register=Register.FP, - op0_register=Register.AP, - op1_addr=Instruction.Op1Addr.AP, - res=Instruction.Res.ADD, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.ADD, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) - assert parse_and_build('ap += 123') == \ - Instruction( - off0=-1, - off1=-1, - off2=1, - imm=123, - dst_register=Register.FP, - op0_register=Register.FP, - op1_addr=Instruction.Op1Addr.IMM, - res=Instruction.Res.OP1, - pc_update=Instruction.PcUpdate.REGULAR, - ap_update=Instruction.ApUpdate.ADD, - fp_update=Instruction.FpUpdate.REGULAR, - opcode=Instruction.Opcode.NOP) + assert parse_and_build("ap += [fp + 4] + [fp]") == Instruction( + off0=-1, + off1=4, + off2=0, + imm=None, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.FP, + res=Instruction.Res.ADD, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.ADD, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) + assert parse_and_build("ap += [ap + 4] + [ap]") == Instruction( + off0=-1, + off1=4, + off2=0, + imm=None, + dst_register=Register.FP, + op0_register=Register.AP, + op1_addr=Instruction.Op1Addr.AP, + res=Instruction.Res.ADD, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.ADD, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) + assert parse_and_build("ap += 123") == Instruction( + off0=-1, + off1=-1, + off2=1, + imm=123, + dst_register=Register.FP, + op0_register=Register.FP, + op1_addr=Instruction.Op1Addr.IMM, + res=Instruction.Res.OP1, + pc_update=Instruction.PcUpdate.REGULAR, + ap_update=Instruction.ApUpdate.ADD, + fp_update=Instruction.FpUpdate.REGULAR, + opcode=Instruction.Opcode.NOP, + ) def test_addap_instruction_failures(): - verify_exception("""\ + verify_exception( + """\ ap += 5; ap++ ^***********^ ap++ may not be used with the addap opcode. -""") +""" + ) def verify_exception(code_with_err): @@ -541,5 +596,7 @@ def verify_exception(code_with_err): code = code_with_err.splitlines()[0] with pytest.raises(InstructionBuilderError) as e: parse_and_build(code) - assert get_location_marks(code, e.value.location) + '\n' + str(e.value.message) == \ - code_with_err.rstrip() + assert ( + get_location_marks(code, e.value.location) + "\n" + str(e.value.message) + == code_with_err.rstrip() + ) diff --git a/src/starkware/cairo/lang/compiler/instruction_test.py b/src/starkware/cairo/lang/compiler/instruction_test.py index 4b136a35..596565e0 100644 --- a/src/starkware/cairo/lang/compiler/instruction_test.py +++ b/src/starkware/cairo/lang/compiler/instruction_test.py @@ -3,12 +3,15 @@ import pytest from starkware.cairo.lang.compiler.instruction import ( - N_FLAGS, OFFSET_BITS, decode_instruction_values) + N_FLAGS, + OFFSET_BITS, + decode_instruction_values, +) def test_decode(): - offsets = [randrange(0, 2**OFFSET_BITS) for _ in range(3)] - flags = randrange(0, 2**N_FLAGS) + offsets = [randrange(0, 2 ** OFFSET_BITS) for _ in range(3)] + flags = randrange(0, 2 ** N_FLAGS) instruction = 0 for part in [flags] + offsets[::-1]: instruction = (instruction << OFFSET_BITS) | part @@ -16,5 +19,5 @@ def test_decode(): def test_unsupported_instruction(): - with pytest.raises(AssertionError, match='Unsupported instruction.'): + with pytest.raises(AssertionError, match="Unsupported instruction."): decode_instruction_values(1 << (3 * OFFSET_BITS + N_FLAGS)) diff --git a/src/starkware/cairo/lang/compiler/location_utils.py b/src/starkware/cairo/lang/compiler/location_utils.py index f4bdf066..587f0550 100644 --- a/src/starkware/cairo/lang/compiler/location_utils.py +++ b/src/starkware/cairo/lang/compiler/location_utils.py @@ -6,7 +6,8 @@ def add_parent_location( - expr: Expression, new_parent_location: Optional[Location], message: str) -> Expression: + expr: Expression, new_parent_location: Optional[Location], message: str +) -> Expression: if new_parent_location is None: return expr @@ -15,5 +16,7 @@ def location_modifier(self, location: Optional[Location]) -> Optional[Location]: if location is None: return new_parent_location return location.with_parent_location( - new_parent_location=new_parent_location, message=message) # type: ignore + new_parent_location=new_parent_location, message=message # type: ignore + ) + return AddParentLocationTransformer().visit(expr) diff --git a/src/starkware/cairo/lang/compiler/module_reader.py b/src/starkware/cairo/lang/compiler/module_reader.py index 7ca457f9..6ec6befe 100644 --- a/src/starkware/cairo/lang/compiler/module_reader.py +++ b/src/starkware/cairo/lang/compiler/module_reader.py @@ -26,7 +26,8 @@ def source_files(self): return set(filename for filename, scope in self.source_files_with_scopes) def module_to_file_path( - self, module_name: str, isfile: Callable[[str], bool] = os.path.isfile) -> str: + self, module_name: str, isfile: Callable[[str], bool] = os.path.isfile + ) -> str: """ Translates module name to file path. """ @@ -55,7 +56,7 @@ def read(self, module_name: str) -> Tuple[str, str]: filename = self.module_to_file_path(module_name) self.source_files.add(filename) self.source_files_with_scopes.add((filename, ScopedName.from_string(module_name))) - with open(filename, 'r') as f: + with open(filename, "r") as f: return f.read(), filename @@ -63,5 +64,5 @@ class ModuleNotFoundException(Exception): def __init__(self, module: str, paths: List[str]): msg = f"Could not find module '{module}'. Searched in the following paths:" for path in paths: - msg += '\n' + path + msg += "\n" + path super().__init__(msg) diff --git a/src/starkware/cairo/lang/compiler/module_reader_test.py b/src/starkware/cairo/lang/compiler/module_reader_test.py index 71cfc91a..eee89d41 100644 --- a/src/starkware/cairo/lang/compiler/module_reader_test.py +++ b/src/starkware/cairo/lang/compiler/module_reader_test.py @@ -5,20 +5,23 @@ def test_file_path_extractor(): isfile = lambda _: True - reader = ModuleReader(paths=['/usr/include'], cairo_suffix='.f~o') - assert reader.module_to_file_path('foo.bar', isfile) == '/usr/include/foo/bar.f~o' + reader = ModuleReader(paths=["/usr/include"], cairo_suffix=".f~o") + assert reader.module_to_file_path("foo.bar", isfile) == "/usr/include/foo/bar.f~o" - reader = ModuleReader(paths=['rel//path'], cairo_suffix='.txt') - assert reader.module_to_file_path('hello.world', isfile) == 'rel//path/hello/world.txt' + reader = ModuleReader(paths=["rel//path"], cairo_suffix=".txt") + assert reader.module_to_file_path("hello.world", isfile) == "rel//path/hello/world.txt" def test_search_file(): - reader = ModuleReader(paths=['a', 'b', 'c'], cairo_suffix='.c') - assert reader.module_to_file_path('f', isfile=lambda x: x in ['b/f.c', 'c/f.c']) == 'b/f.c' + reader = ModuleReader(paths=["a", "b", "c"], cairo_suffix=".c") + assert reader.module_to_file_path("f", isfile=lambda x: x in ["b/f.c", "c/f.c"]) == "b/f.c" - with pytest.raises(ModuleNotFoundException, match="""\ + with pytest.raises( + ModuleNotFoundException, + match="""\ Could not find module 'x.y.z'. Searched in the following paths: a/x/y/z.c b/x/y/z.c -c/x/y/z.c"""): - reader.module_to_file_path('x.y.z', isfile=lambda _: False) +c/x/y/z.c""", + ): + reader.module_to_file_path("x.y.z", isfile=lambda _: False) diff --git a/src/starkware/cairo/lang/compiler/offset_reference.py b/src/starkware/cairo/lang/compiler/offset_reference.py index 1dd86802..f208d4f1 100644 --- a/src/starkware/cairo/lang/compiler/offset_reference.py +++ b/src/starkware/cairo/lang/compiler/offset_reference.py @@ -5,9 +5,14 @@ from starkware.cairo.lang.compiler.ast.expr import ExprDot, Expression, ExprIdentifier from starkware.cairo.lang.compiler.identifier_definition import ( - IdentifierDefinition, ReferenceDefinition) + IdentifierDefinition, + ReferenceDefinition, +) from starkware.cairo.lang.compiler.preprocessor.flow import ( - FlowTrackingData, FlowTrackingDataActual, ReferenceManager) + FlowTrackingData, + FlowTrackingDataActual, + ReferenceManager, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -21,20 +26,22 @@ class OffsetReferenceDefinition(IdentifierDefinition): In the example, 'x' is the parent reference and 'y.z' is the member path. When eval() is called, both 'x' and 'T.y' are evaluated and '[x + T.y]' is returned. """ - TYPE: ClassVar[str] = 'offset-reference' + + TYPE: ClassVar[str] = "offset-reference" Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema parent: ReferenceDefinition member_path: ScopedName def eval( - self, reference_manager: ReferenceManager, flow_tracking_data: FlowTrackingData) -> \ - Expression: + self, reference_manager: ReferenceManager, flow_tracking_data: FlowTrackingData + ) -> Expression: reference = flow_tracking_data.resolve_reference( - reference_manager=reference_manager, - name=self.parent.full_name) - assert isinstance(flow_tracking_data, FlowTrackingDataActual), \ - 'Resolved references can only come from FlowTrackingDataActual.' + reference_manager=reference_manager, name=self.parent.full_name + ) + assert isinstance( + flow_tracking_data, FlowTrackingDataActual + ), "Resolved references can only come from FlowTrackingDataActual." expr = reference.eval(flow_tracking_data.ap_tracking) for member_name in self.member_path.path: diff --git a/src/starkware/cairo/lang/compiler/offset_reference_test.py b/src/starkware/cairo/lang/compiler/offset_reference_test.py index d863e5d2..a905a3fd 100644 --- a/src/starkware/cairo/lang/compiler/offset_reference_test.py +++ b/src/starkware/cairo/lang/compiler/offset_reference_test.py @@ -3,7 +3,10 @@ from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.preprocessor.flow import ( - FlowTrackingDataActual, ReferenceManager, RegTrackingData) + FlowTrackingDataActual, + ReferenceManager, + RegTrackingData, +) from starkware.cairo.lang.compiler.references import Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_system import mark_types_in_expr_resolved @@ -12,17 +15,19 @@ def test_offset_reference_definition_typed_members(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) t_star = TypePointer(pointee=t) reference_manager = ReferenceManager() - main_reference = ReferenceDefinition(full_name=scope('a'), cairo_type=t_star, references=[]) + main_reference = ReferenceDefinition(full_name=scope("a"), cairo_type=t_star, references=[]) references = { - scope('a'): reference_manager.alloc_id(Reference( - pc=0, - value=mark_types_in_expr_resolved(parse_expr('cast(ap, T*)')), - ap_tracking_data=RegTrackingData(group=0, offset=0), - )), + scope("a"): reference_manager.alloc_id( + Reference( + pc=0, + value=mark_types_in_expr_resolved(parse_expr("cast(ap, T*)")), + ap_tracking_data=RegTrackingData(group=0, offset=0), + ) + ), } flow_tracking_data = FlowTrackingDataActual( @@ -32,7 +37,10 @@ def test_offset_reference_definition_typed_members(): # Create OffsetReferenceDefinition instance for an expression of the form "a.", # in this case a.x.y.z, and check the result of evaluation of this expression. - definition = OffsetReferenceDefinition(parent=main_reference, member_path=scope('x.y.z')) - assert definition.eval( - reference_manager=reference_manager, - flow_tracking_data=flow_tracking_data).format() == 'cast(ap - 1, T*).x.y.z' + definition = OffsetReferenceDefinition(parent=main_reference, member_path=scope("x.y.z")) + assert ( + definition.eval( + reference_manager=reference_manager, flow_tracking_data=flow_tracking_data + ).format() + == "cast(ap - 1, T*).x.y.z" + ) diff --git a/src/starkware/cairo/lang/compiler/parser.py b/src/starkware/cairo/lang/compiler/parser.py index 43a76b57..3926ec23 100644 --- a/src/starkware/cairo/lang/compiler/parser.py +++ b/src/starkware/cairo/lang/compiler/parser.py @@ -4,7 +4,12 @@ import lark from lark.exceptions import ( - LarkError, UnexpectedCharacters, UnexpectedEOF, UnexpectedToken, VisitError) + LarkError, + UnexpectedCharacters, + UnexpectedEOF, + UnexpectedToken, + VisitError, +) from starkware.cairo.lang.compiler.ast.cairo_types import CairoType from starkware.cairo.lang.compiler.ast.code_elements import CodeElement @@ -13,78 +18,92 @@ from starkware.cairo.lang.compiler.ast.module import CairoFile from starkware.cairo.lang.compiler.error_handling import InputFile, Location, LocationError from starkware.cairo.lang.compiler.parser_transformer import ( - ParserContext, ParserError, ParserTransformer) + ParserContext, + ParserError, + ParserTransformer, +) -grammar_file = os.path.join(os.path.dirname(__file__), 'cairo.ebnf') +grammar_file = os.path.join(os.path.dirname(__file__), "cairo.ebnf") gram_parser = lark.Lark( - open(grammar_file, 'r').read(), - start=['cairo_file', 'repl'], - lexer='standard', - propagate_positions=True) + open(grammar_file, "r").read(), + start=["cairo_file", "repl"], + lexer="standard", + propagate_positions=True, +) def wrap_lark_error(err: LarkError, input_file: InputFile) -> Exception: if input_file.content is None: return err lines = input_file.content.splitlines() - assert len(lines) > 0, 'Syntax errors are unexpected in code with no lines.' + assert len(lines) > 0, "Syntax errors are unexpected in code with no lines." err_str = str(err) if isinstance(err, UnexpectedToken): expected = set(err.expected) - if {'FP', 'AP'} <= expected: - expected.remove('FP') - expected.remove('AP') - expected.add('register') - if {'MINUS', 'INT'} <= expected: - expected.remove('MINUS') - if {'CAST', 'LPAR', 'LSQB', 'IDENTIFIER', 'INT', 'AMPERSAND', 'register'} <= expected: + if {"FP", "AP"} <= expected: + expected.remove("FP") + expected.remove("AP") + expected.add("register") + if {"MINUS", "INT"} <= expected: + expected.remove("MINUS") + if {"CAST", "LPAR", "LSQB", "IDENTIFIER", "INT", "AMPERSAND", "register"} <= expected: expected -= { - 'CAST', 'LPAR', 'LSQB', 'IDENTIFIER', 'INT', 'HEXINT', 'PYCONST', 'NONDET', - 'AMPERSAND', 'register'} - expected.add('expression') - if {'PLUS', 'MINUS', 'STAR', 'SLASH'} <= expected: - expected -= {'PLUS', 'MINUS', 'STAR', 'SLASH'} - expected.add('operator') - if 'COMMENT' in expected: - expected.remove('COMMENT') - if '_NEWLINE' in expected: - expected.remove('_NEWLINE') + "CAST", + "LPAR", + "LSQB", + "IDENTIFIER", + "INT", + "HEXINT", + "PYCONST", + "NONDET", + "AMPERSAND", + "register", + } + expected.add("expression") + if {"PLUS", "MINUS", "STAR", "SLASH"} <= expected: + expected -= {"PLUS", "MINUS", "STAR", "SLASH"} + expected.add("operator") + if "COMMENT" in expected: + expected.remove("COMMENT") + if "_NEWLINE" in expected: + expected.remove("_NEWLINE") TOKENS = { - '_ARROW': '"->"', - '_AT': '"@"', - '_DBL_EQ': '"=="', - '_DBL_PLUS': '"++"', - '_NEQ': '"!="', - 'AMPERSAND': '"&"', - 'CAST': '"cast"', - 'CALL': '"call"', - 'COLON': '":"', - 'DOT': '"."', - 'EQUAL': '"="', - 'FUNC': '"func"', - 'IDENTIFIER': 'identifier', - 'INT': 'integer', - 'HEXINT': 'integer', - 'LBRACE': '"{"', - 'LPAR': '"("', - 'LSQB': '"["', - 'MINUS': '"-"', - 'NAMESPACE': '"namespace"', - 'PLUS': '"+"', - 'RBRACE': '"}"', - 'RPAR': '")"', - 'RSQB': '"]"', - 'SEMICOLON': '";"', - 'SLASH': '"/"', - 'STAR': '"*"', - 'STRUCT': '"struct"', + "_ARROW": '"->"', + "_AT": '"@"', + "_DBL_EQ": '"=="', + "_DBL_PLUS": '"++"', + "_NEQ": '"!="', + "AMPERSAND": '"&"', + "CAST": '"cast"', + "CALL": '"call"', + "COLON": '":"', + "DOT": '"."', + "EQUAL": '"="', + "FUNC": '"func"', + "IDENTIFIER": "identifier", + "INT": "integer", + "HEXINT": "integer", + "LBRACE": '"{"', + "LPAR": '"("', + "LSQB": '"["', + "MINUS": '"-"', + "NAMESPACE": '"namespace"', + "PLUS": '"+"', + "RBRACE": '"}"', + "RPAR": '")"', + "RSQB": '"]"', + "SEMICOLON": '";"', + "SLASH": '"/"', + "STAR": '"*"', + "STRUCT": '"struct"', } expected_lst = sorted(TOKENS.get(x, x) for x in expected) if len(expected_lst) > 1: - err_str = \ + err_str = ( f'Unexpected token {repr(err.token)}. Expected one of: {", ".join(expected_lst)}.' + ) else: err_str = f'Unexpected token {repr(err.token)}. Expected: {", ".join(expected_lst)}.' @@ -103,14 +122,22 @@ def wrap_lark_error(err: LarkError, input_file: InputFile) -> Exception: return err location = Location( - start_line=line, start_col=col, end_line=line, - end_col=min(col + width, len(lines[line - 1]) + 1), input_file=input_file) + start_line=line, + start_col=col, + end_line=line, + end_col=min(col + width, len(lines[line - 1]) + 1), + input_file=input_file, + ) return ParserError(err_str, location) def parse( - filename: Optional[str], code: str, code_type: str, expected_type, - parser_context: Optional[ParserContext] = None): + filename: Optional[str], + code: str, + code_type: str, + expected_type, + parser_context: Optional[ParserContext] = None, +): """ Parses the given string and returns an AST tree based on the classes in ast/*.py. code_type is the ebnf rule to start from (e.g., 'expr' or 'cairo_file'). @@ -130,8 +157,9 @@ def parse( raise err.orig_exc else: raise - assert isinstance(parsed, expected_type), \ - f'Expected parsing result to be {expected_type.__name__}. Found: {type(parsed).__name__}' + assert isinstance( + parsed, expected_type + ), f"Expected parsing result to be {expected_type.__name__}. Found: {type(parsed).__name__}" return parsed @@ -144,22 +172,22 @@ def lex(code: str) -> List[lark.lexer.Token]: def parse_file( - code: str, filename: str = '', - parser_context: Optional[ParserContext] = None) -> CairoFile: + code: str, filename: str = "", parser_context: Optional[ParserContext] = None +) -> CairoFile: """ Parses the given string and returns a CairoFile instance. """ # If code does not end with '\n', add it. - if not code.endswith('\n'): - code += '\n' - return parse(filename, code, 'cairo_file', CairoFile, parser_context=parser_context) + if not code.endswith("\n"): + code += "\n" + return parse(filename, code, "cairo_file", CairoFile, parser_context=parser_context) def parse_instruction(code: str) -> InstructionAst: """ Parses the given string and returns an InstructionAst instance. """ - return parse(None, code, 'instruction', InstructionAst) + return parse(None, code, "instruction", InstructionAst) @lru_cache(None) @@ -167,18 +195,18 @@ def parse_expr(code: str) -> Expression: """ Parses the given string and returns an Expression instance. """ - return parse(None, code, 'expr', Expression) + return parse(None, code, "expr", Expression) def parse_type(code: str) -> CairoType: """ Parses the given string and returns an Expression instance. """ - return parse(None, code, 'type', CairoType) + return parse(None, code, "type", CairoType) def parse_code_element(code: str, parser_context: Optional[ParserContext] = None) -> CodeElement: """ Parses the given string and returns a CodeElement instance. """ - return parse(None, code, 'code_element', CodeElement, parser_context=parser_context) + return parse(None, code, "code_element", CodeElement, parser_context=parser_context) diff --git a/src/starkware/cairo/lang/compiler/parser_errors_test.py b/src/starkware/cairo/lang/compiler/parser_errors_test.py index d716f774..6627e6ff 100644 --- a/src/starkware/cairo/lang/compiler/parser_errors_test.py +++ b/src/starkware/cairo/lang/compiler/parser_errors_test.py @@ -5,127 +5,176 @@ def test_unexpected_token(): - verify_exception(""" + verify_exception( + """ x + = y -""", """ +""", + """ file:?:?: Unexpected token Token(EQUAL, '='). Expected: expression. x + = y ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ let x = -""", r""" +""", + r""" file:?:?: Unexpected token Token(_NEWLINE, '\n'). Expected one of: "call", expression. let x = ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ foo bar -""", """ +""", + """ file:?:?: Unexpected token Token(IDENTIFIER, 'bar'). Expected one of: "(", ".", ":", "=", "[", \ "{", operator. foo bar ^*^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ foo = bar test -""", """ +""", + """ file:?:?: Unexpected token Token(IDENTIFIER, 'test'). Expected one of: "(", ".", ";", "[", "{", \ operator. foo = bar test ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ const func -""", """ +""", + """ file:?:?: Unexpected token Token(FUNC, 'func'). Expected: identifier. const func ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ %[ 5 %] %[ 7 %] -""", """ +""", + """ file:?:?: Unexpected token Token(PYCONST, '%[ 7 %]'). Expected one of: ".", "=", "[", operator. %[ 5 %] %[ 7 %] ^*****^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ static_assert ap -""", r""" +""", + r""" file:?:?: Unexpected token Token(_NEWLINE, '\n'). Expected one of: ".", "==", "[", operator. static_assert ap ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ [ap] = x& + y -""", """ +""", + """ file:?:?: Unexpected token Token(AMPERSAND, '&'). Expected one of: "(", ".", ";", "[", "{", \ operator. [ap] = x& + y ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func & -""", """ +""", + """ file:?:?: Unexpected token Token(AMPERSAND, '&'). Expected: identifier. func & ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ let x : T 5 -""", """ +""", + """ file:?:?: Unexpected token Token(INT, '5'). Expected one of: "*", ".", "=". let x : T 5 ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ foo( * -""", """ +""", + """ file:?:?: Unexpected token Token(STAR, '*'). Expected one of: ")", expression. foo( * ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ if x y -""", """ +""", + """ file:?:?: Unexpected token Token(IDENTIFIER, 'y'). Expected one of: "!=", "(", ".", "==", "[", \ "{", operator. if x y ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ x = y; ap-- -""", """ +""", + """ file:?:?: Unexpected token Token(MINUS, '-'). Expected: "++". x = y; ap-- ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func foo()* -""", """ +""", + """ file:?:?: Unexpected token Token(STAR, '*'). Expected one of: "->", ":". func foo()* ^ -""") +""", + ) def test_unexpected_character(): - verify_exception(""" + verify_exception( + """ x~y -""", """ +""", + """ file:?:?: Unexpected character "~". x~y ^ -""") +""", + ) def test_parser_error(): # Unexpected EOF - missing 'end'. - with pytest.raises(ParserError, match='Unexpected end-of-input.') as e: - parse_file(code=""" + with pytest.raises(ParserError, match="Unexpected end-of-input.") as e: + parse_file( + code=""" func f(): const a = 5 -""") - assert str(e.value).endswith(""" +""" + ) + assert str(e.value).endswith( + """ const a = 5 - ^""") + ^""" + ) diff --git a/src/starkware/cairo/lang/compiler/parser_test.py b/src/starkware/cairo/lang/compiler/parser_test.py index 29c7b0bc..f3b8ab47 100644 --- a/src/starkware/cairo/lang/compiler/parser_test.py +++ b/src/starkware/cairo/lang/compiler/parser_test.py @@ -3,540 +3,565 @@ from starkware.cairo.lang.compiler.ast.aliased_identifier import AliasedIdentifier from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypeTuple from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeElementImport, CodeElementReference, CodeElementReturnValueReference) + CodeElementImport, + CodeElementReference, + CodeElementReturnValueReference, +) from starkware.cairo.lang.compiler.ast.expr import ( - ExprConst, ExprDeref, ExprDot, ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, - ExprPyConst, ExprReg, ExprSubscript) + ExprConst, + ExprDeref, + ExprDot, + ExprIdentifier, + ExprNeg, + ExprOperator, + ExprParentheses, + ExprPyConst, + ExprReg, + ExprSubscript, +) from starkware.cairo.lang.compiler.ast.formatting_utils import FormattingError from starkware.cairo.lang.compiler.ast.instructions import ( - AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, - JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) + AddApInstruction, + AssertEqInstruction, + CallInstruction, + CallLabelInstruction, + InstructionAst, + JnzInstruction, + JumpInstruction, + JumpToLabelInstruction, + RetInstruction, +) from starkware.cairo.lang.compiler.ast.types import TypedIdentifier from starkware.cairo.lang.compiler.error_handling import LocationError, get_location_marks from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.parser import ( - parse, parse_code_element, parse_expr, parse_instruction, parse_type) + parse, + parse_code_element, + parse_expr, + parse_instruction, + parse_type, +) from starkware.cairo.lang.compiler.parser_test_utils import verify_exception from starkware.cairo.lang.compiler.parser_transformer import ParserContext, ParserError from starkware.python.utils import safe_zip def test_int(): - expr = parse_expr(' 01234 ') + expr = parse_expr(" 01234 ") assert expr == ExprConst(val=1234) - assert expr.format_str == '01234' - assert expr.format() == '01234' + assert expr.format_str == "01234" + assert expr.format() == "01234" - expr = parse_expr('-01234') + expr = parse_expr("-01234") assert expr == ExprNeg(val=ExprConst(val=1234)) - assert expr.val.format_str == '01234' - assert expr.format() == '-01234' + assert expr.val.format_str == "01234" + assert expr.format() == "-01234" - assert parse_expr('-1234') == parse_expr('- 1234') + assert parse_expr("-1234") == parse_expr("- 1234") def test_hex_int(): - expr = parse_expr(' 0x1234 ') + expr = parse_expr(" 0x1234 ") assert expr == ExprConst(val=0x1234) - assert expr.format_str == '0x1234' - assert expr.format() == '0x1234' + assert expr.format_str == "0x1234" + assert expr.format() == "0x1234" - expr = parse_expr('-0x01234') + expr = parse_expr("-0x01234") assert expr == ExprNeg(val=ExprConst(val=0x1234)) - assert expr.val.format_str == '0x01234' - assert expr.format() == '-0x01234' + assert expr.val.format_str == "0x01234" + assert expr.format() == "-0x01234" - assert parse_expr('-0x1234') == parse_expr('- 0x1234') + assert parse_expr("-0x1234") == parse_expr("- 0x1234") def test_types(): - assert isinstance(parse_type('felt'), TypeFelt) - assert parse_type('my_namespace.MyStruct * *').format() == 'my_namespace.MyStruct**' - assert parse_type('my_namespace.MyStruct*****').format() == 'my_namespace.MyStruct*****' + assert isinstance(parse_type("felt"), TypeFelt) + assert parse_type("my_namespace.MyStruct * *").format() == "my_namespace.MyStruct**" + assert parse_type("my_namespace.MyStruct*****").format() == "my_namespace.MyStruct*****" def test_type_tuple(): - typ = parse_type('(felt)') + typ = parse_type("(felt)") assert typ == TypeTuple(members=[TypeFelt()]) - assert typ.format() == '(felt)' - assert parse_type('( felt, felt* , (felt, T.S,)* )').format() == '(felt, felt*, (felt, T.S)*)' + assert typ.format() == "(felt)" + assert parse_type("( felt, felt* , (felt, T.S,)* )").format() == "(felt, felt*, (felt, T.S)*)" def test_identifier_and_dot(): - assert parse_expr('x.y . z + x ').format() == 'x.y.z + x' - assert parse_expr(' [x]. y . z').format() == '[x].y.z' - assert parse_expr('(x-y).z').format() == '(x - y).z' - assert parse_expr('x-y.z').format() == 'x - y.z' - assert parse_expr('[ap+1].x.y').format() == '[ap + 1].x.y' - assert parse_expr('((a.b + c).d * e.f + g.h).i.j').format() == '((a.b + c).d * e.f + g.h).i.j' - - assert parse_expr('(x).y.z') == \ - ExprDot( - expr=ExprDot( - expr=ExprParentheses(val=ExprIdentifier(name='x')), - member=ExprIdentifier(name='y')), - member=ExprIdentifier(name='z')) - assert parse_expr('x.y.z') == ExprIdentifier(name='x.y.z') + assert parse_expr("x.y . z + x ").format() == "x.y.z + x" + assert parse_expr(" [x]. y . z").format() == "[x].y.z" + assert parse_expr("(x-y).z").format() == "(x - y).z" + assert parse_expr("x-y.z").format() == "x - y.z" + assert parse_expr("[ap+1].x.y").format() == "[ap + 1].x.y" + assert parse_expr("((a.b + c).d * e.f + g.h).i.j").format() == "((a.b + c).d * e.f + g.h).i.j" + + assert parse_expr("(x).y.z") == ExprDot( + expr=ExprDot( + expr=ExprParentheses(val=ExprIdentifier(name="x")), member=ExprIdentifier(name="y") + ), + member=ExprIdentifier(name="z"), + ) + assert parse_expr("x.y.z") == ExprIdentifier(name="x.y.z") with pytest.raises(ParserError): - parse_expr('.x') + parse_expr(".x") with pytest.raises(ParserError): - parse_expr('x.') + parse_expr("x.") with pytest.raises(ParserError): - parse_expr('x.(y+z)') + parse_expr("x.(y+z)") with pytest.raises(ParserError): - parse_expr('x.[a]') + parse_expr("x.[a]") def test_typed_identifier(): - typed_identifier = parse(None, 't : felt*', 'typed_identifier', TypedIdentifier) - assert typed_identifier.format() == 't : felt*' + typed_identifier = parse(None, "t : felt*", "typed_identifier", TypedIdentifier) + assert typed_identifier.format() == "t : felt*" - typed_identifier = parse(None, 'local t : felt', 'typed_identifier', TypedIdentifier) - assert typed_identifier.format() == 'local t : felt' + typed_identifier = parse(None, "local t : felt", "typed_identifier", TypedIdentifier) + assert typed_identifier.format() == "local t : felt" def test_exp_pyconst(): - expr = parse_expr(' %[foo bar%] ') - assert expr == ExprPyConst(code='foo bar') - assert expr.format() == '%[foo bar%]' + expr = parse_expr(" %[foo bar%] ") + assert expr == ExprPyConst(code="foo bar") + assert expr.format() == "%[foo bar%]" def test_add_expr(): - expr = parse_expr('[fp + 1] + [ap - x]') - assert expr == \ - ExprOperator( - a=ExprDeref( - addr=ExprOperator( - a=ExprReg(reg=Register.FP), - op='+', - b=ExprConst(val=1))), - op='+', - b=ExprDeref( - addr=ExprOperator( - a=ExprReg(reg=Register.AP), - op='-', - b=ExprIdentifier(name='x')))) - assert expr.format() == '[fp + 1] + [ap - x]' - assert parse_expr('[ap-7]+37').format() == '[ap - 7] + 37' + expr = parse_expr("[fp + 1] + [ap - x]") + assert expr == ExprOperator( + a=ExprDeref(addr=ExprOperator(a=ExprReg(reg=Register.FP), op="+", b=ExprConst(val=1))), + op="+", + b=ExprDeref( + addr=ExprOperator(a=ExprReg(reg=Register.AP), op="-", b=ExprIdentifier(name="x")) + ), + ) + assert expr.format() == "[fp + 1] + [ap - x]" + assert parse_expr("[ap-7]+37").format() == "[ap - 7] + 37" def test_deref_expr(): - expr = parse_expr('[[fp - 7] + 3]') - assert expr == \ - ExprDeref( - addr=ExprOperator( - a=ExprDeref( - addr=ExprOperator( - a=ExprReg(reg=Register.FP), - op='-', - b=ExprConst(val=7))), - op='+', - b=ExprConst(val=3))) - assert expr.format() == '[[fp - 7] + 3]' + expr = parse_expr("[[fp - 7] + 3]") + assert expr == ExprDeref( + addr=ExprOperator( + a=ExprDeref(addr=ExprOperator(a=ExprReg(reg=Register.FP), op="-", b=ExprConst(val=7))), + op="+", + b=ExprConst(val=3), + ) + ) + assert expr.format() == "[[fp - 7] + 3]" def test_subscript_expr(): - assert parse_expr('x[y]').format() == 'x[y]' - assert parse_expr('[x][y][z][w]').format() == '[x][y][z][w]' - assert parse_expr(' x [ [ y[z[w]] ] ]').format() == 'x[[y[z[w]]]]' - assert parse_expr(' (x+y)[z+w] ').format() == '(x + y)[z + w]' - assert parse_expr('(&x)[3][(a-b)*2][&c]').format() == '(&x)[3][(a - b) * 2][&c]' - assert parse_expr('x[i+n*j]').format() == 'x[i + n * j]' - assert parse_expr('x+[y][z]').format() == 'x + [y][z]' - - assert parse_expr('[x][y][[z]]') == \ - ExprSubscript( - expr=ExprSubscript( - expr=ExprDeref(addr=ExprIdentifier(name='x')), - offset=ExprIdentifier(name='y') - ), - offset=ExprDeref(addr=ExprIdentifier(name='z'))) + assert parse_expr("x[y]").format() == "x[y]" + assert parse_expr("[x][y][z][w]").format() == "[x][y][z][w]" + assert parse_expr(" x [ [ y[z[w]] ] ]").format() == "x[[y[z[w]]]]" + assert parse_expr(" (x+y)[z+w] ").format() == "(x + y)[z + w]" + assert parse_expr("(&x)[3][(a-b)*2][&c]").format() == "(&x)[3][(a - b) * 2][&c]" + assert parse_expr("x[i+n*j]").format() == "x[i + n * j]" + assert parse_expr("x+[y][z]").format() == "x + [y][z]" + + assert parse_expr("[x][y][[z]]") == ExprSubscript( + expr=ExprSubscript( + expr=ExprDeref(addr=ExprIdentifier(name="x")), offset=ExprIdentifier(name="y") + ), + offset=ExprDeref(addr=ExprIdentifier(name="z")), + ) with pytest.raises(ParserError): - parse_expr('x[)]') + parse_expr("x[)]") with pytest.raises(ParserError): - parse_expr('x[]') + parse_expr("x[]") def test_operator_precedence(): - code = '(5 + 2) - (3 - 9) * (7 + (-(8 ** 2))) - 10 * (-2) * 5 ** 3 + (((7)))' + code = "(5 + 2) - (3 - 9) * (7 + (-(8 ** 2))) - 10 * (-2) * 5 ** 3 + (((7)))" expr = parse_expr(code) # Test formatting. assert expr.format() == code # Compute the value of expr from the tree and compare it with the correct value. - PRIME = 3 * 2**30 + 1 + PRIME = 3 * 2 ** 30 + 1 simplified_expr = ExpressionSimplifier(PRIME).visit(expr) assert isinstance(simplified_expr, ExprConst) assert simplified_expr.val == eval(code) def test_mul_expr(): - assert parse_expr('[ap]*[fp]').format() == '[ap] * [fp]' - assert parse_expr('[ap]*37').format() == '[ap] * 37' + assert parse_expr("[ap]*[fp]").format() == "[ap] * [fp]" + assert parse_expr("[ap]*37").format() == "[ap] * 37" def test_div_expr(): - assert parse_expr('[ap]/[fp]/3/[ap+1]').format() == '[ap] / [fp] / 3 / [ap + 1]' + assert parse_expr("[ap]/[fp]/3/[ap+1]").format() == "[ap] / [fp] / 3 / [ap + 1]" - code = '120 / 2 / 3 / 4' + code = "120 / 2 / 3 / 4" expr = parse_expr(code) # Compute the value of expr from the tree and compare it with the correct value. - PRIME = 3 * 2**30 + 1 + PRIME = 3 * 2 ** 30 + 1 simplified_expr = ExpressionSimplifier(PRIME).visit(expr) assert isinstance(simplified_expr, ExprConst) assert simplified_expr.val == 5 def test_cast_expr(): - assert parse_expr('cast( ap , T * * )').format() == 'cast(ap, T**)' - assert parse_expr('cast( ap , T * * ) * (cast(fp, felt))').format() == \ - 'cast(ap, T**) * (cast(fp, felt))' - assert parse_expr('cast( \n ap , T * * )').format() == 'cast(\n ap, T**)' + assert parse_expr("cast( ap , T * * )").format() == "cast(ap, T**)" + assert ( + parse_expr("cast( ap , T * * ) * (cast(fp, felt))").format() + == "cast(ap, T**) * (cast(fp, felt))" + ) + assert parse_expr("cast( \n ap , T * * )").format() == "cast(\n ap, T**)" def test_tuple_expr(): - assert parse_expr('( )').format() == '()' - assert parse_expr('( 2)').format() == '(2)' # Not a tuple. - assert parse_expr('(a= 2)').format() == '(a=2)' # Tuple. - assert parse_expr('( 2,)').format() == '(2,)' - assert parse_expr('( 1 , ap)').format() == '(1, ap)' - assert parse_expr('( 1 , ap, )').format() == '(1, ap,)' - assert parse_expr('( 1 , a=2, b=(c=()))').format() == '(1, a=2, b=(c=()))' + assert parse_expr("( )").format() == "()" + assert parse_expr("( 2)").format() == "(2)" # Not a tuple. + assert parse_expr("(a= 2)").format() == "(a=2)" # Tuple. + assert parse_expr("( 2,)").format() == "(2,)" + assert parse_expr("( 1 , ap)").format() == "(1, ap)" + assert parse_expr("( 1 , ap, )").format() == "(1, ap,)" + assert parse_expr("( 1 , a=2, b=(c=()))").format() == "(1, a=2, b=(c=()))" def test_tuple_expr_with_notes(): - assert parse_expr("""\ + assert ( + parse_expr( + """\ ( 1 , # a. ( # c. ) #b. - , (fp,[3]))""").format() == """\ + , (fp,[3]))""" + ).format() + == """\ (1, # a. ( # c. ), # b. (fp, [3]))""" - assert parse_expr("""\ + ) + assert ( + parse_expr( + """\ ( 1 # b. , # a. - )""").format() == """\ + )""" + ).format() + == """\ (1, # b. # a. )""" + ) def test_hint_expr(): - expr = parse_expr('a*nondet %{6 %}+ 7') - assert expr.format() == 'a * nondet %{ 6 %} + 7' + expr = parse_expr("a*nondet %{6 %}+ 7") + assert expr.format() == "a * nondet %{ 6 %} + 7" def test_pow_expr(): - assert parse_expr('2 ** 3').format() == '2 ** 3' - verify_exception('let x = 2 * * 3', """ + assert parse_expr("2 ** 3").format() == "2 ** 3" + verify_exception( + "let x = 2 * * 3", + """ file:?:?: Unexpected operator. Did you mean "**"? let x = 2 * * 3 ^*^ -""") +""", + ) def test_offsets(): - assert parse_expr(' [ [ ap] -x ]').format() == '[[ap] - x]' - assert parse_expr(' [ [ ap+foo] -x ]').format() == '[[ap + foo] - x]' - assert parse_expr(' [ [ fp+ 0 ] - 0]').format() == '[[fp + 0] - 0]' - assert parse_expr(' [ap+-5]').format() == '[ap + (-5)]' - assert parse_expr(' [ap--5]').format() == '[ap - (-5)]' + assert parse_expr(" [ [ ap] -x ]").format() == "[[ap] - x]" + assert parse_expr(" [ [ ap+foo] -x ]").format() == "[[ap + foo] - x]" + assert parse_expr(" [ [ fp+ 0 ] - 0]").format() == "[[fp + 0] - 0]" + assert parse_expr(" [ap+-5]").format() == "[ap + (-5)]" + assert parse_expr(" [ap--5]").format() == "[ap - (-5)]" def test_instruction(): # AssertEq. - expr = parse_instruction('[ap] = [fp]; ap++') - assert expr == \ - InstructionAst( - body=AssertEqInstruction( - a=ExprDeref( - addr=ExprReg(reg=Register.AP)), - b=ExprDeref( - addr=ExprReg(reg=Register.FP))), - inc_ap=True) - assert expr.format() == '[ap] = [fp]; ap++' - assert parse_instruction('[ap+5] = [fp]+[ap] - 5').format() == '[ap + 5] = [fp] + [ap] - 5' - assert parse_instruction('[ap+5]+3= [fp]*7;ap ++ ').format() == \ - '[ap + 5] + 3 = [fp] * 7; ap++' + expr = parse_instruction("[ap] = [fp]; ap++") + assert expr == InstructionAst( + body=AssertEqInstruction( + a=ExprDeref(addr=ExprReg(reg=Register.AP)), b=ExprDeref(addr=ExprReg(reg=Register.FP)) + ), + inc_ap=True, + ) + assert expr.format() == "[ap] = [fp]; ap++" + assert parse_instruction("[ap+5] = [fp]+[ap] - 5").format() == "[ap + 5] = [fp] + [ap] - 5" + assert parse_instruction("[ap+5]+3= [fp]*7;ap ++ ").format() == "[ap + 5] + 3 = [fp] * 7; ap++" # Jump. - expr = parse_instruction('jmp rel [ap] + x; ap++') - assert expr == \ - InstructionAst( - body=JumpInstruction( - val=ExprOperator( - a=ExprDeref(addr=ExprReg(reg=Register.AP)), - op='+', - b=ExprIdentifier(name='x')), - relative=True), - inc_ap=True) - assert expr.format() == 'jmp rel [ap] + x; ap++' - assert parse_instruction(' jmp abs[ap]+x').format() == 'jmp abs [ap] + x' + expr = parse_instruction("jmp rel [ap] + x; ap++") + assert expr == InstructionAst( + body=JumpInstruction( + val=ExprOperator( + a=ExprDeref(addr=ExprReg(reg=Register.AP)), op="+", b=ExprIdentifier(name="x") + ), + relative=True, + ), + inc_ap=True, + ) + assert expr.format() == "jmp rel [ap] + x; ap++" + assert parse_instruction(" jmp abs[ap]+x").format() == "jmp abs [ap] + x" # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('jmp abs') + parse_instruction("jmp abs") with pytest.raises(ParserError): - parse_instruction('jmpabs[ap]') + parse_instruction("jmpabs[ap]") # JumpToLabel. - expr = parse_instruction('jmp label') - assert expr == \ - InstructionAst( - body=JumpToLabelInstruction( - label=ExprIdentifier(name='label'), - condition=None), - inc_ap=False) - assert expr.format() == 'jmp label' + expr = parse_instruction("jmp label") + assert expr == InstructionAst( + body=JumpToLabelInstruction(label=ExprIdentifier(name="label"), condition=None), + inc_ap=False, + ) + assert expr.format() == "jmp label" # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('jmp [fp]') + parse_instruction("jmp [fp]") with pytest.raises(ParserError): - parse_instruction('jmp 7') + parse_instruction("jmp 7") # Jnz. - expr = parse_instruction('jmp rel [ap] + x if [fp + 3] != 0') - assert expr == \ - InstructionAst( - body=JnzInstruction( - jump_offset=ExprOperator( - a=ExprDeref(addr=ExprReg(reg=Register.AP)), - op='+', - b=ExprIdentifier(name='x')), - condition=ExprDeref( - addr=ExprOperator( - a=ExprReg(reg=Register.FP), - op='+', - b=ExprConst(val=3)))), - inc_ap=False) - assert expr.format() == 'jmp rel [ap] + x if [fp + 3] != 0' - assert parse_instruction(' jmp rel 17 if[fp]!=0;ap++').format() == \ - 'jmp rel 17 if [fp] != 0; ap++' + expr = parse_instruction("jmp rel [ap] + x if [fp + 3] != 0") + assert expr == InstructionAst( + body=JnzInstruction( + jump_offset=ExprOperator( + a=ExprDeref(addr=ExprReg(reg=Register.AP)), op="+", b=ExprIdentifier(name="x") + ), + condition=ExprDeref( + addr=ExprOperator(a=ExprReg(reg=Register.FP), op="+", b=ExprConst(val=3)) + ), + ), + inc_ap=False, + ) + assert expr.format() == "jmp rel [ap] + x if [fp + 3] != 0" + assert ( + parse_instruction(" jmp rel 17 if[fp]!=0;ap++").format() + == "jmp rel 17 if [fp] != 0; ap++" + ) # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('jmprel 17 if x != 0') + parse_instruction("jmprel 17 if x != 0") with pytest.raises(ParserError): - parse_instruction('jmp 17 if x') - with pytest.raises(ParserError, match='!= 0'): - parse_instruction('jmp rel 17 if x != 2') + parse_instruction("jmp 17 if x") + with pytest.raises(ParserError, match="!= 0"): + parse_instruction("jmp rel 17 if x != 2") with pytest.raises(ParserError): - parse_instruction('jmp rel [fp] ifx != 0') + parse_instruction("jmp rel [fp] ifx != 0") # Jnz to label. - expr = parse_instruction('jmp label if [fp] != 0') - assert expr == \ - InstructionAst( - body=JumpToLabelInstruction( - label=ExprIdentifier('label'), - condition=ExprDeref(addr=ExprReg(reg=Register.FP))), - inc_ap=False) - assert expr.format() == 'jmp label if [fp] != 0' + expr = parse_instruction("jmp label if [fp] != 0") + assert expr == InstructionAst( + body=JumpToLabelInstruction( + label=ExprIdentifier("label"), condition=ExprDeref(addr=ExprReg(reg=Register.FP)) + ), + inc_ap=False, + ) + assert expr.format() == "jmp label if [fp] != 0" # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('jmp [fp] if [fp] != 0') + parse_instruction("jmp [fp] if [fp] != 0") with pytest.raises(ParserError): - parse_instruction('jmp 7 if [fp] != 0') + parse_instruction("jmp 7 if [fp] != 0") # Call abs. - expr = parse_instruction('call abs [fp] + x') - assert expr == \ - InstructionAst( - body=CallInstruction( - val=ExprOperator( - a=ExprDeref(addr=ExprReg(reg=Register.FP)), - op='+', - b=ExprIdentifier(name='x')), - relative=False), - inc_ap=False) - assert expr.format() == 'call abs [fp] + x' - assert parse_instruction('call abs 17;ap++').format() == 'call abs 17; ap++' + expr = parse_instruction("call abs [fp] + x") + assert expr == InstructionAst( + body=CallInstruction( + val=ExprOperator( + a=ExprDeref(addr=ExprReg(reg=Register.FP)), op="+", b=ExprIdentifier(name="x") + ), + relative=False, + ), + inc_ap=False, + ) + assert expr.format() == "call abs [fp] + x" + assert parse_instruction("call abs 17;ap++").format() == "call abs 17; ap++" # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('call abs') + parse_instruction("call abs") with pytest.raises(ParserError): - parse_instruction('callabs 7') + parse_instruction("callabs 7") # Call rel. - expr = parse_instruction('call rel [ap] + x') - assert expr == \ - InstructionAst( - body=CallInstruction( - val=ExprOperator( - a=ExprDeref(addr=ExprReg(reg=Register.AP)), - op='+', - b=ExprIdentifier(name='x')), - relative=True), - inc_ap=False) - assert expr.format() == 'call rel [ap] + x' - assert parse_instruction('call rel 17;ap++').format() == 'call rel 17; ap++' + expr = parse_instruction("call rel [ap] + x") + assert expr == InstructionAst( + body=CallInstruction( + val=ExprOperator( + a=ExprDeref(addr=ExprReg(reg=Register.AP)), op="+", b=ExprIdentifier(name="x") + ), + relative=True, + ), + inc_ap=False, + ) + assert expr.format() == "call rel [ap] + x" + assert parse_instruction("call rel 17;ap++").format() == "call rel 17; ap++" # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('call rel') + parse_instruction("call rel") with pytest.raises(ParserError): - parse_instruction('callrel 7') + parse_instruction("callrel 7") # Call label. - expr = parse_instruction('call label') - assert expr == \ - InstructionAst( - body=CallLabelInstruction( - label=ExprIdentifier(name='label')), - inc_ap=False) - assert expr.format() == 'call label' - assert parse_instruction('call label ;ap++').format() == 'call label; ap++' + expr = parse_instruction("call label") + assert expr == InstructionAst( + body=CallLabelInstruction(label=ExprIdentifier(name="label")), inc_ap=False + ) + assert expr.format() == "call label" + assert parse_instruction("call label ;ap++").format() == "call label; ap++" # Make sure the following are not OK. with pytest.raises(ParserError): - parse_instruction('call [fp]') + parse_instruction("call [fp]") with pytest.raises(ParserError): - parse_instruction('call 7') + parse_instruction("call 7") # Ret. - expr = parse_instruction('ret') - assert expr == \ - InstructionAst( - body=RetInstruction(), - inc_ap=False) - assert expr.format() == 'ret' + expr = parse_instruction("ret") + assert expr == InstructionAst(body=RetInstruction(), inc_ap=False) + assert expr.format() == "ret" # AddAp. - expr = parse_instruction('ap += [fp] + 2') - assert expr == \ - InstructionAst( - body=AddApInstruction( - expr=ExprOperator( - a=ExprDeref( - addr=ExprReg(reg=Register.FP)), - op='+', - b=ExprConst(val=2))), - inc_ap=False) - assert expr.format() == 'ap += [fp] + 2' - assert parse_instruction('ap +=[ fp]+ 2').format() == 'ap += [fp] + 2' - assert parse_instruction('ap +=[ fp]+ 2;ap ++').format() == 'ap += [fp] + 2; ap++' + expr = parse_instruction("ap += [fp] + 2") + assert expr == InstructionAst( + body=AddApInstruction( + expr=ExprOperator( + a=ExprDeref(addr=ExprReg(reg=Register.FP)), op="+", b=ExprConst(val=2) + ) + ), + inc_ap=False, + ) + assert expr.format() == "ap += [fp] + 2" + assert parse_instruction("ap +=[ fp]+ 2").format() == "ap += [fp] + 2" + assert parse_instruction("ap +=[ fp]+ 2;ap ++").format() == "ap += [fp] + 2; ap++" def test_import(): # Test module names without periods. - res = parse_code_element('from a import b') + res = parse_code_element("from a import b") assert res == CodeElementImport( - path=ExprIdentifier(name='a'), - import_items=[AliasedIdentifier( - orig_identifier=ExprIdentifier(name='b'), - local_name=None)]) - assert res.format(allowed_line_length=100) == 'from a import b' + path=ExprIdentifier(name="a"), + import_items=[AliasedIdentifier(orig_identifier=ExprIdentifier(name="b"), local_name=None)], + ) + assert res.format(allowed_line_length=100) == "from a import b" # Test module names without periods, with aliasing. - res = parse_code_element('from a import b as c') + res = parse_code_element("from a import b as c") assert res == CodeElementImport( - path=ExprIdentifier(name='a'), - import_items=[AliasedIdentifier( - orig_identifier=ExprIdentifier(name='b'), - local_name=ExprIdentifier(name='c'))]) - assert res.format(allowed_line_length=100) == 'from a import b as c' + path=ExprIdentifier(name="a"), + import_items=[ + AliasedIdentifier( + orig_identifier=ExprIdentifier(name="b"), local_name=ExprIdentifier(name="c") + ) + ], + ) + assert res.format(allowed_line_length=100) == "from a import b as c" # Test module names with periods. - res = parse_code_element('from a.b12.c4 import lib345') + res = parse_code_element("from a.b12.c4 import lib345") assert res == CodeElementImport( - path=ExprIdentifier(name='a.b12.c4'), - import_items=[AliasedIdentifier( - orig_identifier=ExprIdentifier(name='lib345'), - local_name=None)]) - assert res.format(allowed_line_length=100) == 'from a.b12.c4 import lib345' + path=ExprIdentifier(name="a.b12.c4"), + import_items=[ + AliasedIdentifier(orig_identifier=ExprIdentifier(name="lib345"), local_name=None) + ], + ) + assert res.format(allowed_line_length=100) == "from a.b12.c4 import lib345" # Test multiple imports. - res = parse_code_element('from lib import a,b as b2, c') + res = parse_code_element("from lib import a,b as b2, c") assert res == CodeElementImport( - path=ExprIdentifier(name='lib'), + path=ExprIdentifier(name="lib"), import_items=[ + AliasedIdentifier(orig_identifier=ExprIdentifier(name="a"), local_name=None), AliasedIdentifier( - orig_identifier=ExprIdentifier(name='a'), - local_name=None), - AliasedIdentifier( - orig_identifier=ExprIdentifier(name='b'), - local_name=ExprIdentifier(name='b2')), - AliasedIdentifier( - orig_identifier=ExprIdentifier(name='c'), - local_name=None), - ]) - assert res.format(allowed_line_length=100) == 'from lib import a, b as b2, c' - assert res.format(allowed_line_length=20) == 'from lib import (\n a, b as b2, c)' + orig_identifier=ExprIdentifier(name="b"), local_name=ExprIdentifier(name="b2") + ), + AliasedIdentifier(orig_identifier=ExprIdentifier(name="c"), local_name=None), + ], + ) + assert res.format(allowed_line_length=100) == "from lib import a, b as b2, c" + assert res.format(allowed_line_length=20) == "from lib import (\n a, b as b2, c)" - assert res == parse_code_element('from lib import (\n a, b as b2, c)') + assert res == parse_code_element("from lib import (\n a, b as b2, c)") # Test module with bad identifier (with periods). with pytest.raises(ParserError): - parse_expr('from a.b import c.d') + parse_expr("from a.b import c.d") # Test module with bad local name (with periods). with pytest.raises(ParserError): - parse_expr('from a.b import c as d.d') + parse_expr("from a.b import c as d.d") def test_return_value_reference(): - res = parse_code_element('let z=call x') - assert res.format(allowed_line_length=100) == 'let z = call x' + res = parse_code_element("let z=call x") + assert res.format(allowed_line_length=100) == "let z = call x" - res = parse_code_element('let z:y.z=call x') - assert res.format(allowed_line_length=100) == 'let z : y.z = call x' + res = parse_code_element("let z:y.z=call x") + assert res.format(allowed_line_length=100) == "let z : y.z = call x" - res = parse_code_element('let z:y.z=call rel x') - assert res.format(allowed_line_length=100) == 'let z : y.z = call rel x' + res = parse_code_element("let z:y.z=call rel x") + assert res.format(allowed_line_length=100) == "let z : y.z = call rel x" res = parse_code_element( - 'let very_long_prefix = foo(a=1, b= 1, very_long_arg_1=1, very_long_arg_2 =1)') - assert res.format( - allowed_line_length=40) == """\ + "let very_long_prefix = foo(a=1, b= 1, very_long_arg_1=1, very_long_arg_2 =1)" + ) + assert ( + res.format(allowed_line_length=40) + == """\ let very_long_prefix = foo( a=1, b=1, very_long_arg_1=1, very_long_arg_2=1)""" + ) res = parse_code_element( - 'let (very_long_prefix ,b,c: T) = foo(a=1, b= 1, very_long_arg_1=1, very_long_arg_2 =1)') - assert res.format( - allowed_line_length=40) == """\ + "let (very_long_prefix ,b,c: T) = foo(a=1, b= 1, very_long_arg_1=1, very_long_arg_2 =1)" + ) + assert ( + res.format(allowed_line_length=40) + == """\ let (very_long_prefix, b, c : T) = foo( a=1, b=1, very_long_arg_1=1, very_long_arg_2=1)""" + ) with pytest.raises(ParserError): # Const in the unpacking tuple. - parse_expr('let (1,b,c) = foo(a=1, b= 1)') + parse_expr("let (1,b,c) = foo(a=1, b= 1)") with pytest.raises(ParserError): # Missing identifier after call. - parse_expr('let z = call') + parse_expr("let z = call") with pytest.raises(ParserError): # 'ap++' cannot be used in the return value reference syntax. - parse_expr('let z = call x; ap++') + parse_expr("let z = call x; ap++") def test_return(): - res = parse_code_element('return( 1, \na= 2 )') - assert res.format(allowed_line_length=100) == 'return (1, a=2)' + res = parse_code_element("return( 1, \na= 2 )") + assert res.format(allowed_line_length=100) == "return (1, a=2)" def test_func_call(): - res = parse_code_element('fibonacci( 1, \na= 2 )') - assert res.format(allowed_line_length=100) == 'fibonacci(1, a=2)' + res = parse_code_element("fibonacci( 1, \na= 2 )") + assert res.format(allowed_line_length=100) == "fibonacci(1, a=2)" - res = parse_code_element('fibonacci {a=b,c = d}( 1, \na= 2 )') - assert res.format(allowed_line_length=100) == 'fibonacci{a=b, c=d}(1, a=2)' - assert res.format(allowed_line_length=20) == 'fibonacci{a=b, c=d}(\n 1, a=2)' - assert res.format(allowed_line_length=15) == 'fibonacci{\n a=b, c=d}(\n 1, a=2)' + res = parse_code_element("fibonacci {a=b,c = d}( 1, \na= 2 )") + assert res.format(allowed_line_length=100) == "fibonacci{a=b, c=d}(1, a=2)" + assert res.format(allowed_line_length=20) == "fibonacci{a=b, c=d}(\n 1, a=2)" + assert res.format(allowed_line_length=15) == "fibonacci{\n a=b, c=d}(\n 1, a=2)" def test_tail_call(): - res = parse_code_element('return fibonacci( 1, \na= 2 )') - assert res.format(allowed_line_length=100) == 'return fibonacci(1, a=2)' + res = parse_code_element("return fibonacci( 1, \na= 2 )") + assert res.format(allowed_line_length=100) == "return fibonacci(1, a=2)" def test_func_with_args(): @@ -546,42 +571,43 @@ def def_func(args_str): [ap] = 4 end""" - def test_format(args_str_wrong, args_str_right=''): + def test_format(args_str_wrong, args_str_right=""): assert parse_code_element(def_func(args_str_wrong)).format( - allowed_line_length=100) == def_func(args_str_right) + allowed_line_length=100 + ) == def_func(args_str_right) - test_format(' ( x : T, y : S, z ) ', '(x : T, y : S, z)') - test_format('(x,y,z)', '(x, y, z)') - test_format('(x,y,z,)', '(x, y, z)') - test_format('(x,\ny,\nz)', '(x, y, z)') - test_format('(\nx,\ny,\nz)', '(x, y, z)') - test_format('( )', '()') - test_format('(\n\n)', '()') + test_format(" ( x : T, y : S, z ) ", "(x : T, y : S, z)") + test_format("(x,y,z)", "(x, y, z)") + test_format("(x,y,z,)", "(x, y, z)") + test_format("(x,\ny,\nz)", "(x, y, z)") + test_format("(\nx,\ny,\nz)", "(x, y, z)") + test_format("( )", "()") + test_format("(\n\n)", "()") - test_format('(x,y,z,)-> (a,b,c)', '(x, y, z) -> (a, b, c)') - test_format('()->(a,b,c)', '() -> (a, b, c)') - test_format('(x,y,z) ->()', '(x, y, z) -> ()') + test_format("(x,y,z,)-> (a,b,c)", "(x, y, z) -> (a, b, c)") + test_format("()->(a,b,c)", "() -> (a, b, c)") + test_format("(x,y,z) ->()", "(x, y, z) -> ()") # Implicit arguments. - test_format('{x,y\n\n}(z,w)->()', '{x, y}(z, w) -> ()') + test_format("{x,y\n\n}(z,w)->()", "{x, y}(z, w) -> ()") with pytest.raises(ParserError): - test_format('') + test_format("") with pytest.raises(ParserError): # Argument name cannot contain dots. - test_format('(x.y, z)') + test_format("(x.y, z)") with pytest.raises(ParserError): # Arguments must be separated by a comma. - test_format('(x y)') + test_format("(x y)") with pytest.raises(ParserError): # Double trailing comma is not allowed. - test_format('(x,y,z,,)') + test_format("(x,y,z,,)") with pytest.raises(FormattingError): - test_format('(x #comment\n,y,z)->()') + test_format("(x #comment\n,y,z)->()") def test_decoractor(): @@ -593,71 +619,82 @@ def test_decoractor(): return () end""" - assert parse_code_element(code=code).format(allowed_line_length=100) == """\ + assert ( + parse_code_element(code=code).format(allowed_line_length=100) + == """\ @hello @world @external func myfunc(): return () end""" + ) def test_decoractor_errors(): - verify_exception(""" + verify_exception( + """ @hello world func myfunc(): return() end -""", """ +""", + """ file:?:?: Unexpected token Token(IDENTIFIER, \'world\'). Expected one of: "@", "func", \ "namespace", "struct". @hello world ^***^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ @hello-world func myfunc(): return() end -""", """ +""", + """ file:?:?: Unexpected token Token(MINUS, \'-\'). Expected one of: "@", "func", "namespace", "struct". @hello-world ^ -""") +""", + ) def test_reference_type_annotation(): - res = parse_code_element('let s : T * = ap') - assert res.format(allowed_line_length=100) == 'let s : T* = ap' + res = parse_code_element("let s : T * = ap") + assert res.format(allowed_line_length=100) == "let s : T* = ap" with pytest.raises(ParserError): - parse_expr('local x : = 0') + parse_expr("local x : = 0") def test_addressof(): - res = parse_code_element('static_assert & s.SIZE == ap ') - assert res.format(allowed_line_length=100) == 'static_assert &s.SIZE == ap' + res = parse_code_element("static_assert & s.SIZE == ap ") + assert res.format(allowed_line_length=100) == "static_assert &s.SIZE == ap" def test_func_expr(): - res = parse_code_element('let x = f()') + res = parse_code_element("let x = f()") assert isinstance(res, CodeElementReturnValueReference) - assert res.format(allowed_line_length=100) == 'let x = f()' + assert res.format(allowed_line_length=100) == "let x = f()" - res = parse_code_element('let x = (f())') + res = parse_code_element("let x = (f())") assert isinstance(res, CodeElementReference) - assert res.format(allowed_line_length=100) == 'let x = (f())' + assert res.format(allowed_line_length=100) == "let x = (f())" def test_parent_location(): - parent_location = ( - parse_expr('1 + 2').location, 'An error ocurred while processing:') - - location = parse_code_element('let x = 3 + 4', parser_context=ParserContext( - parent_location=parent_location)).expr.location - location_err = LocationError(message='Error', location=location) - assert str(location_err) == """\ + parent_location = (parse_expr("1 + 2").location, "An error ocurred while processing:") + + location = parse_code_element( + "let x = 3 + 4", parser_context=ParserContext(parent_location=parent_location) + ).expr.location + location_err = LocationError(message="Error", location=location) + assert ( + str(location_err) + == """\ :1:1: An error ocurred while processing: 1 + 2 ^***^ @@ -665,6 +702,7 @@ def test_parent_location(): let x = 3 + 4 ^***^\ """ + ) def test_locations(): @@ -694,4 +732,4 @@ def test_locations(): expr.body.b.addr.b, ] for expr, mark in safe_zip(exprs, marks): - assert get_location_marks(code, expr.location) == code + '\n' + mark + assert get_location_marks(code, expr.location) == code + "\n" + mark diff --git a/src/starkware/cairo/lang/compiler/parser_test_utils.py b/src/starkware/cairo/lang/compiler/parser_test_utils.py index ec775d5c..649397c4 100644 --- a/src/starkware/cairo/lang/compiler/parser_test_utils.py +++ b/src/starkware/cairo/lang/compiler/parser_test_utils.py @@ -10,6 +10,6 @@ def verify_exception(code: str, error: str): Verifies that parsing the code results in the given error. """ with pytest.raises(ParserError) as e: - parse_file(code, '') + parse_file(code, "") # Remove line and column information from the error using a regular expression. - assert re.sub(':[0-9]+:[0-9]+: ', 'file:?:?: ', str(e.value)) == error.strip() + assert re.sub(":[0-9]+:[0-9]+: ", "file:?:?: ", str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/parser_transformer.py b/src/starkware/cairo/lang/compiler/parser_transformer.py index 2e06d179..e932d7f0 100644 --- a/src/starkware/cairo/lang/compiler/parser_transformer.py +++ b/src/starkware/cairo/lang/compiler/parser_transformer.py @@ -7,30 +7,85 @@ from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.bool_expr import BoolExpr from starkware.cairo.lang.compiler.ast.cairo_types import ( - TypeFelt, TypePointer, TypeStruct, TypeTuple) + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.code_elements import ( - BuiltinsDirective, CodeBlock, CodeElementAllocLocals, CodeElementCompoundAssertEq, - CodeElementConst, CodeElementDirective, CodeElementEmptyLine, CodeElementFuncCall, - CodeElementFunction, CodeElementHint, CodeElementIf, CodeElementImport, CodeElementInstruction, - CodeElementLabel, CodeElementLocalVariable, CodeElementMember, CodeElementReference, - CodeElementReturn, CodeElementReturnValueReference, CodeElementStaticAssert, - CodeElementTailCall, CodeElementTemporaryVariable, CodeElementUnpackBinding, CodeElementWith, - CommentedCodeElement, LangDirective) + BuiltinsDirective, + CodeBlock, + CodeElementAllocLocals, + CodeElementCompoundAssertEq, + CodeElementConst, + CodeElementDirective, + CodeElementEmptyLine, + CodeElementFuncCall, + CodeElementFunction, + CodeElementHint, + CodeElementIf, + CodeElementImport, + CodeElementInstruction, + CodeElementLabel, + CodeElementLocalVariable, + CodeElementMember, + CodeElementReference, + CodeElementReturn, + CodeElementReturnValueReference, + CodeElementStaticAssert, + CodeElementTailCall, + CodeElementTemporaryVariable, + CodeElementUnpackBinding, + CodeElementWith, + CommentedCodeElement, + LangDirective, +) from starkware.cairo.lang.compiler.ast.expr import ( - ArgList, ExprAddressOf, ExprAssignment, ExprCast, ExprConst, ExprDeref, ExprDot, ExprHint, - ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPow, ExprPyConst, ExprReg, - ExprSubscript, ExprTuple) + ArgList, + ExprAddressOf, + ExprAssignment, + ExprCast, + ExprConst, + ExprDeref, + ExprDot, + ExprHint, + ExprIdentifier, + ExprNeg, + ExprOperator, + ExprParentheses, + ExprPow, + ExprPyConst, + ExprReg, + ExprSubscript, + ExprTuple, +) from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall from starkware.cairo.lang.compiler.ast.instructions import ( - AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, - JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) + AddApInstruction, + AssertEqInstruction, + CallInstruction, + CallLabelInstruction, + InstructionAst, + JnzInstruction, + JumpInstruction, + JumpToLabelInstruction, + RetInstruction, +) from starkware.cairo.lang.compiler.ast.module import CairoFile from starkware.cairo.lang.compiler.ast.notes import Notes from starkware.cairo.lang.compiler.ast.rvalue import ( - RvalueCall, RvalueCallInst, RvalueExpr, RvalueFuncCall) + RvalueCall, + RvalueCallInst, + RvalueExpr, + RvalueFuncCall, +) from starkware.cairo.lang.compiler.ast.types import Modifier, TypedIdentifier from starkware.cairo.lang.compiler.error_handling import ( - InputFile, Location, LocationError, ParentLocation) + InputFile, + Location, + LocationError, + ParentLocation, +) from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -40,8 +95,12 @@ class ParserContext: """ Represents information that affects the parsing process. """ + parent_location: Optional[ParentLocation] = None + # If True, treat type identifiers as resolved. + resolved_types: bool = False + class ParserError(LocationError): pass @@ -57,7 +116,7 @@ def __init__(self, input_file: InputFile, parser_context: Optional[ParserContext self.parser_context = ParserContext() if parser_context is None else parser_context def __default__(self, data: str, children, meta): - raise TypeError(f'Unable to parse tree node of type {data}') + raise TypeError(f"Unable to parse tree node of type {data}") # Types. @@ -69,8 +128,9 @@ def type_struct(self, value): assert len(value) == 1 and isinstance(value[0], ExprIdentifier) return TypeStruct( scope=ScopedName.from_string(value[0].name), - is_fully_resolved=False, - location=value[0].location) + is_fully_resolved=self.parser_context.resolved_types, + location=value[0].location, + ) @v_args(meta=True) def type_pointer(self, value, meta): @@ -92,12 +152,14 @@ def arg_list(self, value, meta): args = value[1::3] # Join the notes before and after the comma. notes = [ - prev_after + before - for before, prev_after - in zip(value[::3], [Notes()] + value[2::3])] + prev_after + before for before, prev_after in zip(value[::3], [Notes()] + value[2::3]) + ] return ArgList( - args=args, notes=notes, has_trailing_comma=has_trailing_comma, - location=self.meta2loc(meta)) + args=args, + notes=notes, + has_trailing_comma=has_trailing_comma, + location=self.meta2loc(meta), + ) @v_args(meta=True) def expr_assignment(self, value, meta): @@ -107,12 +169,12 @@ def expr_assignment(self, value, meta): elif len(value) == 2: identifier, expr = value else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") return ExprAssignment(identifier=identifier, expr=expr, location=self.meta2loc(meta)) @v_args(meta=True) def identifier(self, value, meta): - return ExprIdentifier(name='.'.join(x.value for x in value), location=self.meta2loc(meta)) + return ExprIdentifier(name=".".join(x.value for x in value), location=self.meta2loc(meta)) @v_args(meta=True) def identifier_def(self, value, meta): @@ -125,7 +187,8 @@ def atom_number(self, value, meta): @v_args(meta=True) def atom_hex_number(self, value, meta): return ExprConst( - val=int(value[0], 16), format_str=value[0].value, location=self.meta2loc(meta)) + val=int(value[0], 16), format_str=value[0].value, location=self.meta2loc(meta) + ) @v_args(meta=True) def atom_pyconst(self, value, meta): @@ -146,22 +209,26 @@ def atom_func_call(self, value, meta): @v_args(meta=True) def expr_add(self, value, meta): return ExprOperator( - a=value[0], op='+', b=value[2], notes=value[1], location=self.meta2loc(meta)) + a=value[0], op="+", b=value[2], notes=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def expr_sub(self, value, meta): return ExprOperator( - a=value[0], op='-', b=value[2], notes=value[1], location=self.meta2loc(meta)) + a=value[0], op="-", b=value[2], notes=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def expr_mul(self, value, meta): return ExprOperator( - a=value[0], op='*', b=value[2], notes=value[1], location=self.meta2loc(meta)) + a=value[0], op="*", b=value[2], notes=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def expr_div(self, value, meta): return ExprOperator( - a=value[0], op='/', b=value[2], notes=value[1], location=self.meta2loc(meta)) + a=value[0], op="/", b=value[2], notes=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def unary_addressof(self, value, meta): @@ -176,12 +243,12 @@ def two_stars(self, value, meta): is_two_chars = meta.end_pos == meta.start_pos + 2 if not is_two_chars: raise ParserError( - 'Unexpected operator. Did you mean "**"?', location=self.meta2loc(meta)) + 'Unexpected operator. Did you mean "**"?', location=self.meta2loc(meta) + ) @v_args(meta=True) def expr_pow(self, value, meta): - return ExprPow( - a=value[0], b=value[3], notes=value[2], location=self.meta2loc(meta)) + return ExprPow(a=value[0], b=value[3], notes=value[2], location=self.meta2loc(meta)) @v_args(meta=True) def atom_parentheses(self, value, meta): @@ -194,7 +261,8 @@ def atom_deref(self, value, meta): @v_args(meta=True) def atom_subscript(self, value, meta): return ExprSubscript( - expr=value[0], offset=value[2], notes=value[1], location=self.meta2loc(meta)) + expr=value[0], offset=value[2], notes=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def atom_dot(self, value, meta): @@ -203,7 +271,8 @@ def atom_dot(self, value, meta): @v_args(meta=True) def atom_cast(self, value, meta): return ExprCast( - expr=value[1], notes=value[0], dest_type=value[2], location=self.meta2loc(meta)) + expr=value[1], notes=value[0], dest_type=value[2], location=self.meta2loc(meta) + ) @v_args(meta=True) def atom_tuple(self, value, meta): @@ -231,11 +300,11 @@ def bool_expr_neq(self, value, meta): @v_args(meta=True) def modifier_local(self, value, meta): - return Modifier(name='local', location=self.meta2loc(meta)) + return Modifier(name="local", location=self.meta2loc(meta)) @v_args(meta=True) def typed_identifier(self, value, meta): - assert len(value) in [1, 2, 3], f'Unexpected argument: value={value}' + assert len(value) in [1, 2, 3], f"Unexpected argument: value={value}" modifier = None if isinstance(value[0], Modifier): modifier = value.pop(0) @@ -270,17 +339,19 @@ def inst_jmp_to_label(self, value, meta): @v_args(meta=True) def inst_jnz(self, value, meta): - if value[2] != '0': + if value[2] != "0": raise ParserError('Invalid syntax, expected "!= 0".', location=self.meta2loc(meta)) return JnzInstruction( - jump_offset=value[0], condition=value[1], location=self.meta2loc(meta)) + jump_offset=value[0], condition=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def inst_jnz_to_label(self, value, meta): - if value[2] != '0': + if value[2] != "0": raise ParserError('Invalid syntax, expected "!= 0".', location=self.meta2loc(meta)) return JumpToLabelInstruction( - label=value[0], condition=value[1], location=self.meta2loc(meta)) + label=value[0], condition=value[1], location=self.meta2loc(meta) + ) @v_args(meta=True) def inst_call_rel(self, value, meta): @@ -313,11 +384,11 @@ def instruction_ap(self, value, meta): # RValues. def rvalue_expr(self, value): - expr, = value + (expr,) = value return RvalueExpr(expr=expr) def rvalue_call_instruction(self, value): - call_inst, = value + (call_inst,) = value return RvalueCallInst(call_inst=call_inst) @v_args(meta=True) @@ -328,11 +399,14 @@ def function_call(self, value, meta): elif len(value) == 3: func_ident, implicit_args, arg_list = value else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") return RvalueFuncCall( - func_ident=func_ident, arguments=arg_list, implicit_arguments=implicit_args, - location=self.meta2loc(meta)) + func_ident=func_ident, + arguments=arg_list, + implicit_arguments=implicit_args, + location=self.meta2loc(meta), + ) # CairoFile. @@ -348,8 +422,7 @@ def code_element_member(self, value): def code_element_reference(self, value): ref_binding, rvalue = value if isinstance(ref_binding, IdentifierList): - return CodeElementUnpackBinding( - unpacking_list=ref_binding, rvalue=rvalue) + return CodeElementUnpackBinding(unpacking_list=ref_binding, rvalue=rvalue) elif isinstance(ref_binding, TypedIdentifier): typed_identifier = ref_binding if isinstance(rvalue, RvalueCall): @@ -360,7 +433,7 @@ def code_element_reference(self, value): elif isinstance(rvalue, RvalueExpr): return CodeElementReference(typed_identifier=typed_identifier, expr=rvalue.expr) - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") @v_args(meta=True) def code_element_local_var(self, value, meta): @@ -370,15 +443,16 @@ def code_element_local_var(self, value, meta): elif len(value) == 2: typed_identifier, expr = value else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") return CodeElementLocalVariable( - typed_identifier=typed_identifier, expr=expr, location=self.meta2loc(meta)) + typed_identifier=typed_identifier, expr=expr, location=self.meta2loc(meta) + ) @v_args(meta=True) def code_element_temp_var(self, value, meta): typed_identifier, *maybe_expr = value - expr, = maybe_expr if len(maybe_expr) > 0 else [None] + (expr,) = maybe_expr if len(maybe_expr) > 0 else [None] return CodeElementTemporaryVariable( typed_identifier=typed_identifier, @@ -392,7 +466,7 @@ def code_element_static_assert(self, value, meta): @v_args(meta=True) def code_element_return(self, value, meta): - arglist, = value + (arglist,) = value return CodeElementReturn(exprs=arglist.args, location=self.meta2loc(meta)) @v_args(meta=True) @@ -419,7 +493,8 @@ def code_element_empty_line(self, value): def commented_code_element(self, value, meta): comment = value[1][1:] if len(value) == 2 else None return CommentedCodeElement( - code_elm=value[0], comment=comment, location=self.meta2loc(meta)) + code_elm=value[0], comment=comment, location=self.meta2loc(meta) + ) def code_block(self, value): return CodeBlock(code_elements=value) @@ -437,7 +512,7 @@ def implicit_arguments(self, value): elif len(value) == 1: return value[0] else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") def decorator_list(self, value): return value @@ -453,10 +528,10 @@ def code_element_function(self, value): returns = None code_block = value[4] else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") return CodeElementFunction( - element_type='func', + element_type="func", identifier=identifier, arguments=arguments, implicit_arguments=implicit_arguments, @@ -493,7 +568,7 @@ def code_element_if(self, value, meta): elif len(value) == 3: else_code_block = value[2] else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") # Create a location for the if keyword. location: Optional[Location] = None @@ -502,13 +577,16 @@ def code_element_if(self, value, meta): start_line=meta.line, start_col=meta.column, end_line=meta.line, - end_col=meta.column + len('if'), + end_col=meta.column + len("if"), input_file=self.input_file, ) return CodeElementIf( - condition=condition, main_code_block=main_code_block, else_code_block=else_code_block, - location=location) + condition=condition, + main_code_block=main_code_block, + else_code_block=else_code_block, + location=location, + ) @v_args(meta=True) def code_element_directive(self, value, meta): @@ -527,18 +605,17 @@ def directive_lang(self, value, meta): def aliased_identifier(self, value, meta): if len(value) == 1: # Element of the form: . - identifier, = value + (identifier,) = value local_name = None elif len(value) == 2: # Element of the form: as . identifier, local_name = value else: - raise NotImplementedError(f'Unexpected argument: value={value}') + raise NotImplementedError(f"Unexpected argument: value={value}") return AliasedIdentifier( - orig_identifier=identifier, - local_name=local_name, - location=self.meta2loc(meta)) + orig_identifier=identifier, local_name=local_name, location=self.meta2loc(meta) + ) @v_args(meta=True) def code_element_import(self, value, meta): @@ -549,7 +626,7 @@ def code_element_import(self, value, meta): notes = [] else: # Multiline. - assert len(value) % 3 == 2, f'Unexpected value {value}.' + assert len(value) % 3 == 2, f"Unexpected value {value}." import_items = value[2::3] # Join the notes before and after the comma. notes = [value[1]] + [value[i] + value[i + 1] for i in range(3, len(value) - 1, 3)] @@ -571,7 +648,7 @@ def cairo_file(self, value): # Notes. def note_new_line(self, value): - return '\n' + return "\n" @v_args(meta=True) def notes(self, value, meta): @@ -586,14 +663,15 @@ def notes(self, value, meta): comments = [] for v in value: - if v == '\n': + if v == "\n": if not saw_comment: starts_new_line = True else: comments.append(v.value) saw_comment = True return Notes( - comments=comments, starts_new_line=starts_new_line, location=self.meta2loc(meta)) + comments=comments, starts_new_line=starts_new_line, location=self.meta2loc(meta) + ) def meta2loc(self, meta): if meta.empty: diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py index 583d5aac..af8f217f 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions.py @@ -4,14 +4,27 @@ from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeElement, CodeElementTemporaryVariable) + CodeElement, + CodeElementTemporaryVariable, +) from starkware.cairo.lang.compiler.ast.expr import ( - ExprConst, ExprDeref, Expression, ExprHint, ExprIdentifier, ExprNeg, ExprOperator, ExprReg) + ExprConst, + ExprDeref, + Expression, + ExprHint, + ExprIdentifier, + ExprNeg, + ExprOperator, + ExprReg, +) from starkware.cairo.lang.compiler.ast.types import TypedIdentifier from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.instruction_builder import ( - InstructionBuilderError, _parse_offset, _parse_register_offset) + InstructionBuilderError, + _parse_offset, + _parse_register_offset, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.references import translate_ap @@ -79,7 +92,7 @@ def rewrite(self, expr: Expression, sim: SimplicityLevel): will be replaced by a variable for DEREF. The expression "[ap] + 6" will be left unchanged for OPERATION but will be replaced by a variable for DEREF and DEREF_CONST. """ - funcname = f'rewrite_{type(expr).__name__}' + funcname = f"rewrite_{type(expr).__name__}" return getattr(self, funcname)(expr, sim) def rewrite_ExprConst(self, expr: ExprConst, sim: SimplicityLevel): @@ -90,19 +103,21 @@ def rewrite_ExprConst(self, expr: ExprConst, sim: SimplicityLevel): def rewrite_ExprReg(self, expr: ExprReg, sim: SimplicityLevel): if expr.reg is Register.AP: raise PreprocessorError( - 'ap may only be used in an expression of the form [ap + ].', - location=expr.location) + "ap may only be used in an expression of the form [ap + ].", + location=expr.location, + ) elif expr.reg is Register.FP: return self.rewrite(expr=self.context.get_fp_val(expr.location), sim=sim) else: - raise NotImplementedError(f'Unknown register {expr.reg}.') + raise NotImplementedError(f"Unknown register {expr.reg}.") def rewrite_ExprOperator(self, expr: ExprOperator, sim: SimplicityLevel): expr = ExprOperator( a=self.rewrite(expr.a, SimplicityLevel.DEREF), op=expr.op, b=self.rewrite(expr.b, SimplicityLevel.DEREF_CONST), - location=expr.location) + location=expr.location, + ) if sim is SimplicityLevel.OPERATION: return expr @@ -121,16 +136,20 @@ def rewrite_ExprOperator(self, expr: ExprOperator, sim: SimplicityLevel): def rewrite_ExprPow(self, expr: ExprReg, sim: SimplicityLevel): raise PreprocessorError( - "Operator '**' is only supported for constant values.", - location=expr.location) + "Operator '**' is only supported for constant values.", location=expr.location + ) def rewrite_ExprNeg(self, expr: ExprNeg, sim: SimplicityLevel): # Treat "-val" as "val * (-1)". - return self.rewrite(ExprOperator( - a=expr.val, - op='*', - b=ExprConst(val=-1, location=expr.location), - location=expr.location), sim) + return self.rewrite( + ExprOperator( + a=expr.val, + op="*", + b=ExprConst(val=-1, location=expr.location), + location=expr.location, + ), + sim, + ) def rewrite_ExprDeref(self, expr: ExprDeref, sim: SimplicityLevel): if is_simple_deref(expr): @@ -138,7 +157,8 @@ def rewrite_ExprDeref(self, expr: ExprDeref, sim: SimplicityLevel): return expr expr = ExprDeref( - addr=self.rewrite(expr.addr, SimplicityLevel.DEREF_OFFSET), location=expr.location) + addr=self.rewrite(expr.addr, SimplicityLevel.DEREF_OFFSET), location=expr.location + ) return expr if sim is SimplicityLevel.OPERATION else self.wrap(expr) def rewrite_ExprHint(self, expr: ExprHint, sim: SimplicityLevel): @@ -150,11 +170,15 @@ def wrap(self, expr: Expression) -> ExprIdentifier: expr = self.translate_ap(expr) self.n_vars += 1 - self.code_elements.append(CodeElementTemporaryVariable( - typed_identifier=TypedIdentifier( - identifier=identifier, expr_type=TypeFelt(location=expr.location)), - expr=expr, - location=expr.location)) + self.code_elements.append( + CodeElementTemporaryVariable( + typed_identifier=TypedIdentifier( + identifier=identifier, expr_type=TypeFelt(location=expr.location) + ), + expr=expr, + location=expr.location, + ) + ) return identifier def translate_ap(self, expr): @@ -162,8 +186,10 @@ def translate_ap(self, expr): def process_compound_expressions( - exprs: List[Expression], simplicity: Union[SimplicityLevel, List[SimplicityLevel]], - context: CompoundExpressionContext) -> Tuple[List[CodeElement], List[Expression]]: + exprs: List[Expression], + simplicity: Union[SimplicityLevel, List[SimplicityLevel]], + context: CompoundExpressionContext, +) -> Tuple[List[CodeElement], List[Expression]]: """ Rewrites the given list of expressions, by adding temporary variables, in the required simiplicity levels. @@ -192,7 +218,8 @@ def process_compound_expressions( def process_compound_assert( - expr_a: Expression, expr_b: Expression, context: CompoundExpressionContext): + expr_a: Expression, expr_b: Expression, context: CompoundExpressionContext +): """ A version of process_compound_expressions() for assert instructions. Takes two expressions and returns them simplified to levels [DEREF, OPERATION] or [OPERATION, DEREF], @@ -213,5 +240,6 @@ def process_compound_assert( simplicity = [SimplicityLevel.OPERATION, SimplicityLevel.DEREF] code_elements, exprs = process_compound_expressions( - exprs=[expr_a, expr_b], simplicity=simplicity, context=context) + exprs=[expr_a, expr_b], simplicity=simplicity, context=context + ) return code_elements, exprs diff --git a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py index a5cdbab2..1a8fd147 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/compound_expressions_test.py @@ -7,10 +7,17 @@ from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.preprocessor.compound_expressions import ( - CompoundExpressionContext, CompoundExpressionVisitor, SimplicityLevel, - process_compound_expressions) + CompoundExpressionContext, + CompoundExpressionVisitor, + SimplicityLevel, + process_compound_expressions, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( - PRIME, preprocess_str, strip_comments_and_linebreaks, verify_exception) + PRIME, + preprocess_str, + strip_comments_and_linebreaks, + verify_exception, +) class CompoundExpressionTestContext(CompoundExpressionContext): @@ -18,97 +25,111 @@ def __init__(self): self.tempvar_name_counter = itertools.count(0) def new_tempvar_name(self) -> str: - return f'x{next(self.tempvar_name_counter)}' + return f"x{next(self.tempvar_name_counter)}" def get_fp_val(self, location: Optional[Location]) -> Expression: - raise NotImplementedError('fp is not supported in the test.') + raise NotImplementedError("fp is not supported in the test.") -@pytest.mark.parametrize('expr_str, to_operation, to_deref_const, to_deref_offset, to_deref', [ +@pytest.mark.parametrize( + "expr_str, to_operation, to_deref_const, to_deref_offset, to_deref", [ - '5', - '5', - '5', - 'tempvar x0 : felt = 5; x0', - 'tempvar x0 : felt = 5; x0', + [ + "5", + "5", + "5", + "tempvar x0 : felt = 5; x0", + "tempvar x0 : felt = 5; x0", + ], + [ + "[ap + 5]", + "[ap + 5]", + "[ap + 5]", + "[ap + 5]", + "[ap + 5]", + ], + [ + "[ap + 5] + 3", + "[ap + 5] + 3", + "tempvar x0 : felt = [ap - 0 + 5] + 3; x0", + "[ap + 5] + 3", + "tempvar x0 : felt = [ap - 0 + 5] + 3; x0", + ], + [ + "3 + [ap + 5]", + "tempvar x0 : felt = 3; x0 + [ap + 5]", + "tempvar x0 : felt = 3; tempvar x1 : felt = x0 + [ap - 1 + 5]; x1", + "tempvar x0 : felt = 3; tempvar x1 : felt = x0 + [ap - 1 + 5]; x1", + "tempvar x0 : felt = 3; tempvar x1 : felt = x0 + [ap - 1 + 5]; x1", + ], + [ + "[[ap + 5]]", + "[[ap + 5]]", + "tempvar x0 : felt = [[ap - 0 + 5]]; x0", + "tempvar x0 : felt = [[ap - 0 + 5]]; x0", + "tempvar x0 : felt = [[ap - 0 + 5]]; x0", + ], + [ + "[[[ap + 5]]]", + "tempvar x0 : felt = [[ap - 0 + 5]]; [x0]", + "tempvar x0 : felt = [[ap - 0 + 5]]; tempvar x1 : felt = [x0]; x1", + "tempvar x0 : felt = [[ap - 0 + 5]]; tempvar x1 : felt = [x0]; x1", + "tempvar x0 : felt = [[ap - 0 + 5]]; tempvar x1 : felt = [x0]; x1", + ], + [ + "[3]", + "tempvar x0 : felt = 3; [x0]", + "tempvar x0 : felt = 3; tempvar x1 : felt = [x0]; x1", + "tempvar x0 : felt = 3; tempvar x1 : felt = [x0]; x1", + "tempvar x0 : felt = 3; tempvar x1 : felt = [x0]; x1", + ], + [ + "-[ap + 3]", + "[ap + 3] * (-1)", + "tempvar x0 : felt = [ap - 0 + 3] * (-1); x0", + "tempvar x0 : felt = [ap - 0 + 3] * (-1); x0", + "tempvar x0 : felt = [ap - 0 + 3] * (-1); x0", + ], ], - [ - '[ap + 5]', - '[ap + 5]', - '[ap + 5]', - '[ap + 5]', - '[ap + 5]', - ], - [ - '[ap + 5] + 3', - '[ap + 5] + 3', - 'tempvar x0 : felt = [ap - 0 + 5] + 3; x0', - '[ap + 5] + 3', - 'tempvar x0 : felt = [ap - 0 + 5] + 3; x0', - ], - [ - '3 + [ap + 5]', - 'tempvar x0 : felt = 3; x0 + [ap + 5]', - 'tempvar x0 : felt = 3; tempvar x1 : felt = x0 + [ap - 1 + 5]; x1', - 'tempvar x0 : felt = 3; tempvar x1 : felt = x0 + [ap - 1 + 5]; x1', - 'tempvar x0 : felt = 3; tempvar x1 : felt = x0 + [ap - 1 + 5]; x1', - ], - [ - '[[ap + 5]]', - '[[ap + 5]]', - 'tempvar x0 : felt = [[ap - 0 + 5]]; x0', - 'tempvar x0 : felt = [[ap - 0 + 5]]; x0', - 'tempvar x0 : felt = [[ap - 0 + 5]]; x0', - ], - [ - '[[[ap + 5]]]', - 'tempvar x0 : felt = [[ap - 0 + 5]]; [x0]', - 'tempvar x0 : felt = [[ap - 0 + 5]]; tempvar x1 : felt = [x0]; x1', - 'tempvar x0 : felt = [[ap - 0 + 5]]; tempvar x1 : felt = [x0]; x1', - 'tempvar x0 : felt = [[ap - 0 + 5]]; tempvar x1 : felt = [x0]; x1', - ], - [ - '[3]', - 'tempvar x0 : felt = 3; [x0]', - 'tempvar x0 : felt = 3; tempvar x1 : felt = [x0]; x1', - 'tempvar x0 : felt = 3; tempvar x1 : felt = [x0]; x1', - 'tempvar x0 : felt = 3; tempvar x1 : felt = [x0]; x1', - ], - [ - '-[ap + 3]', - '[ap + 3] * (-1)', - 'tempvar x0 : felt = [ap - 0 + 3] * (-1); x0', - 'tempvar x0 : felt = [ap - 0 + 3] * (-1); x0', - 'tempvar x0 : felt = [ap - 0 + 3] * (-1); x0', - ], -]) +) def test_compound_expression_visitor( - expr_str: str, to_operation: str, to_deref_const: str, to_deref_offset: str, to_deref: str): + expr_str: str, to_operation: str, to_deref_const: str, to_deref_offset: str, to_deref: str +): """ Tests rewriting various expression, to the different simplicity levels. For example, to_operation is the expected result when the simplicity level is OPERATION. """ expr = parse_expr(expr_str) for sim, expected_result in [ - (SimplicityLevel.OPERATION, to_operation), - (SimplicityLevel.DEREF_CONST, to_deref_const), - (SimplicityLevel.DEREF_OFFSET, to_deref_offset), - (SimplicityLevel.DEREF, to_deref)]: + (SimplicityLevel.OPERATION, to_operation), + (SimplicityLevel.DEREF_CONST, to_deref_const), + (SimplicityLevel.DEREF_OFFSET, to_deref_offset), + (SimplicityLevel.DEREF, to_deref), + ]: visitor = CompoundExpressionVisitor(context=CompoundExpressionTestContext()) res = visitor.rewrite(expr, sim) - assert ''.join( - code_element.format(allowed_line_length=100) + '; ' - for code_element in visitor.code_elements) + res.format() == expected_result + assert ( + "".join( + code_element.format(allowed_line_length=100) + "; " + for code_element in visitor.code_elements + ) + + res.format() + == expected_result + ) def test_compound_expression_visitor_long(): visitor = CompoundExpressionVisitor(context=CompoundExpressionTestContext()) res = visitor.rewrite( - parse_expr('[ap + 100] - [fp] * [[-[ap + 200] / [ap + 300]]] + [ap] * [ap]'), - SimplicityLevel.OPERATION) - assert ''.join( - code_element.format(allowed_line_length=100) + '\n' - for code_element in visitor.code_elements) == """\ + parse_expr("[ap + 100] - [fp] * [[-[ap + 200] / [ap + 300]]] + [ap] * [ap]"), + SimplicityLevel.OPERATION, + ) + assert ( + "".join( + code_element.format(allowed_line_length=100) + "\n" + for code_element in visitor.code_elements + ) + == """\ tempvar x0 : felt = [ap - 0 + 200] * (-1) tempvar x1 : felt = x0 / [ap - 1 + 300] tempvar x2 : felt = [x1] @@ -117,15 +138,19 @@ def test_compound_expression_visitor_long(): tempvar x5 : felt = [ap - 5 + 100] - x4 tempvar x6 : felt = [ap - 6] * [ap - 6] """ - assert res.format() == 'x5 + x6' + ) + assert res.format() == "x5 + x6" def test_compound_expression_visitor_inverses(): visitor = CompoundExpressionVisitor(context=CompoundExpressionTestContext()) - res = visitor.rewrite(parse_expr('2 - 1 / [ap] + [ap] / 3'), SimplicityLevel.DEREF) - assert ''.join( - code_element.format(allowed_line_length=100) + '\n' - for code_element in visitor.code_elements) == """\ + res = visitor.rewrite(parse_expr("2 - 1 / [ap] + [ap] / 3"), SimplicityLevel.DEREF) + assert ( + "".join( + code_element.format(allowed_line_length=100) + "\n" + for code_element in visitor.code_elements + ) + == """\ tempvar x0 : felt = 2 tempvar x1 : felt = 1 tempvar x2 : felt = x1 / [ap - 2] @@ -133,35 +158,48 @@ def test_compound_expression_visitor_inverses(): tempvar x4 : felt = [ap - 4] / 3 tempvar x5 : felt = x3 + x4 """ - assert res.format() == 'x5' + ) + assert res.format() == "x5" def test_process_compound_expressions(): - code_elements, res = process_compound_expressions(list(map(parse_expr, [ - '[ap - 1] + 5', - '[ap - 1] * [ap - 1]', - '[ap - 1] * [ap - 1]', - '[ap - 2] * [ap - 2] * [ap - 3]', - '[ap - 1]', - ])), [ - SimplicityLevel.OPERATION, - SimplicityLevel.OPERATION, - SimplicityLevel.DEREF, - SimplicityLevel.OPERATION, - SimplicityLevel.OPERATION, - ], context=CompoundExpressionTestContext()) - assert ''.join( - code_element.format(allowed_line_length=100) + '\n' - for code_element in code_elements) == """\ + code_elements, res = process_compound_expressions( + list( + map( + parse_expr, + [ + "[ap - 1] + 5", + "[ap - 1] * [ap - 1]", + "[ap - 1] * [ap - 1]", + "[ap - 2] * [ap - 2] * [ap - 3]", + "[ap - 1]", + ], + ) + ), + [ + SimplicityLevel.OPERATION, + SimplicityLevel.OPERATION, + SimplicityLevel.DEREF, + SimplicityLevel.OPERATION, + SimplicityLevel.OPERATION, + ], + context=CompoundExpressionTestContext(), + ) + assert ( + "".join( + code_element.format(allowed_line_length=100) + "\n" for code_element in code_elements + ) + == """\ tempvar x0 : felt = [ap - 0 - 1] * [ap - 0 - 1] tempvar x1 : felt = [ap - 1 - 2] * [ap - 1 - 2] """ + ) assert [x.format() for x in res] == [ - '[ap - 2 - 1] + 5', - '[ap - 2 - 1] * [ap - 2 - 1]', - 'x0', - 'x1 * [ap - 2 - 3]', - '[ap - 2 - 1]', + "[ap - 2 - 1] + 5", + "[ap - 2 - 1] * [ap - 2 - 1]", + "x0", + "x1 * [ap - 2 - 3]", + "[ap - 2 - 1]", ] @@ -183,7 +221,9 @@ def test_compound_expressions(): assert [fp] = fp + fp """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [ap + 1] * [ap + 2] [ap] = [[ap + (-1)]]; ap++ @@ -213,7 +253,10 @@ def test_compound_expressions(): [ap] = [ap + (-1)] + [ap + (-1)]; ap++ [ap] = [ap + (-2)] + [ap + (-2)]; ap++ [fp] = [ap + (-2)] + [ap + (-1)] -""".replace('\n\n', '\n') +""".replace( + "\n\n", "\n" + ) + ) def test_compound_expressions_long(): @@ -249,7 +292,9 @@ def test_compound_expressions_tempvars(): tempvar z = 5 + nondet %{ val %} * 15 + nondet %{ 1 %} """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [ap + (-1)] * [ap + (-1)]; ap++ [ap] = [ap + (-2)] * [ap + (-3)]; ap++ [ap] = [ap + (-2)] + [ap + (-1)]; ap++ @@ -263,7 +308,10 @@ def test_compound_expressions_tempvars(): %{ memory[ap] = int(1) %} ap += 1 [ap] = [ap + (-2)] + [ap + (-1)]; ap++ -""".replace('\n\n', '\n') +""".replace( + "\n\n", "\n" + ) + ) def test_compound_expressions_localvar(): @@ -323,41 +371,56 @@ def test_compound_expressions_args(): def test_compound_expressions_failures(): - verify_exception("""\ + verify_exception( + """\ assert [ap + [ap]] = [ap] -""", """ +""", + """ file:?:?: ap may only be used in an expression of the form [ap + ]. assert [ap + [ap]] = [ap] ^^ -""") - verify_exception("""\ +""", + ) + verify_exception( + """\ assert [[ap]] = ap -""", """ +""", + """ file:?:?: ap may only be used in an expression of the form [ap + ]. assert [[ap]] = ap ^^ -""") - verify_exception("""\ +""", + ) + verify_exception( + """\ assert [[fp]] = fp -""", """ +""", + """ file:?:?: Using the value of fp directly, requires defining a variable named __fp__. assert [[fp]] = fp ^^ -""") - verify_exception("""\ +""", + ) + verify_exception( + """\ assert [ap] = [ap + 32768] # Offset is out of bounds. -""", """ +""", + """ file:?:?: ap may only be used in an expression of the form [ap + ]. assert [ap] = [ap + 32768] # Offset is out of bounds. ^^ -""") - verify_exception("""\ +""", + ) + verify_exception( + """\ struct T: member a : felt end assert 7 = cast(7, T*) -""", """ +""", + """ file:?:?: Cannot compare 'felt' and 'test_scope.T*'. assert 7 = cast(7, T*) ^********************^ -""") +""", + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/conftest.py b/src/starkware/cairo/lang/compiler/preprocessor/conftest.py index 98e6bbb1..d6dd000d 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/conftest.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/conftest.py @@ -3,4 +3,4 @@ # Instruct pytest to print full information (e.g., the values on both sides of the equality) # about asserts that failed in the module below. # Normally, pytest prints full information only for test files (according to their name). -pytest.register_assert_rewrite('starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils') +pytest.register_assert_rewrite("starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils") diff --git a/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py b/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py index 468aea5f..595f2da0 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/default_pass_manager.py @@ -5,7 +5,11 @@ from starkware.cairo.lang.compiler.preprocessor.dependency_graph import DependencyGraphStage from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.preprocessor.pass_manager import ( - PassManager, PassManagerContext, Stage, VisitorStage) + PassManager, + PassManagerContext, + Stage, + VisitorStage, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor import Preprocessor from starkware.cairo.lang.compiler.preprocessor.struct_collector import StructCollector from starkware.cairo.lang.compiler.preprocessor.unique_labels import UniqueLabelCreator @@ -13,30 +17,40 @@ def default_pass_manager( - prime: int, - read_module: Callable[[str], Tuple[str, str]], - preprocessor_cls: Optional[Type[Preprocessor]] = None, - opt_unused_functions: bool = True, - preprocessor_kwargs: Optional[Dict] = None) -> PassManager: + prime: int, + read_module: Callable[[str], Tuple[str, str]], + preprocessor_cls: Optional[Type[Preprocessor]] = None, + opt_unused_functions: bool = True, + preprocessor_kwargs: Optional[Dict] = None, +) -> PassManager: manager = PassManager() - manager.add_stage('module_collector', ModuleCollector(read_module=read_module)) - manager.add_stage('unique_label_creator', VisitorStage( - lambda context: UniqueLabelCreator(), modify_ast=True)) - manager.add_stage('identifier_collector', VisitorStage( - lambda context: IdentifierCollector(identifiers=context.identifiers))) + manager.add_stage("module_collector", ModuleCollector(read_module=read_module)) + manager.add_stage( + "unique_label_creator", VisitorStage(lambda context: UniqueLabelCreator(), modify_ast=True) + ) + manager.add_stage( + "identifier_collector", + VisitorStage(lambda context: IdentifierCollector(identifiers=context.identifiers)), + ) if opt_unused_functions: - manager.add_stage('dependency_graph', DependencyGraphStage()) - manager.add_stage('struct_collector', VisitorStage( - lambda context: StructCollector(identifiers=context.identifiers))) - manager.add_stage('preprocessor', PreprocessorStage( - prime, preprocessor_cls, preprocessor_kwargs)) + manager.add_stage("dependency_graph", DependencyGraphStage()) + manager.add_stage( + "struct_collector", + VisitorStage(lambda context: StructCollector(identifiers=context.identifiers)), + ) + manager.add_stage( + "preprocessor", PreprocessorStage(prime, preprocessor_cls, preprocessor_kwargs) + ) return manager class PreprocessorStage(Stage): def __init__( - self, prime: int, preprocessor_cls: Optional[Type[Preprocessor]] = None, - preprocessor_kwargs: Optional[Dict] = None): + self, + prime: int, + preprocessor_cls: Optional[Type[Preprocessor]] = None, + preprocessor_kwargs: Optional[Dict] = None, + ): self.prime = prime if preprocessor_cls is None: self.preprocessor_cls = Preprocessor @@ -46,8 +60,11 @@ def __init__( def run(self, context: PassManagerContext): preprocessor = self.preprocessor_cls( - prime=self.prime, identifiers=context.identifiers, - functions_to_compile=context.functions_to_compile, **self.preprocessor_kwargs) + prime=self.prime, + identifiers=context.identifiers, + functions_to_compile=context.functions_to_compile, + **self.preprocessor_kwargs, + ) preprocessor.identifier_locations = context.identifier_locations for module in context.modules: @@ -59,8 +76,10 @@ def run(self, context: PassManagerContext): class ModuleCollector(Stage): def __init__( - self, read_module: Callable[[str], Tuple[str, str]], - additional_modules: Optional[Sequence[str]] = None): + self, + read_module: Callable[[str], Tuple[str, str]], + additional_modules: Optional[Sequence[str]] = None, + ): self.read_module = read_module self.additional_modules = [] if additional_modules is None else list(additional_modules) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py index be5e952a..3092a31b 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph.py @@ -7,7 +7,9 @@ from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.identifier_definition import AliasDefinition from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierManager, MissingIdentifierError) + IdentifierManager, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManagerContext, Stage from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -33,21 +35,22 @@ def _visit_default(self, obj): self.visit(child) def add_identifier( - self, name: ScopedName, location: Optional[Location], is_resolved: bool = False): - if name.path[-1] == '_': + self, name: ScopedName, location: Optional[Location], is_resolved: bool = False + ): + if name.path[-1] == "_": return if is_resolved: canonical_name = name else: try: canonical_name = self.identifiers.search( - accessible_scopes=self.accessible_scopes, name=name).canonical_name + accessible_scopes=self.accessible_scopes, name=name + ).canonical_name except MissingIdentifierError as e: raise PreprocessorError(str(e), location=location) if self.current_function is not None: - self.visited_identifiers.setdefault(self.current_function, []).append( - canonical_name) + self.visited_identifiers.setdefault(self.current_function, []).append(canonical_name) def visit_CodeElementMember(self, elm): pass @@ -57,7 +60,7 @@ def visit_ExprDot(self, expr: ExprDot): self.visit(expr.expr) def visit_CodeElementFunction(self, elm: CodeElementFunction): - if elm.element_type == 'func': + if elm.element_type == "func": # Update self.current_function. old_current_function = self.current_function try: @@ -80,10 +83,11 @@ def visit_ExprIdentifier(self, expr: ExprIdentifier): def visit_CodeElementImport(self, code_elm: CodeElementImport): for import_item in code_elm.import_items: self.add_identifier( - ScopedName.from_string(code_elm.path.name) + - ScopedName.from_string(import_item.orig_identifier.name), + ScopedName.from_string(code_elm.path.name) + + ScopedName.from_string(import_item.orig_identifier.name), is_resolved=True, - location=code_elm.location) + location=code_elm.location, + ) def find_function_dependencies(self, functions: Set[ScopedName]) -> Set[ScopedName]: """ @@ -121,7 +125,8 @@ def visit(self, name: ScopedName): def get_main_functions_to_compile( - identifiers: IdentifierManager, main_scope: ScopedName) -> Set[ScopedName]: + identifiers: IdentifierManager, main_scope: ScopedName +) -> Set[ScopedName]: """ Retrieves the root functions to compile from a main scope. The definition of which functions we need to compile is somewhat arbitrary: @@ -134,15 +139,16 @@ def get_main_functions_to_compile( main_functions |= { identifier_definition.destination for identifier_definition in scope.identifiers.values() - if isinstance(identifier_definition, AliasDefinition)} + if isinstance(identifier_definition, AliasDefinition) + } except MissingIdentifierError: return set() return main_functions def get_functions_to_compile( - modules: List[CairoModule], identifiers: IdentifierManager, - main_scope: ScopedName) -> Set[ScopedName]: + modules: List[CairoModule], identifiers: IdentifierManager, main_scope: ScopedName +) -> Set[ScopedName]: """ Returns a set of reachable function (starting from the functions in the main scope). """ @@ -150,12 +156,14 @@ def get_functions_to_compile( dependency_graph = DependencyGraphVisitor(identifiers) for module in modules: dependency_graph.visit(module) - return dependency_graph.find_function_dependencies(get_main_functions_to_compile( - identifiers=identifiers, main_scope=main_scope)) + return dependency_graph.find_function_dependencies( + get_main_functions_to_compile(identifiers=identifiers, main_scope=main_scope) + ) class DependencyGraphStage(Stage): def run(self, context: PassManagerContext): assert context.functions_to_compile is None context.functions_to_compile = get_functions_to_compile( - modules=context.modules, identifiers=context.identifiers, main_scope=context.main_scope) + modules=context.modules, identifiers=context.identifiers, main_scope=context.main_scope + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py index 8291b8e8..e78a9688 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/dependency_graph_test.py @@ -3,7 +3,9 @@ from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.parser import parse_file from starkware.cairo.lang.compiler.preprocessor.dependency_graph import ( - DependencyGraphVisitor, get_main_functions_to_compile) + DependencyGraphVisitor, + get_main_functions_to_compile, +) from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -19,7 +21,9 @@ def _extract_dependency_graph(codes: Dict[str, str]) -> DependencyGraphVisitor: CairoModule( cairo_file=parse_file(code), module_name=ScopedName.from_string(name), - ) for name, code in codes.items()] + ) + for name, code in codes.items() + ] identifier_collector = IdentifierCollector() for module in modules: identifier_collector.visit(module) @@ -31,7 +35,7 @@ def _extract_dependency_graph(codes: Dict[str, str]) -> DependencyGraphVisitor: def test_dependency_graph(): modules = { - 'module': """ + "module": """ func func0() -> (res): return (res=0) end @@ -45,7 +49,7 @@ def test_dependency_graph(): return () end """, - '__main__': """ + "__main__": """ from module import func1 as func1_alias func foo(): @@ -98,98 +102,98 @@ def test_dependency_graph(): call bar # This line will be ignored since it's outside of any function. """, - '': """ + "": """ from module import func2 -"""} +""", + } dependency_graph_visitor = _extract_dependency_graph(modules) dependencies = { str(scope): set(map(str, deps)) - for scope, deps in dependency_graph_visitor.visited_identifiers.items()} + for scope, deps in dependency_graph_visitor.visited_identifiers.items() + } assert dependencies == { - '__main__.foo': { - '__main__.foo._tempvar', - '__main__.foo._const', - '__main__.foo._local', - '__main__.foo._reference', - '__main__.foo._label', - '__main__.foo._typed_reference', - '__main__.ns.myfunc', - 'module.func0', - 'module.func2', + "__main__.foo": { + "__main__.foo._tempvar", + "__main__.foo._const", + "__main__.foo._local", + "__main__.foo._reference", + "__main__.foo._label", + "__main__.foo._typed_reference", + "__main__.ns.myfunc", + "module.func0", + "module.func2", }, - '__main__.ns.myfunc': { - '__main__.ns.myfunc', - 'module.func1', + "__main__.ns.myfunc": { + "__main__.ns.myfunc", + "module.func1", }, - '__main__.bar': { - '__main__.bar', - '__main__.bar.a', - '__main__.bar.w', - '__main__.bar.w_x', - '__main__.foo.S', - 'module.func0', + "__main__.bar": { + "__main__.bar", + "__main__.bar.a", + "__main__.bar.w", + "__main__.bar.w_x", + "__main__.foo.S", + "module.func0", }, - '__main__.main': { - '__main__.bar', - '__main__.foo._label', + "__main__.main": { + "__main__.bar", + "__main__.foo._label", }, - 'module.func0': set(), - 'module.func1': set(), - 'module.func2': set(), - 'module.func3': set(), + "module.func0": set(), + "module.func1": set(), + "module.func2": set(), + "module.func3": set(), } - assert dependency_graph_visitor.find_function_dependencies( - {scope('__main__.main')}) == { - ScopedName(path=('__main__', 'bar')), - ScopedName(path=('__main__', 'foo')), - ScopedName(path=('__main__', 'main')), - ScopedName(path=('__main__', 'ns', 'myfunc')), - ScopedName(path=('module', 'func0')), - ScopedName(path=('module', 'func1')), - ScopedName(path=('module', 'func2')), + assert dependency_graph_visitor.find_function_dependencies({scope("__main__.main")}) == { + ScopedName(path=("__main__", "bar")), + ScopedName(path=("__main__", "foo")), + ScopedName(path=("__main__", "main")), + ScopedName(path=("__main__", "ns", "myfunc")), + ScopedName(path=("module", "func0")), + ScopedName(path=("module", "func1")), + ScopedName(path=("module", "func2")), } - assert dependency_graph_visitor.find_function_dependencies( - {scope('__main__.ns.myfunc')}) == { - ScopedName(path=('__main__', 'ns', 'myfunc')), - ScopedName(path=('module', 'func1')), + assert dependency_graph_visitor.find_function_dependencies({scope("__main__.ns.myfunc")}) == { + ScopedName(path=("__main__", "ns", "myfunc")), + ScopedName(path=("module", "func1")), } assert dependency_graph_visitor.find_function_dependencies( - {scope('__main__.ns.myfunc'), scope('__main__.bar')}) == { - ScopedName(path=('__main__', 'bar')), - ScopedName(path=('__main__', 'foo')), - ScopedName(path=('__main__', 'ns', 'myfunc')), - ScopedName(path=('module', 'func0')), - ScopedName(path=('module', 'func1')), - ScopedName(path=('module', 'func2')), + {scope("__main__.ns.myfunc"), scope("__main__.bar")} + ) == { + ScopedName(path=("__main__", "bar")), + ScopedName(path=("__main__", "foo")), + ScopedName(path=("__main__", "ns", "myfunc")), + ScopedName(path=("module", "func0")), + ScopedName(path=("module", "func1")), + ScopedName(path=("module", "func2")), } - assert dependency_graph_visitor.find_function_dependencies( - {scope('foo')}) == set() + assert dependency_graph_visitor.find_function_dependencies({scope("foo")}) == set() # Test get_main_functions_to_compile(). assert get_main_functions_to_compile( - identifiers=dependency_graph_visitor.identifiers, - main_scope=scope('module')) == { - scope('module.func0'), - scope('module.func1'), - scope('module.func2'), - scope('module.func3'), + identifiers=dependency_graph_visitor.identifiers, main_scope=scope("module") + ) == { + scope("module.func0"), + scope("module.func1"), + scope("module.func2"), + scope("module.func3"), } assert get_main_functions_to_compile( - identifiers=dependency_graph_visitor.identifiers, - main_scope=scope('__main__')) == { - scope('module.func1'), - scope('__main__.foo'), - scope('__main__.ns'), - scope('__main__.bar'), - scope('__main__.main'), + identifiers=dependency_graph_visitor.identifiers, main_scope=scope("__main__") + ) == { + scope("module.func1"), + scope("__main__.foo"), + scope("__main__.ns"), + scope("__main__.bar"), + scope("__main__.main"), } assert get_main_functions_to_compile( - identifiers=dependency_graph_visitor.identifiers, - main_scope=scope('')) == { - scope('module.func2'), - scope('module'), - scope('__main__'), + identifiers=dependency_graph_visitor.identifiers, main_scope=scope("") + ) == { + scope("module.func2"), + scope("module"), + scope("__main__"), } diff --git a/src/starkware/cairo/lang/compiler/preprocessor/flow.py b/src/starkware/cairo/lang/compiler/preprocessor/flow.py index 09a6f7e4..91df5188 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/flow.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/flow.py @@ -9,7 +9,10 @@ from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.preprocessor.reg_tracking import ( - RegChange, RegChangeLike, RegTrackingData) + RegChange, + RegChangeLike, + RegTrackingData, +) from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference from starkware.cairo.lang.compiler.scoped_name import ScopedName, ScopedNameAsStr @@ -31,10 +34,11 @@ class FlowTrackingData(ABC): Tracking data representing the values of the references in a specific location in the program, considering all possible flows that may reach there. """ + @abstractmethod def converge( - self, reference_manager: ReferenceManager, other: 'FlowTrackingData', - group_alloc: Callable): + self, reference_manager: ReferenceManager, other: "FlowTrackingData", group_alloc: Callable + ): """ Returns a new tracking data representing all references that are valid coming from either self or from other. @@ -57,12 +61,12 @@ class FlowTrackingDataUnreachable(FlowTrackingData): """ def converge( - self, reference_manager: ReferenceManager, other: 'FlowTrackingData', - group_alloc: Callable): + self, reference_manager: ReferenceManager, other: "FlowTrackingData", group_alloc: Callable + ): return other def resolve_reference(self, reference_manager: ReferenceManager, name: ScopedName) -> Reference: - raise FlowTrackingError(f'Reference {name} revoked.') + raise FlowTrackingError(f"Reference {name} revoked.") @dataclasses.dataclass(frozen=True) @@ -70,16 +74,19 @@ class FlowTrackingDataActual(FlowTrackingData): """ Tracking data for a reachable location in the program. """ + # Current ap tracking. ap_tracking: RegTrackingData # Mapping from full reference name to the Reference instance. reference_ids: Dict[ScopedName, int] = field( metadata=dict( - marshmallow_field=mfields.Dict(keys=ScopedNameAsStr, values=mfields.Integer())), - default_factory=dict) + marshmallow_field=mfields.Dict(keys=ScopedNameAsStr, values=mfields.Integer()) + ), + default_factory=dict, + ) @classmethod - def new(cls, group_alloc: Callable) -> 'FlowTrackingDataActual': + def new(cls, group_alloc: Callable) -> "FlowTrackingDataActual": return cls( ap_tracking=RegTrackingData.new(group_alloc), ) @@ -87,12 +94,12 @@ def new(cls, group_alloc: Callable) -> 'FlowTrackingDataActual': def resolve_reference(self, reference_manager: ReferenceManager, name: ScopedName) -> Reference: ref_id = self.reference_ids.get(name) if ref_id is None: - raise FlowTrackingError(f'Reference {name} revoked.') + raise FlowTrackingError(f"Reference {name} revoked.") return reference_manager.get_ref(ref_id) def converge( - self, reference_manager: ReferenceManager, other: 'FlowTrackingData', - group_alloc: Callable): + self, reference_manager: ReferenceManager, other: "FlowTrackingData", group_alloc: Callable + ): if not isinstance(other, FlowTrackingDataActual): return other.converge(reference_manager, self, group_alloc) @@ -110,8 +117,9 @@ def converge( other_ref = reference_manager.get_ref(other_ref_id) try: ref_expr = reference.eval(self.ap_tracking) - if simplifier.visit(ref_expr) == \ - simplifier.visit(other_ref.eval(other.ap_tracking)): + if simplifier.visit(ref_expr) == simplifier.visit( + other_ref.eval(other.ap_tracking) + ): # Same expression. # Create a new reference on the new ap tracking. new_reference = Reference( @@ -130,13 +138,13 @@ def converge( reference_ids=reference_ids, ) - def add_ap(self, ap_change: RegChangeLike, group_alloc: Callable) -> 'FlowTrackingData': + def add_ap(self, ap_change: RegChangeLike, group_alloc: Callable) -> "FlowTrackingData": new_ap_tracking = self.ap_tracking.add(ap_change, group_alloc) return dataclasses.replace(self, ap_tracking=new_ap_tracking) def add_reference( - self, reference_manager: ReferenceManager, name: ScopedName, - ref: Reference) -> 'FlowTrackingData': + self, reference_manager: ReferenceManager, name: ScopedName, ref: Reference + ) -> "FlowTrackingData": """ Adds or rebinds a reference. """ @@ -181,8 +189,9 @@ def __init__(self): # Mapping from a fully qualified label name to its tracking data. # This begines unconstrained, and for every flow to this label, we 'converge' this data # with the new tracking data. - self.labels_data: Dict[ScopedName, FlowTrackingData] = \ - defaultdict(FlowTrackingDataUnreachable) + self.labels_data: Dict[ScopedName, FlowTrackingData] = defaultdict( + FlowTrackingDataUnreachable + ) self.groups = itertools.count(0) self.reference_manager = ReferenceManager() @@ -208,7 +217,8 @@ def add_flow_to_label(self, label_name: ScopedName, ap_change: RegChangeLike): ap_change = RegChange.from_expr(ap_change) new_data = self.get().add_ap(ap_change, self._group_alloc) self.labels_data[label_name] = self.labels_data[label_name].converge( - self.reference_manager, new_data, self._group_alloc) + self.reference_manager, new_data, self._group_alloc + ) def converge_with_label(self, label_name: ScopedName): """ @@ -216,7 +226,8 @@ def converge_with_label(self, label_name: ScopedName): label's definition location. """ self.data = self.data.converge( - self.reference_manager, self.labels_data[label_name], self._group_alloc) + self.reference_manager, self.labels_data[label_name], self._group_alloc + ) def revoke(self): self.data = FlowTrackingDataUnreachable() diff --git a/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py b/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py index 165f5712..f851c3db 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/flow_test.py @@ -26,59 +26,65 @@ def test_flow_tracking(): assert flow_tracking.get_ap_tracking() - loc2 == RegChangeKnown(11) -@pytest.mark.parametrize('changes', [ - # Good case. - SimpleNamespace(valid=True, label0=3, body0=1, label1=2, body1=2), - # Bad case - one mismatching jump. - SimpleNamespace(valid=False, label0=3, body0=1, label1=5, body1=2), - # Bad case - jump mismatch with current. - SimpleNamespace(valid=False, label0=3, body0=1, label1=2, body1=5), -]) +@pytest.mark.parametrize( + "changes", + [ + # Good case. + SimpleNamespace(valid=True, label0=3, body0=1, label1=2, body1=2), + # Bad case - one mismatching jump. + SimpleNamespace(valid=False, label0=3, body0=1, label1=5, body1=2), + # Bad case - jump mismatch with current. + SimpleNamespace(valid=False, label0=3, body0=1, label1=2, body1=5), + ], +) def test_flow_tracking_labels(changes): # Good case. flow_tracking = FlowTracking() - flow_tracking.add_flow_to_label(ScopedName.from_string('a'), changes.label0) + flow_tracking.add_flow_to_label(ScopedName.from_string("a"), changes.label0) flow_tracking.add_ap(changes.body0) - flow_tracking.add_flow_to_label(ScopedName.from_string('a'), changes.label1) + flow_tracking.add_flow_to_label(ScopedName.from_string("a"), changes.label1) flow_tracking.add_ap(changes.body1) current_data = flow_tracking.get() - flow_tracking.converge_with_label(ScopedName.from_string('a')) + flow_tracking.converge_with_label(ScopedName.from_string("a")) assert (flow_tracking.get() == current_data) is changes.valid -@pytest.mark.parametrize('changes', [ - SimpleNamespace(valid=True, to_a=1, to_b=4, at_a=7, at_b=4), - SimpleNamespace(valid=False, to_a=1, to_b=4, at_a=6, at_b=4), - SimpleNamespace(valid=False, to_a=1, to_b=4, at_a=6, at_b=5), - SimpleNamespace(valid=False, to_a=2, to_b=4, at_a=6, at_b=5), - SimpleNamespace(valid=False, to_a=1, to_b=3, at_a=6, at_b=5), -]) +@pytest.mark.parametrize( + "changes", + [ + SimpleNamespace(valid=True, to_a=1, to_b=4, at_a=7, at_b=4), + SimpleNamespace(valid=False, to_a=1, to_b=4, at_a=6, at_b=4), + SimpleNamespace(valid=False, to_a=1, to_b=4, at_a=6, at_b=5), + SimpleNamespace(valid=False, to_a=2, to_b=4, at_a=6, at_b=5), + SimpleNamespace(valid=False, to_a=1, to_b=3, at_a=6, at_b=5), + ], +) def test_flow_tracking_labels_diverge(changes): """ Tests a case of divergence. Diverge to a, b with different ap diffs, then converge at c. """ flow_tracking = FlowTracking() - flow_tracking.add_flow_to_label(ScopedName.from_string('a'), changes.to_a) - flow_tracking.add_flow_to_label(ScopedName.from_string('b'), changes.to_b) + flow_tracking.add_flow_to_label(ScopedName.from_string("a"), changes.to_a) + flow_tracking.add_flow_to_label(ScopedName.from_string("b"), changes.to_b) # Label a. flow_tracking.revoke() - flow_tracking.converge_with_label(ScopedName.from_string('a')) + flow_tracking.converge_with_label(ScopedName.from_string("a")) flow_tracking.add_ap(changes.at_a) data_after_a = flow_tracking.get() - flow_tracking.add_flow_to_label(ScopedName.from_string('c'), 0) + flow_tracking.add_flow_to_label(ScopedName.from_string("c"), 0) # Label b. flow_tracking.revoke() - flow_tracking.converge_with_label(ScopedName.from_string('b')) + flow_tracking.converge_with_label(ScopedName.from_string("b")) flow_tracking.add_ap(changes.at_b) data_after_b = flow_tracking.get() - flow_tracking.add_flow_to_label(ScopedName.from_string('c'), 0) + flow_tracking.add_flow_to_label(ScopedName.from_string("c"), 0) # Label c. flow_tracking.revoke() - flow_tracking.converge_with_label(ScopedName.from_string('c')) + flow_tracking.converge_with_label(ScopedName.from_string("c")) data_at_c = flow_tracking.get() if changes.valid: @@ -87,49 +93,54 @@ def test_flow_tracking_labels_diverge(changes): assert data_after_a != data_at_c and data_after_b != data_at_c -@pytest.mark.parametrize('refs', [ - SimpleNamespace(valid=True, expr_a=parse_expr('[fp+3]*2'), expr_b=parse_expr('[fp+3]*2')), - SimpleNamespace(valid=False, expr_a=parse_expr('[fp+3]*2'), expr_b=parse_expr('[fp+2]*2')), - SimpleNamespace(valid=True, expr_a=parse_expr('[ap-3]*2'), expr_b=parse_expr('[ap-1]*2')), - SimpleNamespace(valid=False, expr_a=parse_expr('[ap-3]*2'), expr_b=parse_expr('[ap-3]*2')), -]) +@pytest.mark.parametrize( + "refs", + [ + SimpleNamespace(valid=True, expr_a=parse_expr("[fp+3]*2"), expr_b=parse_expr("[fp+3]*2")), + SimpleNamespace(valid=False, expr_a=parse_expr("[fp+3]*2"), expr_b=parse_expr("[fp+2]*2")), + SimpleNamespace(valid=True, expr_a=parse_expr("[ap-3]*2"), expr_b=parse_expr("[ap-1]*2")), + SimpleNamespace(valid=False, expr_a=parse_expr("[ap-3]*2"), expr_b=parse_expr("[ap-3]*2")), + ], +) def test_flow_tracking_converge_references(refs): flow_tracking = FlowTracking() - flow_tracking.add_flow_to_label(ScopedName.from_string('a'), RegChangeUnknown()) - flow_tracking.add_flow_to_label(ScopedName.from_string('b'), RegChangeUnknown()) + flow_tracking.add_flow_to_label(ScopedName.from_string("a"), RegChangeUnknown()) + flow_tracking.add_flow_to_label(ScopedName.from_string("b"), RegChangeUnknown()) # Label a. flow_tracking.revoke() - flow_tracking.converge_with_label(ScopedName.from_string('a')) + flow_tracking.converge_with_label(ScopedName.from_string("a")) flow_tracking.add_reference( - name=ScopedName.from_string('x'), + name=ScopedName.from_string("x"), ref=Reference( pc=0, value=refs.expr_a, ap_tracking_data=flow_tracking.get_ap_tracking(), - )) + ), + ) flow_tracking.add_ap(13) - flow_tracking.add_flow_to_label(ScopedName.from_string('c'), 0) + flow_tracking.add_flow_to_label(ScopedName.from_string("c"), 0) # Label b. flow_tracking.revoke() - flow_tracking.converge_with_label(ScopedName.from_string('b')) + flow_tracking.converge_with_label(ScopedName.from_string("b")) flow_tracking.add_reference( - name=ScopedName.from_string('x'), + name=ScopedName.from_string("x"), ref=Reference( pc=0, value=refs.expr_b, ap_tracking_data=flow_tracking.get_ap_tracking(), - )) + ), + ) flow_tracking.add_ap(15) - flow_tracking.add_flow_to_label(ScopedName.from_string('c'), 0) + flow_tracking.add_flow_to_label(ScopedName.from_string("c"), 0) # Label c - convergence. flow_tracking.revoke() - flow_tracking.converge_with_label(ScopedName.from_string('c')) + flow_tracking.converge_with_label(ScopedName.from_string("c")) if refs.valid: - flow_tracking.resolve_reference(ScopedName.from_string('x')) + flow_tracking.resolve_reference(ScopedName.from_string("x")) else: with pytest.raises(FlowTrackingError): - flow_tracking.resolve_reference(ScopedName.from_string('x')) + flow_tracking.resolve_reference(ScopedName.from_string("x")) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py index 7d909dca..3fa96fc2 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor.py @@ -2,12 +2,21 @@ from typing import Dict, Optional from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction from starkware.cairo.lang.compiler.ast.visitor import Visitor from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.identifier_definition import ( - DefinitionError, FutureIdentifierDefinition, IdentifierDefinition, StructDefinition) + DefinitionError, + FutureIdentifierDefinition, + IdentifierDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError @@ -28,12 +37,16 @@ def __init__(self, identifiers: Optional[IdentifierManager] = None): def handle_missing_future_definition(self, name: ScopedName, location): raise PreprocessorError( - f"Identifier '{name}' not found by IdentifierCollector.", - location=location) + f"Identifier '{name}' not found by IdentifierCollector.", location=location + ) def add_name_definition( - self, name: ScopedName, identifier_definition: IdentifierDefinition, location, - require_future_definition=True): + self, + name: ScopedName, + identifier_definition: IdentifierDefinition, + location, + require_future_definition=True, + ): """ Adds a definition of an identifier named 'name' at 'location'. The identifier must already be found as a FutureIdentifierDefinition in 'self.identifiers' @@ -52,21 +65,22 @@ def add_name_definition( f"Identifier '{name}' expected to be of type " f"'{future_definition.identifier_type.__name__}', not " f"'{type(identifier_definition).__name__}'.", - location=location) + location=location, + ) self.identifiers.add_identifier(name, identifier_definition) self.identifier_locations[name] = location def get_struct_definition( - self, name: ScopedName, location: Optional[Location]) -> StructDefinition: + self, name: ScopedName, location: Optional[Location] + ) -> StructDefinition: """ Returns the struct definition that corresponds to the given identifier. location is used if there is an error. """ try: - res = self.identifiers.search( - accessible_scopes=self.accessible_scopes, name=name) + res = self.identifiers.search(accessible_scopes=self.accessible_scopes, name=name) res.assert_fully_parsed() except IdentifierError as exc: raise PreprocessorError(str(exc), location=location) @@ -76,7 +90,8 @@ def get_struct_definition( raise PreprocessorError( f"""\ Expected '{res.canonical_name}' to be a {StructDefinition.TYPE}. Found: '{struct_def.TYPE}'.""", - location=location) + location=location, + ) return struct_def @@ -103,8 +118,7 @@ def get_canonical_struct_name(self, scoped_name: ScopedName, location: Optional[ location is used if there is an error. """ - result = self.identifiers.search( - self.accessible_scopes, scoped_name) + result = self.identifiers.search(self.accessible_scopes, scoped_name) canonical_name = result.get_canonical_name() identifier_def = result.identifier_definition @@ -116,7 +130,8 @@ def get_canonical_struct_name(self, scoped_name: ScopedName, location: Optional[ raise PreprocessorError( f"""\ Expected '{scoped_name}' to be a {StructDefinition.TYPE}. Found: '{identifier_type}'.""", - location=location) + location=location, + ) return canonical_name @@ -135,16 +150,18 @@ def resolve_type(self, cairo_type: CairoType) -> CairoType: return dataclasses.replace( cairo_type, scope=self.get_canonical_struct_name( - scoped_name=cairo_type.scope, location=cairo_type.location), - is_fully_resolved=True) + scoped_name=cairo_type.scope, location=cairo_type.location + ), + is_fully_resolved=True, + ) except IdentifierError as exc: raise PreprocessorError(str(exc), location=cairo_type.location) elif isinstance(cairo_type, TypeTuple): return dataclasses.replace( - cairo_type, - members=[self.resolve_type(subtype) for subtype in cairo_type.members]) + cairo_type, members=[self.resolve_type(subtype) for subtype in cairo_type.members] + ) else: - raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') + raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") def get_struct_size(self, struct_name: ScopedName, location: Optional[Location]): return self.get_struct_definition(name=struct_name, location=location).size @@ -159,16 +176,18 @@ def get_size(self, cairo_type: CairoType): if cairo_type.is_fully_resolved: try: return get_struct_definition( - struct_name=cairo_type.scope, identifier_manager=self.identifiers).size + struct_name=cairo_type.scope, identifier_manager=self.identifiers + ).size except DefinitionError as exc: raise PreprocessorError(str(exc), location=cairo_type.location) else: return self.get_struct_size( - struct_name=cairo_type.scope, location=cairo_type.location) + struct_name=cairo_type.scope, location=cairo_type.location + ) elif isinstance(cairo_type, TypeTuple): return sum(self.get_size(member_type) for member_type in cairo_type.members) else: - raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') + raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") def inside_a_struct(self) -> bool: if len(self.parents) == 0: @@ -178,4 +197,4 @@ def inside_a_struct(self) -> bool: if not isinstance(parent, CodeElementFunction): return False - return parent.element_type == 'struct' + return parent.element_type == "struct" diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py index 2153a8d4..dad52667 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_aware_visitor_test.py @@ -2,7 +2,8 @@ from starkware.cairo.lang.compiler.identifier_definition import ConstDefinition from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( - IdentifierAwareVisitor) + IdentifierAwareVisitor, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -10,14 +11,20 @@ def test_add_name_definition_no_future(): visitor = IdentifierAwareVisitor() - test_id = ScopedName.from_string('test_id') + test_id = ScopedName.from_string("test_id") location = None visitor.add_name_definition( - name=test_id, identifier_definition=ConstDefinition(value=1), location=location, - require_future_definition=False) + name=test_id, + identifier_definition=ConstDefinition(value=1), + location=location, + require_future_definition=False, + ) with pytest.raises(PreprocessorError, match=f"Redefinition of 'test_id'."): visitor.add_name_definition( - name=test_id, identifier_definition=ConstDefinition(value=1), location=location, - require_future_definition=False) + name=test_id, + identifier_definition=ConstDefinition(value=1), + location=location, + require_future_definition=False, + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py index 9cdcdd77..6069d370 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector.py @@ -2,15 +2,32 @@ from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeBlock, CodeElement, CodeElementConst, CodeElementFunction, CodeElementIf, CodeElementImport, - CodeElementLabel, CodeElementLocalVariable, CodeElementReference, - CodeElementReturnValueReference, CodeElementTemporaryVariable, CodeElementUnpackBinding, - CodeElementWith) + CodeBlock, + CodeElement, + CodeElementConst, + CodeElementFunction, + CodeElementIf, + CodeElementImport, + CodeElementLabel, + CodeElementLocalVariable, + CodeElementReference, + CodeElementReturnValueReference, + CodeElementTemporaryVariable, + CodeElementUnpackBinding, + CodeElementWith, +) from starkware.cairo.lang.compiler.ast.visitor import Visitor from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, ConstDefinition, FunctionDefinition, FutureIdentifierDefinition, - IdentifierDefinition, LabelDefinition, ReferenceDefinition, StructDefinition) + AliasDefinition, + ConstDefinition, + FunctionDefinition, + FutureIdentifierDefinition, + IdentifierDefinition, + LabelDefinition, + ReferenceDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.preprocessor.local_variables import N_LOCALS_CONSTANT from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError @@ -22,12 +39,13 @@ def _get_identifier(obj): Gets the name of the identifier defined by an object with either an 'identifier' attribute or a 'typed_identifier' attribute. """ - if hasattr(obj, 'identifier'): + if hasattr(obj, "identifier"): return obj.identifier - if hasattr(obj, 'typed_identifier'): + if hasattr(obj, "typed_identifier"): return obj.typed_identifier raise AttributeError( - f"Object of type '{type(obj).__name__}' has no 'identifier' or 'typed_identifier'.") + f"Object of type '{type(obj).__name__}' has no 'identifier' or 'typed_identifier'." + ) class IdentifierCollector(Visitor): @@ -35,6 +53,7 @@ class IdentifierCollector(Visitor): Collects all the identifiers in a code element. Uses a partial visitor. """ + # A dict from code element types to the identifier type they define. IDENTIFIER_DEFINERS = { CodeElementConst: ConstDefinition, @@ -50,26 +69,33 @@ def __init__(self, identifiers: Optional[IdentifierManager] = None): self.identifiers = IdentifierManager() if identifiers is None else identifiers def add_identifier( - self, name: ScopedName, identifier_definition: IdentifierDefinition, - location: Optional[Location]): + self, + name: ScopedName, + identifier_definition: IdentifierDefinition, + location: Optional[Location], + ): """ Adds an identifier with name 'name' and the given identifier definition at location 'location'. """ existing_definition = self.identifiers.get_by_full_name(name) if existing_definition is not None: - if not isinstance(existing_definition, FutureIdentifierDefinition) or \ - not isinstance(identifier_definition, FutureIdentifierDefinition): + if not isinstance(existing_definition, FutureIdentifierDefinition) or not isinstance( + identifier_definition, FutureIdentifierDefinition + ): raise PreprocessorError(f"Redefinition of '{name}'.", location=location) if (existing_definition.identifier_type, identifier_definition.identifier_type) != ( - ReferenceDefinition, ReferenceDefinition): + ReferenceDefinition, + ReferenceDefinition, + ): # Redefinition is only allowed in reference rebinding. raise PreprocessorError(f"Redefinition of '{name}'.", location=location) self.identifiers.add_identifier(name, identifier_definition) def add_future_identifier( - self, name: ScopedName, identifier_type: type, location: Optional[Location]): + self, name: ScopedName, identifier_type: type, location: Optional[Location] + ): """ Adds a future identifier with name 'name' of type 'identifier_type' at location 'location'. """ @@ -77,21 +103,22 @@ def add_future_identifier( self.add_identifier( name=name, identifier_definition=FutureIdentifierDefinition(identifier_type=identifier_type), - location=location) + location=location, + ) def visit(self, obj): if type(obj) in self.IDENTIFIER_DEFINERS: definition_type = self.IDENTIFIER_DEFINERS[type(obj)] identifier = _get_identifier(obj) self.add_future_identifier( - self.current_scope + identifier.name, - definition_type, - identifier.location) + self.current_scope + identifier.name, definition_type, identifier.location + ) return super().visit(obj) def _visit_default(self, obj): - assert isinstance(obj, (CodeBlock, CodeElement)), \ - f'Received unexpected object of type {type(obj).__name__}.' + assert isinstance( + obj, (CodeBlock, CodeElement) + ), f"Received unexpected object of type {type(obj).__name__}." def visit_CodeElementFunction(self, elm: CodeElementFunction): """ @@ -99,9 +126,8 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): visits the code block contained in the function. """ function_scope = self.current_scope + elm.name - if elm.element_type == 'struct': - self.add_future_identifier( - function_scope, StructDefinition, elm.identifier.location) + if elm.element_type == "struct": + self.add_future_identifier(function_scope, StructDefinition, elm.identifier.location) return args_scope = function_scope + CodeElementFunction.ARGUMENT_SCOPE @@ -114,10 +140,12 @@ def handle_struct_def(identifier_list: Optional[IdentifierList], struct_name: Sc location = identifier_list.location self.add_future_identifier( - name=struct_name, identifier_type=StructDefinition, location=location) + name=struct_name, identifier_type=StructDefinition, location=location + ) def handle_function_arguments( - identifier_list: Optional[IdentifierList], struct_name: ScopedName): + identifier_list: Optional[IdentifierList], struct_name: ScopedName + ): handle_struct_def(identifier_list=identifier_list, struct_name=struct_name) if identifier_list is None: return @@ -126,15 +154,18 @@ def handle_function_arguments( if arg_id.name == N_LOCALS_CONSTANT: raise PreprocessorError( f"The name '{N_LOCALS_CONSTANT}' is reserved and cannot be used as an " - 'argument name.', - location=arg_id.location) + "argument name.", + location=arg_id.location, + ) # Within a function, arguments are also accessible directly. self.add_future_identifier( - function_scope + arg_id.name, ReferenceDefinition, arg_id.location) + function_scope + arg_id.name, ReferenceDefinition, arg_id.location + ) handle_function_arguments(identifier_list=elm.arguments, struct_name=args_scope) handle_function_arguments( - identifier_list=elm.implicit_arguments, struct_name=implicit_args_scope) + identifier_list=elm.implicit_arguments, struct_name=implicit_args_scope + ) handle_struct_def(identifier_list=elm.returns, struct_name=rets_scope) @@ -148,17 +179,18 @@ def handle_function_arguments( for arg_id in arg_and_return_identifiers: if arg_id.name in implicit_arg_names: raise PreprocessorError( - 'Arguments and return values cannot have the same name of an implicit ' - 'argument.', - location=arg_id.location) + "Arguments and return values cannot have the same name of an implicit " + "argument.", + location=arg_id.location, + ) - ident_type = FunctionDefinition if elm.element_type == 'func' else LabelDefinition - self.add_future_identifier( - function_scope, ident_type, elm.identifier.location) + ident_type = FunctionDefinition if elm.element_type == "func" else LabelDefinition + self.add_future_identifier(function_scope, ident_type, elm.identifier.location) # Add SIZEOF_LOCALS for current block at identifier definition location if available. self.add_future_identifier( - function_scope + N_LOCALS_CONSTANT, ConstDefinition, elm.identifier.location) + function_scope + N_LOCALS_CONSTANT, ConstDefinition, elm.identifier.location + ) super().visit_CodeElementFunction(elm) def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): @@ -166,23 +198,25 @@ def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): Registers all the unpacked identifiers. """ for identifier in elm.unpacking_list.identifiers: - if identifier.name == '_': + if identifier.name == "_": continue self.add_future_identifier( - self.current_scope + - identifier.name, - ReferenceDefinition, - identifier.location) + self.current_scope + identifier.name, ReferenceDefinition, identifier.location + ) def visit_CodeElementIf(self, obj: CodeElementIf): assert obj.label_neq is not None assert obj.label_end is not None self.add_future_identifier( - name=self.current_scope + obj.label_neq, identifier_type=LabelDefinition, - location=obj.location) + name=self.current_scope + obj.label_neq, + identifier_type=LabelDefinition, + location=obj.location, + ) self.add_future_identifier( - name=self.current_scope + obj.label_end, identifier_type=LabelDefinition, - location=obj.location) + name=self.current_scope + obj.label_end, + identifier_type=LabelDefinition, + location=obj.location, + ) self.visit(obj.main_code_block) if obj.else_code_block is not None: self.visit(obj.else_code_block) @@ -207,13 +241,15 @@ def visit_CodeElementImport(self, elm: CodeElementImport): raise PreprocessorError( f"Cannot import '{import_item.orig_identifier.name}' " f"from '{elm.path.name}'.", - location=import_item.orig_identifier.location) + location=import_item.orig_identifier.location, + ) # Add alias to identifiers. self.add_identifier( name=self.current_scope + local_identifier.name, identifier_definition=AliasDefinition(destination=alias_dst), - location=import_item.identifier.location) + location=import_item.identifier.location, + ) def visit_CodeElementWith(self, elm: CodeElementWith): for aliased_identifier in elm.identifiers: @@ -221,5 +257,6 @@ def visit_CodeElementWith(self, elm: CodeElementWith): self.add_future_identifier( name=self.current_scope + aliased_identifier.local_name.name, identifier_type=ReferenceDefinition, - location=aliased_identifier.local_name.location) + location=aliased_identifier.local_name.location, + ) self.visit(elm.code_block) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py index 75cf3d4e..1ca8f304 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/identifier_collector_test.py @@ -1,6 +1,11 @@ from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, ConstDefinition, FunctionDefinition, LabelDefinition, ReferenceDefinition, - StructDefinition) + AliasDefinition, + ConstDefinition, + FunctionDefinition, + LabelDefinition, + ReferenceDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.parser import parse_file from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import verify_exception @@ -17,7 +22,8 @@ def _extract_identifiers(code): collector.visit(ast.code_block) return [ (str(name), identifier_definition.identifier_type) - for name, identifier_definition in collector.identifiers.as_dict().items()] + for name, identifier_definition in collector.identifiers.as_dict().items() + ] def test_collect_single_binds(): @@ -30,12 +36,12 @@ def test_collect_single_binds(): let g : H = f(1, 2, 3) """ assert set(_extract_identifiers(code)) == { - ('a', ReferenceDefinition), - ('b', ConstDefinition), - ('c', ReferenceDefinition), - ('d', ReferenceDefinition), - ('f', LabelDefinition), - ('g', ReferenceDefinition), + ("a", ReferenceDefinition), + ("b", ConstDefinition), + ("c", ReferenceDefinition), + ("d", ReferenceDefinition), + ("f", LabelDefinition), + ("g", ReferenceDefinition), } @@ -47,15 +53,15 @@ def test_collect_multi_binds(): let (e, f) = g() """ assert set(_extract_identifiers(code)) == { - ('a', FunctionDefinition), - ('a.SIZEOF_LOCALS', ConstDefinition), - ('a.Args', StructDefinition), - ('a.ImplicitArgs', StructDefinition), - ('a.Return', StructDefinition), - ('a.b', ReferenceDefinition), - ('a.c', ReferenceDefinition), - ('e', ReferenceDefinition), - ('f', ReferenceDefinition), + ("a", FunctionDefinition), + ("a.SIZEOF_LOCALS", ConstDefinition), + ("a.Args", StructDefinition), + ("a.ImplicitArgs", StructDefinition), + ("a.Return", StructDefinition), + ("a.b", ReferenceDefinition), + ("a.c", ReferenceDefinition), + ("e", ReferenceDefinition), + ("f", ReferenceDefinition), } @@ -69,21 +75,21 @@ def test_nested_funcs(): end """ assert set(_extract_identifiers(code)) == { - ('foo', FunctionDefinition), - ('foo.SIZEOF_LOCALS', ConstDefinition), - ('foo.Args', StructDefinition), - ('foo.ImplicitArgs', StructDefinition), - ('foo.Return', StructDefinition), - ('foo.x', ReferenceDefinition), - ('foo.z', ReferenceDefinition), - ('foo.a', ReferenceDefinition), - ('foo.bar', FunctionDefinition), - ('foo.bar.SIZEOF_LOCALS', ConstDefinition), - ('foo.bar.Args', StructDefinition), - ('foo.bar.ImplicitArgs', StructDefinition), - ('foo.bar.Return', StructDefinition), - ('foo.bar.y', ReferenceDefinition), - ('foo.bar.b', ReferenceDefinition), + ("foo", FunctionDefinition), + ("foo.SIZEOF_LOCALS", ConstDefinition), + ("foo.Args", StructDefinition), + ("foo.ImplicitArgs", StructDefinition), + ("foo.Return", StructDefinition), + ("foo.x", ReferenceDefinition), + ("foo.z", ReferenceDefinition), + ("foo.a", ReferenceDefinition), + ("foo.bar", FunctionDefinition), + ("foo.bar.SIZEOF_LOCALS", ConstDefinition), + ("foo.bar.Args", StructDefinition), + ("foo.bar.ImplicitArgs", StructDefinition), + ("foo.bar.Return", StructDefinition), + ("foo.bar.y", ReferenceDefinition), + ("foo.bar.b", ReferenceDefinition), } @@ -93,31 +99,37 @@ def test_redefinition(): local name = [ap] """ assert _extract_identifiers(code) == [ - ('name', ReferenceDefinition), + ("name", ReferenceDefinition), ] def test_redefinition_failures(): - verify_exception(""" + verify_exception( + """ name: local name = [ap] -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.name'. local name = [ap] ^**^ -""") +""", + ) def test_imports(): collector = IdentifierCollector() collector.identifiers.add_identifier( - ScopedName.from_string('foo.bar'), ConstDefinition(value=0)) - ast = parse_file(""" + ScopedName.from_string("foo.bar"), ConstDefinition(value=0) + ) + ast = parse_file( + """ from foo import bar as bar0 -""") +""" + ) with collector.scoped(ScopedName(), parent=ast): collector.visit(ast.code_block) assert collector.identifiers.get_scope(ScopedName()).identifiers == { - 'bar0': AliasDefinition(destination=ScopedName.from_string('foo.bar')), + "bar0": AliasDefinition(destination=ScopedName.from_string("foo.bar")), } diff --git a/src/starkware/cairo/lang/compiler/preprocessor/local_variables.py b/src/starkware/cairo/lang/compiler/preprocessor/local_variables.py index 43ba7936..3567d56b 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/local_variables.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/local_variables.py @@ -3,9 +3,20 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, CastType from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeBlock, CodeElement, CodeElementAllocLocals, CodeElementCompoundAssertEq, CodeElementConst, - CodeElementIf, CodeElementInstruction, CodeElementLocalVariable, CodeElementReference, - CodeElementStaticAssert, CodeElementUnpackBinding, CodeElementWith, CommentedCodeElement) + CodeBlock, + CodeElement, + CodeElementAllocLocals, + CodeElementCompoundAssertEq, + CodeElementConst, + CodeElementIf, + CodeElementInstruction, + CodeElementLocalVariable, + CodeElementReference, + CodeElementStaticAssert, + CodeElementUnpackBinding, + CodeElementWith, + CommentedCodeElement, +) from starkware.cairo.lang.compiler.ast.expr import ExprCast, ExprConst, ExprIdentifier from starkware.cairo.lang.compiler.ast.instructions import AddApInstruction, InstructionAst from starkware.cairo.lang.compiler.error_handling import Location @@ -17,7 +28,7 @@ from starkware.cairo.lang.compiler.references import create_simple_ref_expr from starkware.cairo.lang.compiler.scoped_name import ScopedName -N_LOCALS_CONSTANT = 'SIZEOF_LOCALS' +N_LOCALS_CONSTANT = "SIZEOF_LOCALS" class NLocalsUsedVisitor(ExpressionTransformer): @@ -41,11 +52,13 @@ class LocalVariableHandler: """ def __init__( - self, - new_unique_id_callback: Callable[[], str], - get_size_callback: Callable[[CairoType], int], - get_unpacking_struct_definition_callback: Callable[ - [CodeElementUnpackBinding], StructDefinition]): + self, + new_unique_id_callback: Callable[[], str], + get_size_callback: Callable[[CairoType], int], + get_unpacking_struct_definition_callback: Callable[ + [CodeElementUnpackBinding], StructDefinition + ], + ): # The size of the local variables in this scope. self.local_vars_size: int = 0 @@ -62,7 +75,7 @@ def alloc_unique_id(self) -> str: return self.new_unique_id_callback() def visit(self, obj): - funcname = f'visit_{type(obj).__name__}' + funcname = f"visit_{type(obj).__name__}" if hasattr(self, funcname): return getattr(self, funcname)(obj) else: @@ -71,16 +84,18 @@ def visit(self, obj): def visit_CodeElementIf(self, obj: CodeElementIf): obj = dataclasses.replace(obj, main_code_block=self.visit(obj.main_code_block)) if obj.else_code_block is not None: - obj = dataclasses.replace( - obj, else_code_block=self.visit(obj.else_code_block)) + obj = dataclasses.replace(obj, else_code_block=self.visit(obj.else_code_block)) return [obj] def visit_CodeBlock(self, obj: CodeBlock): new_commented_code_elements = [] for code_element in obj.code_elements: for new_elm in self.visit(code_element.code_elm): - new_commented_code_elements.append(CommentedCodeElement( - code_elm=new_elm, comment=None, location=code_element.location)) + new_commented_code_elements.append( + CommentedCodeElement( + code_elm=new_elm, comment=None, location=code_element.location + ) + ) return dataclasses.replace(obj, code_elements=new_commented_code_elements) @@ -106,7 +121,7 @@ def visit_CodeElementAllocLocals(self, elm: CodeElementAllocLocals) -> List[Code location=location, ), inc_ap=False, - location=location + location=location, ), ) # Return the original element so that the preprocessor can check that ap was not advanced. @@ -128,18 +143,25 @@ def visit_CodeElementLocalVariable(self, elm: CodeElementLocalVariable) -> List[ result: List[CodeElement] = [] if elm.expr is not None: - result.append(CodeElementCompoundAssertEq( - a=ref_expr, - b=ExprCast( - expr=elm.expr, dest_type=local_type, cast_type=CastType.ASSIGN, - location=elm.expr.location), - location=elm.location)) + result.append( + CodeElementCompoundAssertEq( + a=ref_expr, + b=ExprCast( + expr=elm.expr, + dest_type=local_type, + cast_type=CastType.ASSIGN, + location=elm.expr.location, + ), + location=elm.location, + ) + ) result.append( CodeElementReference( typed_identifier=elm.typed_identifier, expr=ref_expr, - )) + ) + ) self.local_vars_size += self.get_size_callback(local_type) return result @@ -158,8 +180,9 @@ def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): struct_def = self.get_unpacking_struct_definition_callback(elm) unpacking_identifiers = [] for typed_identifier, member_def in zip( - elm.unpacking_list.identifiers, struct_def.members.values()): - if typed_identifier.modifier is None or typed_identifier.modifier.name != 'local': + elm.unpacking_list.identifiers, struct_def.members.values() + ): + if typed_identifier.modifier is None or typed_identifier.modifier.name != "local": unpacking_identifiers.append(typed_identifier) continue @@ -168,23 +191,35 @@ def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): # Add type if missing. if typed_identifier.expr_type is None: typed_identifier = dataclasses.replace( - typed_identifier, expr_type=member_def.cairo_type) + typed_identifier, expr_type=member_def.cairo_type + ) temp_ref = dataclasses.replace( typed_identifier, identifier=ExprIdentifier(name=self.alloc_unique_id()), - modifier=None) + modifier=None, + ) unpacking_identifiers.append(temp_ref) - result.extend(self.visit(CodeElementLocalVariable( - typed_identifier=typed_identifier.strip_modifier(), - expr=temp_ref.identifier, - location=typed_identifier.location, - ))) - - result.insert(0, dataclasses.replace( - elm, unpacking_list=dataclasses.replace( - elm.unpacking_list, identifiers=unpacking_identifiers))) + result.extend( + self.visit( + CodeElementLocalVariable( + typed_identifier=typed_identifier.strip_modifier(), + expr=temp_ref.identifier, + location=typed_identifier.location, + ) + ) + ) + + result.insert( + 0, + dataclasses.replace( + elm, + unpacking_list=dataclasses.replace( + elm.unpacking_list, identifiers=unpacking_identifiers + ), + ), + ) return result @@ -193,12 +228,15 @@ def visit_CodeElementWith(self, elm: CodeElementWith): def preprocess_local_variables( - code_elements: List[CodeElement], scope: ScopedName, - new_unique_id_callback: Callable[[], str], - get_size_callback: Callable[[CairoType], int], - get_unpacking_struct_definition_callback: Callable[ - [CodeElementUnpackBinding], StructDefinition], - default_location: Optional[Location]) -> List[CodeElement]: + code_elements: List[CodeElement], + scope: ScopedName, + new_unique_id_callback: Callable[[], str], + get_size_callback: Callable[[CairoType], int], + get_unpacking_struct_definition_callback: Callable[ + [CodeElementUnpackBinding], StructDefinition + ], + default_location: Optional[Location], +) -> List[CodeElement]: """ Preprocesses the local variables of one function. new_unique_id_callback is a callback that allocates a unique identifier. @@ -207,19 +245,22 @@ def preprocess_local_variables( handler = LocalVariableHandler( new_unique_id_callback=new_unique_id_callback, get_size_callback=get_size_callback, - get_unpacking_struct_definition_callback=get_unpacking_struct_definition_callback) + get_unpacking_struct_definition_callback=get_unpacking_struct_definition_callback, + ) result = [] for elm in code_elements: result += handler.visit(elm) n_locals_code_element = CodeElementConst( identifier=ExprIdentifier(name=N_LOCALS_CONSTANT, location=default_location), - expr=ExprConst(val=handler.local_vars_size, location=default_location)) + expr=ExprConst(val=handler.local_vars_size, location=default_location), + ) if handler.local_vars_size > 0 and not handler.n_locals_used_visitor.saw_n_locals_const: raise PreprocessorError( - 'A function with local variables must use alloc_locals.', - location=handler.first_location) + "A function with local variables must use alloc_locals.", + location=handler.first_location, + ) result.insert(0, n_locals_code_element) return result diff --git a/src/starkware/cairo/lang/compiler/preprocessor/local_variables_test.py b/src/starkware/cairo/lang/compiler/preprocessor/local_variables_test.py index 4e2714e6..236a63ed 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/local_variables_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/local_variables_test.py @@ -1,5 +1,8 @@ from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( - PRIME, preprocess_str, verify_exception) + PRIME, + preprocess_str, + verify_exception, +) from starkware.cairo.lang.compiler.type_casts import CairoTypeError @@ -30,7 +33,9 @@ def test_local_variable(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ap += 11 [fp + 3] = [fp] * [fp + 1] [fp] = [fp + 1] @@ -42,6 +47,7 @@ def test_local_variable(): ap += 0 ret """ + ) def test_local_variable_unpack_binding(): @@ -66,7 +72,9 @@ def test_local_variable_unpack_binding(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ret ap += 2 call rel -3 @@ -78,6 +86,7 @@ def test_local_variable_unpack_binding(): [fp + 1] = [fp + 1] ret """ + ) def test_local_rebinding(): @@ -92,7 +101,9 @@ def test_local_rebinding(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ap += 4 [fp] = 5 [fp + 1] = [fp] * [fp] @@ -100,6 +111,7 @@ def test_local_rebinding(): [fp + 3] = [fp + 2] * [fp + 2] ret """ + ) def test_n_locals_used_in_static_assert(): @@ -111,63 +123,82 @@ def test_n_locals_used_in_static_assert(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ret """ + ) def test_local_variable_failures(): - verify_exception(""" + verify_exception( + """ func main(SIZEOF_LOCALS): static_assert SIZEOF_LOCALS == SIZEOF_LOCALS local x end -""", """ +""", + """ file:?:?: The name 'SIZEOF_LOCALS' is reserved and cannot be used as an argument name. func main(SIZEOF_LOCALS): ^***********^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func main(): local x end -""", """ +""", + """ file:?:?: A function with local variables must use alloc_locals. local x ^*****^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func main(): alloc_locals local x = x + x end -""", """ +""", + """ file:?:?: Identifier 'x' referenced before definition. local x = x + x ^ -""") - for inst in ['tempvar a = 0', 'ret', 'ap += [ap]']: - verify_exception(f""" +""", + ) + for inst in ["tempvar a = 0", "ret", "ap += [ap]"]: + verify_exception( + f""" func main(): {inst} alloc_locals end -""", """ +""", + """ file:?:?: alloc_locals must be used before any instruction that changes the ap register. alloc_locals ^**********^ -""") - verify_exception(f""" +""", + ) + verify_exception( + f""" alloc_locals -""", """ +""", + """ file:?:?: alloc_locals cannot be used outside of a function. alloc_locals ^**********^ -""") +""", + ) def test_local_variable_type_failures(): - verify_exception(f""" + verify_exception( + f""" struct T: member a : felt end @@ -177,20 +208,26 @@ def test_local_variable_type_failures(): local x : T* = [ap] ret end -""", """ +""", + """ file:?:?: Cannot cast 'felt' to 'test_scope.T*'. local x : T* = [ap] ^**^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) def test_local_variable_modifier_failures(): - verify_exception(""" + verify_exception( + """ func main(): local local x end -""", """ +""", + """ file:?:?: Unexpected modifier 'local'. local local x ^***^ -""") +""", + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py b/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py index 066bd7b1..bb31512f 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/pass_manager.py @@ -52,7 +52,7 @@ def run(self, context: PassManagerContext): def get_stage_index(self, name: str): assert name in self.stage_names - index, = [i for i, (stage_name, _) in enumerate(self.stages) if stage_name == name] + (index,) = [i for i, (stage_name, _) in enumerate(self.stages) if stage_name == name] return index # Functions for manipulating the stages: diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py index 79b29c8d..853c1423 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocess_codes.py @@ -7,8 +7,10 @@ def preprocess_codes( - codes: Sequence[Tuple[str, str]], pass_manager: PassManager, - main_scope: ScopedName = ScopedName()) -> PreprocessedProgram: + codes: Sequence[Tuple[str, str]], + pass_manager: PassManager, + main_scope: ScopedName = ScopedName(), +) -> PreprocessedProgram: """ Preprocesses a list of Cairo files and returns a PreprocessedProgram instance. codes is a list of pairs (code_string, file_name). diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py index e8832d22..c02a26e1 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor.py @@ -5,52 +5,122 @@ from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, CastType, TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, + CastType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.code_elements import ( - BuiltinsDirective, CodeBlock, CodeElement, CodeElementAllocLocals, CodeElementCompoundAssertEq, - CodeElementConst, CodeElementDirective, CodeElementEmptyLine, CodeElementFuncCall, - CodeElementFunction, CodeElementHint, CodeElementIf, CodeElementImport, CodeElementInstruction, - CodeElementLabel, CodeElementLocalVariable, CodeElementMember, CodeElementReference, - CodeElementReturn, CodeElementReturnValueReference, CodeElementStaticAssert, - CodeElementTailCall, CodeElementTemporaryVariable, CodeElementUnpackBinding, CodeElementWith, - LangDirective) + BuiltinsDirective, + CodeBlock, + CodeElement, + CodeElementAllocLocals, + CodeElementCompoundAssertEq, + CodeElementConst, + CodeElementDirective, + CodeElementEmptyLine, + CodeElementFuncCall, + CodeElementFunction, + CodeElementHint, + CodeElementIf, + CodeElementImport, + CodeElementInstruction, + CodeElementLabel, + CodeElementLocalVariable, + CodeElementMember, + CodeElementReference, + CodeElementReturn, + CodeElementReturnValueReference, + CodeElementStaticAssert, + CodeElementTailCall, + CodeElementTemporaryVariable, + CodeElementUnpackBinding, + CodeElementWith, + LangDirective, +) from starkware.cairo.lang.compiler.ast.expr import ( - ExprAssignment, ExprCast, ExprConst, ExprDeref, Expression, ExprFutureLabel, ExprHint, - ExprIdentifier, ExprOperator, ExprReg, ExprTuple) + ExprAssignment, + ExprCast, + ExprConst, + ExprDeref, + Expression, + ExprFutureLabel, + ExprHint, + ExprIdentifier, + ExprOperator, + ExprReg, + ExprTuple, +) from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall from starkware.cairo.lang.compiler.ast.formatting_utils import get_max_line_length from starkware.cairo.lang.compiler.ast.instructions import ( - AddApInstruction, AssertEqInstruction, CallInstruction, CallLabelInstruction, InstructionAst, - InstructionBody, JnzInstruction, JumpInstruction, JumpToLabelInstruction, RetInstruction) + AddApInstruction, + AssertEqInstruction, + CallInstruction, + CallLabelInstruction, + InstructionAst, + InstructionBody, + JnzInstruction, + JumpInstruction, + JumpToLabelInstruction, + RetInstruction, +) from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.ast.rvalue import RvalueCallInst, RvalueFuncCall from starkware.cairo.lang.compiler.constants import SIZE_CONSTANT from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, DefinitionError, FunctionDefinition, FutureIdentifierDefinition, - IdentifierDefinition, LabelDefinition, MemberDefinition, ReferenceDefinition, StructDefinition) + ConstDefinition, + DefinitionError, + FunctionDefinition, + FutureIdentifierDefinition, + IdentifierDefinition, + LabelDefinition, + MemberDefinition, + ReferenceDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError, IdentifierManager from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.instruction_builder import ( - InstructionBuilderError, get_instruction_size) + InstructionBuilderError, + get_instruction_size, +) from starkware.cairo.lang.compiler.location_utils import add_parent_location from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition from starkware.cairo.lang.compiler.preprocessor.compound_expressions import ( - CompoundExpressionContext, SimplicityLevel, process_compound_assert, - process_compound_expressions) + CompoundExpressionContext, + SimplicityLevel, + process_compound_assert, + process_compound_expressions, +) from starkware.cairo.lang.compiler.preprocessor.flow import ( - FlowTracking, FlowTrackingDataActual, FlowTrackingDataUnreachable, InstructionFlows, - ReferenceManager) + FlowTracking, + FlowTrackingDataActual, + FlowTrackingDataUnreachable, + InstructionFlows, + ReferenceManager, +) from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( - IdentifierAwareVisitor) + IdentifierAwareVisitor, +) from starkware.cairo.lang.compiler.preprocessor.local_variables import ( - create_simple_ref_expr, preprocess_local_variables) + create_simple_ref_expr, + preprocess_local_variables, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.preprocessor.preprocessor_utils import assert_no_modifier from starkware.cairo.lang.compiler.preprocessor.reg_tracking import ( - RegChange, RegChangeKnown, RegChangeUnconstrained, RegChangeUnknown, RegTrackingData) + RegChange, + RegChangeKnown, + RegChangeUnconstrained, + RegChangeUnknown, + RegTrackingData, +) from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference, translate_ap from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -61,7 +131,7 @@ # Indicates that the compiler should be able to deduce the change in the ap register for this # function. -KNOWN_AP_CHANGE_DECORATOR = 'known_ap_change' +KNOWN_AP_CHANGE_DECORATOR = "known_ap_change" @dataclasses.dataclass @@ -74,11 +144,15 @@ class PreprocessedInstruction: def format(self, with_locations: bool = False) -> str: location_str = ( - f' # {self.instruction.location.topmost_location()}.' + f" # {self.instruction.location.topmost_location()}." if with_locations and self.instruction.location is not None - else '') - return ''.join(hint.format(get_max_line_length()) + '\n' for hint, _ in self.hints) + \ - self.instruction.format() + location_str + else "" + ) + return ( + "".join(hint.format(get_max_line_length()) + "\n" for hint, _ in self.hints) + + self.instruction.format() + + location_str + ) @dataclasses.dataclass @@ -98,16 +172,17 @@ def format(self, with_locations: bool = False) -> str: This can be used to print the preprocessor intermediate output. """ code = self._directives_code() - code += ''.join( - inst.format(with_locations=with_locations) + '\n' for inst in self.instructions) + code += "".join( + inst.format(with_locations=with_locations) + "\n" for inst in self.instructions + ) return code def _directives_code(self) -> str: - code = '' + code = "" if self.builtins: - code += BuiltinsDirective(builtins=self.builtins).format() + '\n' + code += BuiltinsDirective(builtins=self.builtins).format() + "\n" if code: - code += '\n' + code += "\n" return code @@ -148,10 +223,15 @@ class Preprocessor(IdentifierAwareVisitor): """ def __init__( - self, prime: int, identifiers: Optional[IdentifierManager] = None, - supported_decorators: Optional[Set[str]] = None, - functions_to_compile: Optional[Set[ScopedName]] = None): - super().__init__(identifiers=identifiers,) + self, + prime: int, + identifiers: Optional[IdentifierManager] = None, + supported_decorators: Optional[Set[str]] = None, + functions_to_compile: Optional[Set[ScopedName]] = None, + ): + super().__init__( + identifiers=identifiers, + ) self.prime: int = prime self.instructions: List[PreprocessedInstruction] = [] # Stores the program counter of the next instruction (where the first instruction is at 0). @@ -193,7 +273,8 @@ def __init__( self.removed_prefixes: Set[ScopedName] = set() def search_identifier( - self, name: str, location: Optional[Location]) -> Optional[IdentifierDefinition]: + self, name: str, location: Optional[Location] + ) -> Optional[IdentifierDefinition]: """ Searches for the given identifier in self.identifiers and returns the corresponding IdentifierDefinition. @@ -216,7 +297,8 @@ def update_identifiers(self, identifiers: IdentifierManager): self.add_future_definition(name, future_definition) def add_future_definition( - self, name: ScopedName, future_definition: FutureIdentifierDefinition): + self, name: ScopedName, future_definition: FutureIdentifierDefinition + ): """ Adds a future definition of an identifier. """ @@ -234,7 +316,8 @@ def visit_uncommented_code_block(self, code_elements: List[CodeElement]): self.directives_allowed = False # Make sure there are no hints at the end of the code block. self.check_no_hints( - 'Found a hint at the end of a code block. Hints must be followed by an instruction.') + "Found a hint at the end of a code block. Hints must be followed by an instruction." + ) def visit_CodeBlock(self, code_block: CodeBlock): # Remove the CommentedCodeElement wrapper. @@ -247,7 +330,8 @@ def visit_CairoModule(self, module: CairoModule): raise PreprocessorError( f"Scope '{module.module_name}' collides with a different identifier " f"of type '{identifier_value.TYPE}'.", - location=None) + location=None, + ) self.flow_tracking.revoke() super().visit_CairoModule(module) @@ -260,18 +344,21 @@ def resolve_labels(self): self.flow_tracking, old_flow_tracking = FlowTracking(), self.flow_tracking self.function_metadata, old_function_metadata = {}, self.function_metadata - assert self.accessible_scopes == [], 'Unexpected preprocessor state.' + assert self.accessible_scopes == [], "Unexpected preprocessor state." for preprocessed_instruction in old_instructions: self.accessible_scopes = preprocessed_instruction.accessible_scopes new_instruction = self.visit(preprocessed_instruction.instruction) self.check_preprocessed_instruction(new_instruction) self.current_pc += self.get_instruction_size(new_instruction) - self.instructions.append(PreprocessedInstruction( - instruction=new_instruction, - accessible_scopes=preprocessed_instruction.accessible_scopes, - hints=preprocessed_instruction.hints, - flow_tracking_data=preprocessed_instruction.flow_tracking_data)) + self.instructions.append( + PreprocessedInstruction( + instruction=new_instruction, + accessible_scopes=preprocessed_instruction.accessible_scopes, + hints=preprocessed_instruction.hints, + flow_tracking_data=preprocessed_instruction.flow_tracking_data, + ) + ) self.accessible_scopes = [] assert old_pc == self.current_pc @@ -291,8 +378,11 @@ def get_program(self): ) def create_struct_from_identifier_list( - self, identifier_list: Optional[IdentifierList], struct_name: ScopedName, - location: Optional[Location]): + self, + identifier_list: Optional[IdentifierList], + struct_name: ScopedName, + location: Optional[Location], + ): """ Creates a struct based on the given 'identifier_list'. """ @@ -306,17 +396,21 @@ def create_struct_from_identifier_list( self.add_name_definition( name=struct_name + arg.identifier.name, identifier_definition=member_def, - location=arg.location) + location=arg.location, + ) offset += self.get_size(cairo_type) self.add_name_definition( - struct_name + SIZE_CONSTANT, - ConstDefinition(value=offset), - location=location) + struct_name + SIZE_CONSTANT, ConstDefinition(value=offset), location=location + ) def add_references_from_struct_members( - self, identifier_list: Optional[IdentifierList], members: Dict[str, MemberDefinition], - scope: ScopedName, start_offset: int): + self, + identifier_list: Optional[IdentifierList], + members: Dict[str, MemberDefinition], + scope: ScopedName, + start_offset: int, + ): """ Adds a reference to an expression of the form '[fp + *]' for each of the struct members, starting from '[fp + start_offset]'. @@ -328,25 +422,28 @@ def add_references_from_struct_members( # Add a reference for the argument. assert_no_modifier(arg) self.add_simple_reference( - name=scope + arg.identifier.name, reg=Register.FP, - cairo_type=member_def.cairo_type, offset=start_offset + member_def.offset, - location=arg.location) + name=scope + arg.identifier.name, + reg=Register.FP, + cairo_type=member_def.cairo_type, + offset=start_offset + member_def.offset, + location=arg.location, + ) def visit_CodeElementFunction(self, elm: CodeElementFunction): - self.check_no_hints('Hints before functions are not allowed.') - if elm.element_type == 'struct': + self.check_no_hints("Hints before functions are not allowed.") + if elm.element_type == "struct": return # Check decorator. known_ap_change_decorator: Optional[ExprIdentifier] = None for decorator in elm.decorators: - if decorator.name == KNOWN_AP_CHANGE_DECORATOR and elm.element_type == 'func': + if decorator.name == KNOWN_AP_CHANGE_DECORATOR and elm.element_type == "func": known_ap_change_decorator = decorator continue if decorator.name not in self.supported_decorators: raise PreprocessorError( - f"Unsupported decorator: '{decorator.name}'.", - location=decorator.location) + f"Unsupported decorator: '{decorator.name}'.", location=decorator.location + ) self.flow_tracking.revoke() @@ -356,17 +453,18 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): outer_function_location = self.identifier_locations.get(self.current_scope) notes = [] if outer_function_location is not None: - loc_str = outer_function_location.to_string_with_content('') - notes.append(f'Outer function was defined here: {loc_str}') + loc_str = outer_function_location.to_string_with_content("") + notes.append(f"Outer function was defined here: {loc_str}") raise PreprocessorError( - 'Nested functions are not supported.' if elm.element_type == 'func' - else 'Cannot define a namespace inside a function.', + "Nested functions are not supported." + if elm.element_type == "func" + else "Cannot define a namespace inside a function.", location=elm.identifier.location, notes=notes, ) - if elm.element_type == 'func': + if elm.element_type == "func": # Check if this function should be skipped. if self.functions_to_compile is not None and new_scope not in self.functions_to_compile: self.removed_prefixes.add(new_scope) @@ -374,7 +472,9 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): self.add_function(elm) else: - assert elm.element_type == 'namespace', f"""\ + assert ( + elm.element_type == "namespace" + ), f"""\ Expected 'elm.element_type' to be a 'namespace'. Found: '{elm.element_type}'.""" self.add_label(identifier=elm.identifier) @@ -385,18 +485,25 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): # Create the references for the arguments. args_struct = get_struct_definition(args_scope, self.identifiers) self.add_references_from_struct_members( - identifier_list=elm.arguments, members=args_struct.members, scope=new_scope, - start_offset=-(2 + args_struct.size)) + identifier_list=elm.arguments, + members=args_struct.members, + scope=new_scope, + start_offset=-(2 + args_struct.size), + ) implicit_args_struct = get_struct_definition(implicit_args_scope, self.identifiers) self.add_references_from_struct_members( - identifier_list=elm.implicit_arguments, members=implicit_args_struct.members, - scope=new_scope, start_offset=-(2 + args_struct.size + implicit_args_struct.size)) + identifier_list=elm.implicit_arguments, + members=implicit_args_struct.members, + scope=new_scope, + start_offset=-(2 + args_struct.size + implicit_args_struct.size), + ) new_reference_states = dict(self.reference_states) if elm.implicit_arguments is not None: for typed_identifier in elm.implicit_arguments.identifiers: - new_reference_states[new_scope + typed_identifier.name] = \ - ReferenceState.ALLOW_IMPLICIT + new_reference_states[ + new_scope + typed_identifier.name + ] = ReferenceState.ALLOW_IMPLICIT # Process code_elements. with self.scoped(new_scope, parent=elm), self.set_reference_states(new_reference_states): @@ -412,20 +519,23 @@ def visit_CodeElementFunction(self, elm: CodeElementFunction): self.visit_uncommented_code_block(code_elements) - if elm.element_type == 'func': + if elm.element_type == "func": if self.flow_tracking.data != FlowTrackingDataUnreachable(): raise PreprocessorError( - 'Function must end with a return instruction or a jump.', - location=elm.identifier.location) + "Function must end with a return instruction or a jump.", + location=elm.identifier.location, + ) self.function_metadata[new_scope].completed = True if known_ap_change_decorator is not None: if not isinstance( - self.function_metadata[new_scope].total_ap_change, RegChangeKnown): + self.function_metadata[new_scope].total_ap_change, RegChangeKnown + ): raise PreprocessorError( - 'The compiler was unable to deduce the change of the ap register, as ' - 'required by this decorator.', - location=known_ap_change_decorator.location) + "The compiler was unable to deduce the change of the ap register, as " + "required by this decorator.", + location=known_ap_change_decorator.location, + ) if self.function_metadata[new_scope].total_ap_change == RegChangeUnconstrained(): # No returns occured. self.function_metadata[new_scope].total_ap_change = RegChangeUnknown() @@ -438,18 +548,21 @@ def visit_CodeElementWith(self, elm: CodeElementWith): if aliased_identifier.local_name is not None: raise PreprocessorError( "The 'as' keyword is not supported in 'with' statements.", - location=aliased_identifier.local_name.location) + location=aliased_identifier.local_name.location, + ) src_identifier_definition = self.identifiers.get_by_full_name(src_full_name) if src_identifier_definition is None: raise PreprocessorError( - f"Unknown reference '{src_identifier.name}'.", location=src_identifier.location) + f"Unknown reference '{src_identifier.name}'.", location=src_identifier.location + ) if not isinstance(src_identifier_definition, ReferenceDefinition): raise PreprocessorError( f"Expected '{src_identifier.name}' to be a reference, " - f'found: {src_identifier_definition.TYPE}.', - location=src_identifier.location) + f"found: {src_identifier_definition.TYPE}.", + location=src_identifier.location, + ) new_reference_states[src_full_name] = ReferenceState.ALLOW_IMPLICIT @@ -474,11 +587,14 @@ def set_reference_states(self, new_reference_states: Dict[ScopedName, ReferenceS def visit_CodeElementIf(self, elm: CodeElementIf): # Prepare branch compound expression. - cond_expr = self.simplify_expr_as_felt(ExprOperator( - a=elm.condition.a, op='-', b=elm.condition.b, location=elm.condition.location)) + cond_expr = self.simplify_expr_as_felt( + ExprOperator( + a=elm.condition.a, op="-", b=elm.condition.b, location=elm.condition.location + ) + ) compound_expressions_code_elements, (res_cond_expr,) = process_compound_expressions( - [cond_expr], [SimplicityLevel.DEREF], - context=self._compound_expression_context) + [cond_expr], [SimplicityLevel.DEREF], context=self._compound_expression_context + ) for code_element in compound_expressions_code_elements: self.visit(code_element) @@ -489,10 +605,17 @@ def visit_CodeElementIf(self, elm: CodeElementIf): label_end = ExprIdentifier(name=elm.label_end, location=elm.location) # Add conditional jump. - self.visit(CodeElementInstruction(InstructionAst( - body=JumpToLabelInstruction( - label=label_neq, condition=res_cond_expr, location=elm.location), - inc_ap=False, location=elm.location))) + self.visit( + CodeElementInstruction( + InstructionAst( + body=JumpToLabelInstruction( + label=label_neq, condition=res_cond_expr, location=elm.location + ), + inc_ap=False, + location=elm.location, + ) + ) + ) # Determine code blocks. eq_code_block: Optional[CodeBlock] @@ -510,9 +633,17 @@ def visit_CodeElementIf(self, elm: CodeElementIf): if self.flow_tracking.data != FlowTrackingDataUnreachable() and neq_code_block is not None: # Code block ended with a flow to next line. Since we have a "Not equal" block, we # add a jump to skip it. - self.visit(CodeElementInstruction(InstructionAst( - body=JumpToLabelInstruction(label=label_end, condition=None, location=elm.location), - inc_ap=False, location=elm.location))) + self.visit( + CodeElementInstruction( + InstructionAst( + body=JumpToLabelInstruction( + label=label_end, condition=None, location=elm.location + ), + inc_ap=False, + location=elm.location, + ) + ) + ) # Add the neq label. self.visit(CodeElementLabel(identifier=label_neq)) @@ -529,8 +660,8 @@ def visit_CodeElementDirective(self, elm: CodeElementDirective): # Visit directive. if not self.directives_allowed: raise PreprocessorError( - 'Directives must appear at the top of the file.', - location=elm.location) + "Directives must appear at the top of the file.", location=elm.location + ) self.visit(elm.directive) def visit_CodeElementImport(self, elm: CodeElementImport): @@ -539,15 +670,17 @@ def visit_CodeElementImport(self, elm: CodeElementImport): def visit_CodeElementAllocLocals(self, elm: CodeElementAllocLocals): if self.current_scope not in self.function_metadata: raise PreprocessorError( - 'alloc_locals cannot be used outside of a function.', - location=elm.location) + "alloc_locals cannot be used outside of a function.", location=elm.location + ) # Check that ap did not change from the beginning of the function. if not isinstance(self.flow_tracking.data, FlowTrackingDataActual) or ( - self.flow_tracking.data.ap_tracking != - self.function_metadata[self.current_scope].initial_ap_data): + self.flow_tracking.data.ap_tracking + != self.function_metadata[self.current_scope].initial_ap_data + ): raise PreprocessorError( - 'alloc_locals must be used before any instruction that changes the ap register.', - location=elm.location) + "alloc_locals must be used before any instruction that changes the ap register.", + location=elm.location, + ) def visit_CodeElementInstruction(self, elm: CodeElementInstruction): current_flow_tracking_data = self.flow_tracking.get() @@ -555,10 +688,12 @@ def visit_CodeElementInstruction(self, elm: CodeElementInstruction): instruction=self.visit(elm.instruction), accessible_scopes=self.accessible_scopes.copy(), hints=self.next_instruction_hints, - flow_tracking_data=current_flow_tracking_data) + flow_tracking_data=current_flow_tracking_data, + ) self._clear_next_hints() self.current_pc += self.get_instruction_size( - preprocessed_instruction.instruction, allow_auto_deduction=True) + preprocessed_instruction.instruction, allow_auto_deduction=True + ) self.instructions.append(preprocessed_instruction) def visit_CodeElementConst(self, elm: CodeElementConst): @@ -569,22 +704,22 @@ def visit_CodeElementConst(self, elm: CodeElementConst): name = self.current_scope + elm.identifier.name val = self.simplify_expr_as_felt(elm.expr) if not isinstance(val, ExprConst): - raise PreprocessorError('Expected a constant expression.', location=elm.expr.location) + raise PreprocessorError("Expected a constant expression.", location=elm.expr.location) self.add_name_definition( - name, - ConstDefinition(value=val.val), - location=elm.identifier.location) + name, ConstDefinition(value=val.val), location=elm.identifier.location + ) def visit_CodeElementMember(self, elm: CodeElementMember): - self.check_no_hints('Hints before member definitions are not allowed.') + self.check_no_hints("Hints before member definitions are not allowed.") if self.inside_a_struct(): # Was already handled by the struct collector. return raise PreprocessorError( - 'The member keyword may only be used inside a struct.', - location=elm.typed_identifier.location) + "The member keyword may only be used inside a struct.", + location=elm.typed_identifier.location, + ) def visit_CodeElementReference(self, elm: CodeElementReference): name = self.current_scope + elm.typed_identifier.identifier.name @@ -598,14 +733,16 @@ def visit_CodeElementReference(self, elm: CodeElementReference): # Copy the type from the value. dst_type = val_type if not check_cast( - src_type=val_type, - dest_type=dst_type, - identifier_manager=self.identifiers, - cast_type=CastType.ASSIGN): + src_type=val_type, + dest_type=dst_type, + identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN, + ): raise PreprocessorError( f"Cannot assign an expression of type '{val_type.format()}' " f"to a reference of type '{dst_type.format()}'.", - location=dst_type.location) + location=dst_type.location, + ) location = val.location @@ -619,11 +756,14 @@ def visit_CodeElementReference(self, elm: CodeElementReference): addr=ExprCast( expr=addr, dest_type=TypePointer(pointee=dst_type, location=location), - location=addr.location), - location=location) + location=addr.location, + ), + location=location, + ) else: ref_expr = ExprCast( - expr=val, dest_type=dst_type, cast_type=CastType.FORCED, location=location) + expr=val, dest_type=dst_type, cast_type=CastType.FORCED, location=location + ) self.add_reference( name=name, @@ -634,7 +774,8 @@ def visit_CodeElementReference(self, elm: CodeElementReference): def visit_CodeElementLocalVariable(self, elm: CodeElementLocalVariable): raise PreprocessorError( - 'Local variables are not supported outside of functions.', location=elm.location) + "Local variables are not supported outside of functions.", location=elm.location + ) def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): assert_no_modifier(elm.typed_identifier) @@ -647,12 +788,12 @@ def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): if isinstance(elm.expr, ExprHint): if not isinstance(dest_type, TypeFelt): raise PreprocessorError( - 'Hint tempvars must be of type felt.', - location=elm.expr.location) + "Hint tempvars must be of type felt.", location=elm.expr.location + ) self.visit( CodeElementHint( hint=ExprHint( - hint_code=f'memory[ap] = int({elm.expr.hint_code})', + hint_code=f"memory[ap] = int({elm.expr.hint_code})", n_prefix_newlines=0, location=elm.location, ), @@ -681,15 +822,19 @@ def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): else: dest_type = self.resolve_type(elm.typed_identifier.expr_type) if not check_cast( - src_type=src_type, dest_type=dest_type, identifier_manager=self.identifiers, - cast_type=CastType.ASSIGN): + src_type=src_type, + dest_type=dest_type, + identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN, + ): raise PreprocessorError( f"Cannot assign an expression of type '{src_type.format()}' " f"to a temporary variable of type '{dest_type.format()}'.", - location=dest_type.location) + location=dest_type.location, + ) dest_size = self.get_size(dest_type) - assert src_size == dest_size, 'Expecting src and dest types to have the same size.' + assert src_size == dest_size, "Expecting src and dest types to have the same size." src_exprs = self.simplified_expr_to_felt_expr_list(expr=expr, expr_type=src_type) self.push_compound_expressions(compound_expressions=src_exprs, location=elm.location) @@ -699,7 +844,8 @@ def visit_CodeElementTemporaryVariable(self, elm: CodeElementTemporaryVariable): reg=Register.AP, cairo_type=dest_type, offset=-src_size, - location=elm.typed_identifier.identifier.location) + location=elm.typed_identifier.identifier.location, + ) def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAssertEq): expr_a, expr_type_a = self.simplify_expr(instruction.a) @@ -707,7 +853,8 @@ def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAsse if expr_type_a != expr_type_b: raise PreprocessorError( f"Cannot compare '{expr_type_a.format()}' and '{expr_type_b.format()}'.", - location=instruction.location) + location=instruction.location, + ) src_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_a, expr_type=expr_type_a) dst_exprs = self.simplified_expr_to_felt_expr_list(expr=expr_b, expr_type=expr_type_b) @@ -718,17 +865,15 @@ def visit_CodeElementCompoundAssertEq(self, instruction: CodeElementCompoundAsse src = self.simplifier.visit(translate_ap(src, ap_diff)) dst = self.simplifier.visit(translate_ap(dst, ap_diff)) compound_expressions_code_elements, (expr_a, expr_b) = process_compound_assert( - src, - dst, - self._compound_expression_context) + src, dst, self._compound_expression_context + ) assert_eq = CodeElementInstruction( instruction=InstructionAst( - body=AssertEqInstruction( - a=expr_a, - b=expr_b, - location=instruction.location), + body=AssertEqInstruction(a=expr_a, b=expr_b, location=instruction.location), inc_ap=False, - location=instruction.location)) + location=instruction.location, + ) + ) for code_element in compound_expressions_code_elements: self.visit(code_element) @@ -739,7 +884,8 @@ def visit_CodeElementStaticAssert(self, elm: CodeElementStaticAssert): b = self.simplify_expr_as_felt(elm.b) if a != b: raise PreprocessorError( - f'Static assert failed: {a.format()} != {b.format()}.', location=elm.location) + f"Static assert failed: {a.format()} != {b.format()}.", location=elm.location + ) def optimize_expressions_for_push(self, exprs: List[Expression]) -> List[Expression]: """ @@ -765,7 +911,7 @@ def get_ap_minus_n(expr: Expression) -> Optional[int]: """ if not isinstance(expr, ExprDeref): return None - if not isinstance(expr.addr, ExprOperator) or expr.addr.op != '+': + if not isinstance(expr.addr, ExprOperator) or expr.addr.op != "+": return None if not isinstance(expr.addr.a, ExprReg) or expr.addr.a.reg != Register.AP: return None @@ -788,8 +934,8 @@ def get_ap_minus_n(expr: Expression) -> Optional[int]: return exprs[prefix_size:] def process_expr_assignment_list( - self, exprs: List[ExprAssignment], struct_name: ScopedName, - location: Optional[Location]) -> List[Expression]: + self, exprs: List[ExprAssignment], struct_name: ScopedName, location: Optional[Location] + ) -> List[Expression]: """ Returns the expressions for an argument list. Used both for function call and a return instruction. @@ -805,8 +951,8 @@ def process_expr_assignment_list( # Make sure we have the correct number of expressions. if len(exprs) != n_members: raise PreprocessorError( - f'Expected exactly {n_members} expressions, got {len(exprs)}.', - location=location) + f"Expected exactly {n_members} expressions, got {len(exprs)}.", location=location + ) passed_args = list(struct_def.members.items()) reached_named = False @@ -816,26 +962,31 @@ def process_expr_assignment_list( # Make sure all named args are after positional args. if reached_named: raise PreprocessorError( - 'Positional arguments must not appear after named arguments.', - location=expr_assignment.location) + "Positional arguments must not appear after named arguments.", + location=expr_assignment.location, + ) else: reached_named = True name = expr_assignment.identifier.name if name != member_name: raise PreprocessorError( f"Expected named arg '{member_name}' found '{name}'.", - location=expr_assignment.identifier.location) + location=expr_assignment.identifier.location, + ) felt_expr_list = self.simplify_expr_to_felt_expr_list( - expr_assignment.expr, member_def.cairo_type) + expr_assignment.expr, member_def.cairo_type + ) compound_expressions.extend(felt_expr_list) return compound_expressions def process_implicit_argument_binding( - self, implicit_args: List[ExprAssignment], - implicit_args_struct_name: ScopedName, - location: Optional[Location]) -> List[Optional[ExprIdentifier]]: + self, + implicit_args: List[ExprAssignment], + implicit_args_struct_name: ScopedName, + location: Optional[Location], + ) -> List[Optional[ExprIdentifier]]: """ Processes the implicit argument bindings. Returns a list whose size is the number of implicit arguments of the called function, with the binding variable for each argument @@ -844,27 +995,31 @@ def process_implicit_argument_binding( will be [None, w, None]. """ implicit_args_struct = self.get_struct_definition( - name=implicit_args_struct_name, location=location) + name=implicit_args_struct_name, location=location + ) # A list of (arg_name, binding). processed_implicit_args: List[Tuple[ExprIdentifier, ExprIdentifier]] = [] for arg in implicit_args: if arg.identifier is None: raise PreprocessorError( - 'Implicit argument binding must be of the form: arg_name=var.', - location=arg.location) + "Implicit argument binding must be of the form: arg_name=var.", + location=arg.location, + ) - if not isinstance(arg.expr, ExprIdentifier) or '.' in arg.expr.name: + if not isinstance(arg.expr, ExprIdentifier) or "." in arg.expr.name: raise PreprocessorError( - 'Implicit argument binding must be an identifier.', - location=arg.expr.location) + "Implicit argument binding must be an identifier.", location=arg.expr.location + ) processed_implicit_args.append((arg.identifier, arg.expr)) result: List[Optional[ExprIdentifier]] = [] for member_name in implicit_args_struct.members.keys(): - if len(processed_implicit_args) == 0 or \ - processed_implicit_args[0][0].name != member_name: + if ( + len(processed_implicit_args) == 0 + or processed_implicit_args[0][0].name != member_name + ): result.append(None) continue @@ -874,15 +1029,18 @@ def process_implicit_argument_binding( # Make sure all implicit argument bindings were processed. if len(processed_implicit_args) > 0: raise PreprocessorError( - f'Unexpected implicit argument binding: {processed_implicit_args[0][0].name}.', - location=processed_implicit_args[0][0].location) + f"Unexpected implicit argument binding: {processed_implicit_args[0][0].name}.", + location=processed_implicit_args[0][0].location, + ) return result def process_implicit_arguments( - self, implicit_args: Optional[List[Optional[ExprIdentifier]]], - implicit_args_struct_name: ScopedName, - location: Optional[Location]) -> List[Expression]: + self, + implicit_args: Optional[List[Optional[ExprIdentifier]]], + implicit_args_struct_name: ScopedName, + location: Optional[Location], + ) -> List[Expression]: """ Returns the expressions for the implicit arguments. Used both for function call and a return instruction. @@ -892,14 +1050,16 @@ def process_implicit_arguments( location - location to attach to errors if no finer location is relevant. """ implicit_args_struct = self.get_struct_definition( - name=implicit_args_struct_name, location=location) + name=implicit_args_struct_name, location=location + ) if implicit_args is None: implicit_args = [None] * len(implicit_args_struct.members) compound_expressions = [] for (member_name, member_def), implicit_arg in safe_zip( - implicit_args_struct.members.items(), implicit_args): + implicit_args_struct.members.items(), implicit_args + ): expr: Expression if implicit_arg is not None: # Explicit binding is given, use it. @@ -909,7 +1069,8 @@ def process_implicit_arguments( expr = add_parent_location( expr=ExprIdentifier(name=member_name, location=member_def.location), new_parent_location=location, - message=f"While trying to retrieve the implicit argument '{member_name}' in:") + message=f"While trying to retrieve the implicit argument '{member_name}' in:", + ) felt_expr_list = self.simplify_expr_to_felt_expr_list(expr, member_def.cairo_type) compound_expressions.extend(felt_expr_list) @@ -917,7 +1078,8 @@ def process_implicit_arguments( return compound_expressions def push_compound_expressions( - self, compound_expressions: List[Expression], location: Optional[Location]): + self, compound_expressions: List[Expression], location: Optional[Location] + ): """ Generates instructions to push all the given expressions onto the stack. In more detail: translates a list of expressions to a set of instructions evaluating the @@ -930,17 +1092,17 @@ def push_compound_expressions( compound_expressions_code_elements, simple_exprs = process_compound_expressions( compound_expressions, SimplicityLevel.OPERATION, - context=self._compound_expression_context) + context=self._compound_expression_context, + ) for code_element in compound_expressions_code_elements: self.visit(code_element) assert len(simple_exprs) == len(compound_expressions) simple_exprs = self.optimize_expressions_for_push(simple_exprs) - compound_expressions = compound_expressions[-len(simple_exprs):] + compound_expressions = compound_expressions[-len(simple_exprs) :] - for i, (simple_expr, original_expr) in enumerate( - zip(simple_exprs, compound_expressions)): + for i, (simple_expr, original_expr) in enumerate(zip(simple_exprs, compound_expressions)): location = original_expr.location code_elm_inst = CodeElementInstruction( instruction=InstructionAst( @@ -954,14 +1116,18 @@ def push_compound_expressions( ), inc_ap=True, location=location, - )) + ) + ) self.visit(code_elm_inst) def push_arguments( - self, arguments: List[ExprAssignment], - implicit_args: Optional[List[Optional[ExprIdentifier]]], - struct_name: ScopedName, implicit_args_struct_name: ScopedName, - location: Optional[Location]): + self, + arguments: List[ExprAssignment], + implicit_args: Optional[List[Optional[ExprIdentifier]]], + struct_name: ScopedName, + implicit_args_struct_name: ScopedName, + location: Optional[Location], + ): """ Generates instructions to push all arguments (including the implicit arguments) onto the stack. @@ -977,11 +1143,14 @@ def push_arguments( location - location to attach to errors if no finer location is relevant. """ args_expressions = self.process_expr_assignment_list( - exprs=arguments, struct_name=struct_name, location=location) + exprs=arguments, struct_name=struct_name, location=location + ) implicit_args_expressions = self.process_implicit_arguments( implicit_args=implicit_args, - implicit_args_struct_name=implicit_args_struct_name, location=location) + implicit_args_struct_name=implicit_args_struct_name, + location=location, + ) self.push_compound_expressions( compound_expressions=implicit_args_expressions + args_expressions, @@ -991,7 +1160,8 @@ def push_arguments( def visit_CodeElementReturn(self, elm: CodeElementReturn): if self.current_scope not in self.function_metadata: raise PreprocessorError( - f'return cannot be used outside of a function.', location=elm.location) + f"return cannot be used outside of a function.", location=elm.location + ) self.push_arguments( arguments=cast(List[ExprAssignment], elm.exprs), @@ -1001,14 +1171,13 @@ def visit_CodeElementReturn(self, elm: CodeElementReturn): location=elm.location, ) code_elm_ret = CodeElementInstruction( - instruction=InstructionAst( - body=RetInstruction(), - inc_ap=False, - location=elm.location)) + instruction=InstructionAst(body=RetInstruction(), inc_ap=False, location=elm.location) + ) self.visit(code_elm_ret) def check_tail_call_cast( - self, src_struct: StructDefinition, dest_struct: StructDefinition) -> bool: + self, src_struct: StructDefinition, dest_struct: StructDefinition + ) -> bool: """ Checks if src_struct can be converted to dest_struct in the context of a tail call. """ @@ -1020,10 +1189,11 @@ def check_tail_call_cast( for src_member, dest_member in zip(src_members.values(), dest_members.values()): if not check_cast( - src_type=src_member.cairo_type, - dest_type=dest_member.cairo_type, - identifier_manager=self.identifiers, - cast_type=CastType.ASSIGN): + src_type=src_member.cairo_type, + dest_type=dest_member.cairo_type, + identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN, + ): return False return True @@ -1031,7 +1201,8 @@ def check_tail_call_cast( def visit_CodeElementTailCall(self, elm: CodeElementTailCall): if self.current_scope not in self.function_metadata: raise PreprocessorError( - f'return cannot be used outside of a function.', location=elm.location) + f"return cannot be used outside of a function.", location=elm.location + ) # Visit function call before type check to get better error message. self.visit(CodeElementFuncCall(func_call=elm.func_call)) @@ -1040,33 +1211,42 @@ def visit_CodeElementTailCall(self, elm: CodeElementTailCall): src_struct = self.get_struct_definition( name=ScopedName.from_string(func_name) + CodeElementFunction.RETURN_SCOPE, - location=elm.location) + location=elm.location, + ) dest_struct = get_struct_definition( struct_name=self.current_scope + CodeElementFunction.RETURN_SCOPE, - identifier_manager=self.identifiers) + identifier_manager=self.identifiers, + ) if not self.check_tail_call_cast(src_struct=src_struct, dest_struct=dest_struct): raise PreprocessorError( f"""\ Cannot convert the return type of {func_name} to the return type of {self.current_scope[-1:]}.""", - location=elm.location) + location=elm.location, + ) src_struct = self.get_struct_definition( name=ScopedName.from_string(func_name) + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, - location=elm.location) + location=elm.location, + ) dest_struct = get_struct_definition( struct_name=self.current_scope + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, - identifier_manager=self.identifiers) + identifier_manager=self.identifiers, + ) if list(src_struct.members.items()) != list(dest_struct.members.items()): - notes = [] if src_struct.location is None or dest_struct.location is None else [ - f"The implicit arguments of '{func_name}' were defined here:\n" + - src_struct.location.to_string_with_content(), - f"The implicit arguments of '{self.current_scope[-1:]}' were defined here:\n" + - dest_struct.location.to_string_with_content(), - ] + notes = ( + [] + if src_struct.location is None or dest_struct.location is None + else [ + f"The implicit arguments of '{func_name}' were defined here:\n" + + src_struct.location.to_string_with_content(), + f"The implicit arguments of '{self.current_scope[-1:]}' were defined here:\n" + + dest_struct.location.to_string_with_content(), + ] + ) raise PreprocessorError( f"""\ @@ -1076,29 +1256,34 @@ def visit_CodeElementTailCall(self, elm: CodeElementTailCall): notes=notes, ) - self.visit(CodeElementInstruction( - instruction=InstructionAst( - body=RetInstruction(), - inc_ap=False, - location=elm.location))) + self.visit( + CodeElementInstruction( + instruction=InstructionAst( + body=RetInstruction(), inc_ap=False, location=elm.location + ) + ) + ) def add_implicit_return_references( - self, implicit_args: List[Optional[ExprIdentifier]], - called_function: ScopedName, - location: Optional[Location]): + self, + implicit_args: List[Optional[ExprIdentifier]], + called_function: ScopedName, + location: Optional[Location], + ): """ Adds references that allow accessing the implicit return values of a called function. """ implicit_args_struct = self.get_struct_definition( - name=called_function + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, - location=location) + name=called_function + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE, location=location + ) return_size = self.get_struct_size( - struct_name=called_function + CodeElementFunction.RETURN_SCOPE, - location=location) + struct_name=called_function + CodeElementFunction.RETURN_SCOPE, location=location + ) assert len(implicit_args_struct.members) == len(implicit_args) for (name, member_def), implicit_arg in zip( - implicit_args_struct.members.items(), implicit_args): + implicit_args_struct.members.items(), implicit_args + ): if implicit_arg is not None: # Use the implicit argument binding. binding_var = implicit_arg.name @@ -1109,50 +1294,62 @@ def add_implicit_return_references( if location is not None and implicit_arg_location is not None: implicit_arg_location = implicit_arg_location.with_parent_location( new_parent_location=location, - message=f"While trying to update the implicit return value '{name}' in:") + message=f"While trying to update the implicit return value '{name}' in:", + ) self.add_simple_reference( name=self.current_scope + binding_var, reg=Register.AP, cairo_type=member_def.cairo_type, offset=member_def.offset - (return_size + implicit_args_struct.size), - location=implicit_arg_location) + location=implicit_arg_location, + ) - if implicit_arg is None and self.reference_states.get( - self.current_scope + name) is not ReferenceState.ALLOW_IMPLICIT: + if ( + implicit_arg is None + and self.reference_states.get(self.current_scope + name) + is not ReferenceState.ALLOW_IMPLICIT + ): raise PreprocessorError( f"'{name}' cannot be used as an implicit return value. " "Consider using a 'with' statement.", - location=implicit_arg_location) + location=implicit_arg_location, + ) def visit_CodeElementFuncCall(self, elm: CodeElementFuncCall): # Make sure the identifier for the called function refers to a function. called_function = ScopedName.from_string(elm.func_call.func_ident.name) try: res = self.identifiers.search( - accessible_scopes=self.accessible_scopes, name=called_function) + accessible_scopes=self.accessible_scopes, name=called_function + ) res.assert_fully_parsed() except IdentifierError as exc: raise PreprocessorError(str(exc), location=elm.func_call.func_ident.location) called_function_def = res.identifier_definition - called_function_def_type = called_function_def.identifier_type \ - if isinstance(called_function_def, FutureIdentifierDefinition) \ + called_function_def_type = ( + called_function_def.identifier_type + if isinstance(called_function_def, FutureIdentifierDefinition) else type(called_function_def) + ) if called_function_def_type is not FunctionDefinition: raise PreprocessorError( - f'Expected {called_function} to be a function name. ' - f'Found: {called_function_def.TYPE}.', - location=elm.func_call.func_ident.location) + f"Expected {called_function} to be a function name. " + f"Found: {called_function_def.TYPE}.", + location=elm.func_call.func_ident.location, + ) implicit_args_struct_name = called_function + CodeElementFunction.IMPLICIT_ARGUMENT_SCOPE implicit_args = ( cast(List[ExprAssignment], elm.func_call.implicit_arguments.args) if elm.func_call.implicit_arguments is not None - else []) + else [] + ) processed_implicit_args = self.process_implicit_argument_binding( implicit_args=implicit_args, implicit_args_struct_name=implicit_args_struct_name, - location=elm.func_call.location) + location=elm.func_call.location, + ) self.push_arguments( arguments=cast(List[ExprAssignment], elm.func_call.arguments.args), @@ -1168,7 +1365,9 @@ def visit_CodeElementFuncCall(self, elm: CodeElementFuncCall): location=elm.func_call.location, ), inc_ap=False, - location=elm.func_call.location)) + location=elm.func_call.location, + ) + ) self.visit(code_elm_call) self.add_implicit_return_references( @@ -1178,17 +1377,20 @@ def visit_CodeElementFuncCall(self, elm: CodeElementFuncCall): ) def add_simple_reference( - self, name: ScopedName, reg: Register, cairo_type: CairoType, offset: int, - location: Optional[Location]): + self, + name: ScopedName, + reg: Register, + cairo_type: CairoType, + offset: int, + location: Optional[Location], + ): """ Creates a simple reference with the given name to "[reg + offset]". """ ref_expr = create_simple_ref_expr( - reg=reg, - offset=offset, - cairo_type=cairo_type, - location=location) + reg=reg, offset=offset, cairo_type=cairo_type, location=location + ) self.add_reference( name=name, value=ref_expr, @@ -1204,31 +1406,37 @@ def visit_CodeElementReturnValueReference(self, elm: CodeElementReturnValueRefer body=elm.func_call.call_inst, inc_ap=False, location=elm.func_call.call_inst.location, - )) + ) + ) func_ident = None if isinstance(elm.func_call.call_inst, CallLabelInstruction): func_ident = elm.func_call.call_inst.label elif isinstance(elm.func_call, RvalueFuncCall): # If the function name is the name of a struct, replace the # CodeElementReturnValueReference with a regular reference. - if self.try_get_struct_definition( - ScopedName.from_string(elm.func_call.func_ident.name)) is not None: - return self.visit(CodeElementReference( - typed_identifier=elm.typed_identifier, - expr=ExprFuncCall( - rvalue=elm.func_call, - location=elm.func_call.location))) + if ( + self.try_get_struct_definition( + ScopedName.from_string(elm.func_call.func_ident.name) + ) + is not None + ): + return self.visit( + CodeElementReference( + typed_identifier=elm.typed_identifier, + expr=ExprFuncCall(rvalue=elm.func_call, location=elm.func_call.location), + ) + ) call_elm = CodeElementFuncCall(func_call=elm.func_call) func_ident = elm.func_call.func_ident else: - raise NotImplementedError(f'Unsupported func_call={elm.func_call}.') + raise NotImplementedError(f"Unsupported func_call={elm.func_call}.") expr_type = elm.typed_identifier.expr_type if expr_type is None: if func_ident is not None: expr_type = TypeStruct( - scope=ScopedName.from_string(func_ident.name) + - CodeElementFunction.RETURN_SCOPE, + scope=ScopedName.from_string(func_ident.name) + + CodeElementFunction.RETURN_SCOPE, is_fully_resolved=False, location=func_ident.location, ) @@ -1251,41 +1459,46 @@ def visit_CodeElementReturnValueReference(self, elm: CodeElementReturnValueRefer def get_unpacking_struct_definition(self, elm: CodeElementUnpackBinding): if not isinstance(elm.rvalue, RvalueFuncCall): raise PreprocessorError( - f'Cannot unpack {elm.rvalue.format()}.', - location=elm.rvalue.location) + f"Cannot unpack {elm.rvalue.format()}.", location=elm.rvalue.location + ) func_ident = elm.rvalue.func_ident - return_type = self.resolve_type(TypeStruct( - scope=ScopedName.from_string(func_ident.name) + CodeElementFunction.RETURN_SCOPE, - is_fully_resolved=False, - location=func_ident.location, - )) - assert isinstance(return_type, TypeStruct), f'Unexpected type {return_type}.' + return_type = self.resolve_type( + TypeStruct( + scope=ScopedName.from_string(func_ident.name) + CodeElementFunction.RETURN_SCOPE, + is_fully_resolved=False, + location=func_ident.location, + ) + ) + assert isinstance(return_type, TypeStruct), f"Unexpected type {return_type}." struct_def = get_struct_definition(return_type.scope, identifier_manager=self.identifiers) expected_len = len(struct_def.members) unpacking_identifiers = elm.unpacking_list.identifiers if len(unpacking_identifiers) != expected_len: - suffix = 's' if expected_len > 1 else '' + suffix = "s" if expected_len > 1 else "" raise PreprocessorError( f"""\ Expected {expected_len} unpacking identifier{suffix}, found {len(unpacking_identifiers)}.""", - location=elm.unpacking_list.location) + location=elm.unpacking_list.location, + ) return struct_def def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): struct_def = self.get_unpacking_struct_definition(elm) - assert isinstance(elm.rvalue, RvalueFuncCall), \ - f'Invalid type for elm.rvalue: {type(elm.rvalue).__name__}.' + assert isinstance( + elm.rvalue, RvalueFuncCall + ), f"Invalid type for elm.rvalue: {type(elm.rvalue).__name__}." self.visit(CodeElementFuncCall(func_call=elm.rvalue)) for typed_identifier, member_def in zip( - elm.unpacking_list.identifiers, struct_def.members.values()): + elm.unpacking_list.identifiers, struct_def.members.values() + ): assert_no_modifier(typed_identifier) - if typed_identifier.name == '_': + if typed_identifier.name == "_": continue if typed_identifier.expr_type is not None: @@ -1294,32 +1507,41 @@ def visit_CodeElementUnpackBinding(self, elm: CodeElementUnpackBinding): cairo_type = member_def.cairo_type if not check_cast( - src_type=member_def.cairo_type, dest_type=cairo_type, - identifier_manager=self.identifiers, - cast_type=CastType.UNPACKING): + src_type=member_def.cairo_type, + dest_type=cairo_type, + identifier_manager=self.identifiers, + cast_type=CastType.UNPACKING, + ): raise PreprocessorError( f"""\ Expected expression of type '{member_def.cairo_type.format()}', got '{cairo_type.format()}'.""", - location=typed_identifier.location + location=typed_identifier.location, ) self.add_simple_reference( - name=self.current_scope + typed_identifier.identifier.name, reg=Register.AP, - cairo_type=cairo_type, offset=member_def.offset - struct_def.size, - location=typed_identifier.location) + name=self.current_scope + typed_identifier.identifier.name, + reg=Register.AP, + cairo_type=cairo_type, + offset=member_def.offset - struct_def.size, + location=typed_identifier.location, + ) def add_label(self, identifier: ExprIdentifier): name = self.current_scope + identifier.name self.flow_tracking.converge_with_label(name) self.add_name_definition( - name, - LabelDefinition(pc=self.current_pc), # type: ignore - location=identifier.location) + name, LabelDefinition(pc=self.current_pc), location=identifier.location # type: ignore + ) def add_reference( - self, name: ScopedName, value: Expression, cairo_type: CairoType, - location: Optional[Location], require_future_definition=True): - if name.path[-1] == '_': + self, + name: ScopedName, + value: Expression, + cairo_type: CairoType, + location: Optional[Location], + require_future_definition=True, + ): + if name.path[-1] == "_": raise PreprocessorError("Reference name cannot be '_'.", location=location) reference = Reference( @@ -1335,16 +1557,19 @@ def add_reference( # Rebind reference. if existing_definition.cairo_type != cairo_type: raise PreprocessorError( - 'Reference rebinding must preserve the reference type. ' + "Reference rebinding must preserve the reference type. " f"Previous type: '{existing_definition.cairo_type.format()}', " f"new type: '{cairo_type.format()}'.", - location=location) + location=location, + ) existing_definition.references.append(reference) else: self.add_name_definition( name, ReferenceDefinition(full_name=name, cairo_type=cairo_type, references=[reference]), - location=location, require_future_definition=require_future_definition) + location=location, + require_future_definition=require_future_definition, + ) def add_function(self, elm: CodeElementFunction): name = self.current_scope + elm.name @@ -1354,13 +1579,15 @@ def add_function(self, elm: CodeElementFunction): pc=self.current_pc, decorators=[identifier.name for identifier in elm.decorators], ), - location=elm.identifier.location) + location=elm.identifier.location, + ) self.function_metadata[name] = FunctionMetadata( - initial_ap_data=self.flow_tracking.get_ap_tracking()) + initial_ap_data=self.flow_tracking.get_ap_tracking() + ) def visit_CodeElementLabel(self, elm: CodeElementLabel): - self.check_no_hints('Hints before labels are not allowed.') + self.check_no_hints("Hints before labels are not allowed.") self.add_label(elm.identifier) def visit_CodeElementHint(self, elm: CodeElementHint): @@ -1374,9 +1601,8 @@ def visit_CodeElementEmptyLine(self, elm: CodeElementEmptyLine): def visit_InstructionAst(self, instruction: InstructionAst): flows, instruction_body = self.visit(instruction.body) res = InstructionAst( - body=instruction_body, - inc_ap=instruction.inc_ap, - location=instruction.location) + body=instruction_body, inc_ap=instruction.inc_ap, location=instruction.location + ) added_ap = 1 if instruction.inc_ap else 0 # Add jump flows. @@ -1396,14 +1622,16 @@ def visit_AssertEqInstruction(self, instruction: AssertEqInstruction): return InstructionFlows(next_inst=RegChangeKnown(0)), AssertEqInstruction( a=self.simplify_expr_as_felt(instruction.a), b=self.simplify_expr_as_felt(instruction.b), - location=instruction.location) + location=instruction.location, + ) def visit_JumpInstruction(self, instruction: JumpInstruction): self.revoke_function_ap_change() return InstructionFlows(), JumpInstruction( val=self.simplify_expr_as_felt(instruction.val), relative=instruction.relative, - location=instruction.location) + location=instruction.location, + ) def visit_JumpToLabelInstruction(self, instruction: JumpToLabelInstruction): label_name = instruction.label.name @@ -1418,18 +1646,19 @@ def visit_JumpToLabelInstruction(self, instruction: JumpToLabelInstruction): res_instruction = dataclasses.replace(instruction, condition=condition) else: jump_offset = ExprConst( - val=label_pc - self.current_pc, location=instruction.label.location) + val=label_pc - self.current_pc, location=instruction.label.location + ) if instruction.condition is None: self.current_instruction_ended_flow = True res_instruction = JumpInstruction( - val=jump_offset, - relative=True, - location=instruction.location) + val=jump_offset, relative=True, location=instruction.location + ) else: res_instruction = JnzInstruction( jump_offset=jump_offset, condition=self.simplify_expr_as_felt(instruction.condition), - location=instruction.location) + location=instruction.location, + ) if label_pc <= self.current_pc: self.revoke_function_ap_change() @@ -1437,8 +1666,8 @@ def visit_JumpToLabelInstruction(self, instruction: JumpToLabelInstruction): flow_next = None if instruction.condition is None else RegChangeKnown(0) if label_full_name is None: raise PreprocessorError( - f'Unknown label {label_name}.', - location=instruction.label.location) + f"Unknown label {label_name}.", location=instruction.label.location + ) jumps: Dict[ScopedName, RegChange] = {label_full_name: RegChangeKnown(0)} return InstructionFlows(next_inst=flow_next, jumps=jumps), res_instruction @@ -1447,7 +1676,8 @@ def visit_JnzInstruction(self, instruction: JnzInstruction): return InstructionFlows(next_inst=RegChangeKnown(0)), JnzInstruction( jump_offset=self.simplify_expr_as_felt(instruction.jump_offset), condition=self.simplify_expr_as_felt(instruction.condition), - location=instruction.location) + location=instruction.location, + ) def revoke_function_ap_change(self): """ @@ -1461,7 +1691,8 @@ def visit_CallInstruction(self, instruction: CallInstruction): return InstructionFlows(next_inst=RegChangeUnknown()), CallInstruction( val=self.simplify_expr_as_felt(instruction.val), relative=instruction.relative, - location=instruction.location) + location=instruction.location, + ) def visit_CallLabelInstruction(self, instruction: CallLabelInstruction): label_name = instruction.label.name @@ -1479,16 +1710,16 @@ def visit_CallLabelInstruction(self, instruction: CallLabelInstruction): # Add 2 for call instruction. ap_change = 2 + metadata.total_ap_change - jump_offset = ExprConst( - val=label_pc - self.current_pc, location=instruction.label.location) - return InstructionFlows(next_inst=ap_change), \ - CallInstruction(val=jump_offset, relative=True, location=instruction.location) + jump_offset = ExprConst(val=label_pc - self.current_pc, location=instruction.label.location) + return InstructionFlows(next_inst=ap_change), CallInstruction( + val=jump_offset, relative=True, location=instruction.location + ) def visit_AddApInstruction(self, instruction: AddApInstruction): expr = self.simplify_expr_as_felt(instruction.expr) return InstructionFlows(next_inst=RegChange.from_expr(expr)), AddApInstruction( - expr=expr, - location=instruction.location) + expr=expr, location=instruction.location + ) def visit_RetInstruction(self, instruction: RetInstruction): if self.current_scope in self.function_metadata: @@ -1501,7 +1732,7 @@ def visit_RetInstruction(self, instruction: RetInstruction): def visit_BuiltinsDirective(self, directive: BuiltinsDirective): if self.builtins is not None: raise PreprocessorError( - 'Redefinition of builtins directive.', + "Redefinition of builtins directive.", location=directive.location, ) @@ -1519,7 +1750,7 @@ def visit_BuiltinsDirective(self, directive: BuiltinsDirective): def visit_LangDirective(self, directive: LangDirective): raise PreprocessorError( - f'Unsupported %lang directive. Are you using the correct compiler?', + f"Unsupported %lang directive. Are you using the correct compiler?", location=directive.location, ) @@ -1532,7 +1763,8 @@ def simplify_expr(self, expr) -> Tuple[Expression, CairoType]: expr = substitute_identifiers( expr=expr, get_identifier_callback=self.get_variable, - resolve_type_callback=self.resolve_type) + resolve_type_callback=self.resolve_type, + ) expr, expr_type = simplify_type_system(expr, identifiers=self.identifiers) return self.simplifier.visit(expr), self.resolve_type(expr_type) @@ -1545,11 +1777,13 @@ def simplify_expr_as_felt(self, expr) -> Expression: if not isinstance(expr_type, (TypeFelt, TypePointer)): raise PreprocessorError( f"Expected a 'felt' or a pointer type. Got: '{expr_type.format()}'.", - location=expr.location) + location=expr.location, + ) return expr def simplify_expr_to_felt_expr_list( - self, expr: Expression, expected_type: CairoType) -> List[Expression]: + self, expr: Expression, expected_type: CairoType + ) -> List[Expression]: """ Takes a possibly typed expression, checks that it can be assigned to expected_type and splits it into a list of typeless expressions that can be passed to @@ -1561,18 +1795,22 @@ def simplify_expr_to_felt_expr_list( expr, expr_type = self.simplify_expr(expr) if not check_cast( - src_type=expr_type, dest_type=expected_type, identifier_manager=self.identifiers, - cast_type=CastType.ASSIGN): + src_type=expr_type, + dest_type=expected_type, + identifier_manager=self.identifiers, + cast_type=CastType.ASSIGN, + ): raise PreprocessorError( f"""\ Expected expression of type '{expected_type.format()}', got '{expr_type.format()}'.""", - location=location + location=location, ) return self.simplified_expr_to_felt_expr_list(expr=expr, expr_type=expr_type) def simplified_expr_to_felt_expr_list( - self, expr: Expression, expr_type: CairoType) -> List[Expression]: + self, expr: Expression, expr_type: CairoType + ) -> List[Expression]: """ Takes a simplified expression and its type and splits it into a list of typeless expressions that can be passed to process_compound_expressions. @@ -1586,11 +1824,13 @@ def simplified_expr_to_felt_expr_list( member_types = expr_type.members elif isinstance(expr_type, TypeStruct): struct_definition = get_struct_definition( - expr_type.scope, identifier_manager=self.identifiers) + expr_type.scope, identifier_manager=self.identifiers + ) member_types = [ - member_def.cairo_type for member_def in struct_definition.members.values()] + member_def.cairo_type for member_def in struct_definition.members.values() + ] else: - raise PreprocessorError(f'Unexpected type {expr_type}.', location=expr_type.location) + raise PreprocessorError(f"Unexpected type {expr_type}.", location=expr_type.location) # Get the list of member expressions. if isinstance(expr, ExprTuple): @@ -1609,31 +1849,35 @@ def simplified_expr_to_felt_expr_list( ExprDeref( ExprOperator( a=addr, - op='+', + op="+", b=ExprConst(offset, location=location), location=location, ), location=location, - ))) + ) + ) + ) offset += self.get_size(member_type) expr_list = [] for member_expr, member_type in zip(member_exprs, member_types): - expr_list.extend(self.simplified_expr_to_felt_expr_list( - expr=member_expr, expr_type=member_type)) + expr_list.extend( + self.simplified_expr_to_felt_expr_list(expr=member_expr, expr_type=member_type) + ) return expr_list - def get_label(self, label_name: str, location: Optional[Location]) -> \ - Tuple[Optional[int], Optional[ScopedName]]: + def get_label( + self, label_name: str, location: Optional[Location] + ) -> Tuple[Optional[int], Optional[ScopedName]]: """ Returns a pair (pc, canonical_name) for the given label, or (None, None) if this label hasn't been processed yet. """ try: search_result = self.identifiers.search( - accessible_scopes=self.accessible_scopes, - name=ScopedName.from_string(label_name)) + accessible_scopes=self.accessible_scopes, name=ScopedName.from_string(label_name) + ) search_result.assert_fully_parsed() except IdentifierError as exc: raise PreprocessorError(str(exc), location=location) @@ -1644,7 +1888,9 @@ def get_label(self, label_name: str, location: Optional[Location]) -> \ if not isinstance(search_result.identifier_definition, LabelDefinition): raise PreprocessorError( f"Expected a label name. Identifier '{label_name}' is of type " - f'{search_result.identifier_definition.TYPE}.', location=location) + f"{search_result.identifier_definition.TYPE}.", + location=location, + ) return search_result.identifier_definition.pc, search_result.canonical_name def get_variable(self, var: ExprIdentifier): @@ -1657,8 +1903,8 @@ def get_variable(self, var: ExprIdentifier): # Allow future label assignment. return ExprFutureLabel(identifier=var) raise PreprocessorError( - f"Identifier '{var.name}' referenced before definition.", - location=var.location) + f"Identifier '{var.name}' referenced before definition.", location=var.location + ) if isinstance(identifier_definition, ConstDefinition): return identifier_definition.value @@ -1673,25 +1919,28 @@ def get_variable(self, var: ExprIdentifier): try: res_expr = identifier_definition.eval( reference_manager=self.flow_tracking.reference_manager, - flow_tracking_data=self.flow_tracking.data) + flow_tracking_data=self.flow_tracking.data, + ) if var.location is not None: res_expr = add_parent_location( expr=res_expr, new_parent_location=var.location, - message=f"While expanding the reference '{var.name}' in:") + message=f"While expanding the reference '{var.name}' in:", + ) return res_expr except FlowTrackingError as exc: raise PreprocessorError( - f"Reference '{var.name}' was revoked.", location=var.location, notes=exc.notes) + f"Reference '{var.name}' was revoked.", location=var.location, notes=exc.notes + ) except DefinitionError as exc: raise PreprocessorError(str(exc), location=var.location) raise PreprocessorError( - f'Unexpected identifier {var.name} of type {identifier_definition.TYPE}.', - location=var.location) + f"Unexpected identifier {var.name} of type {identifier_definition.TYPE}.", + location=var.location, + ) - def get_instruction_size( - self, instruction: InstructionAst, allow_auto_deduction: bool = False): + def get_instruction_size(self, instruction: InstructionAst, allow_auto_deduction: bool = False): """ Returns the size of the instruction in field elements by calling build_instruction(). If allow_auto_deduction is True, then in some cases (where labels are involved) @@ -1702,9 +1951,9 @@ def get_instruction_size( except InstructionBuilderError as exc: # If for some reason location is not known, use the location of the full instruction. if exc.location is None: - exc.notes.append('Missing exact location information on this error.') + exc.notes.append("Missing exact location information on this error.") exc.location = instruction.location - exc.notes.append(f'Preprocessed instruction:\n{instruction.format()}') + exc.notes.append(f"Preprocessed instruction:\n{instruction.format()}") raise exc def check_preprocessed_instruction(self, instruction: InstructionAst): @@ -1715,7 +1964,7 @@ def check_preprocessed_instruction(self, instruction: InstructionAst): """ if isinstance(instruction.body, (JumpToLabelInstruction, CallLabelInstruction)): label = instruction.body.label - raise PreprocessorError(f'Unknown label {label.name}.', location=label.location) + raise PreprocessorError(f"Unknown label {label.name}.", location=label.location) def check_no_hints(self, msg): """ @@ -1729,7 +1978,7 @@ def new_unique_id(self) -> str: """ Returns a new identifier name. """ - name = f'__temp{self.next_temp_id}' + name = f"__temp{self.next_temp_id}" self.next_temp_id += 1 self.scoped_temp_ids.add(self.current_scope + name) return name @@ -1751,10 +2000,12 @@ def new_tempvar_name(self) -> str: def get_fp_val(self, location: Optional[Location]) -> Expression: try: return self.preprocessor.simplify_expr_as_felt( - ExprIdentifier(name='__fp__', location=location)) + ExprIdentifier(name="__fp__", location=location) + ) except PreprocessorError as exc: - if 'Unknown identifier' not in exc.message: + if "Unknown identifier" not in exc.message: raise raise PreprocessorError( - 'Using the value of fp directly, requires defining a variable named __fp__.', - location=exc.location) + "Using the value of fp directly, requires defining a variable named __fp__.", + location=exc.location, + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py index 3beff714..63275ff1 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py @@ -4,14 +4,22 @@ from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, LabelDefinition, ReferenceDefinition) + ConstDefinition, + LabelDefinition, + ReferenceDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import IdentifierError from starkware.cairo.lang.compiler.instruction_builder import InstructionBuilderError from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager from starkware.cairo.lang.compiler.preprocessor.preprocess_codes import preprocess_codes from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( - PRIME, TEST_SCOPE, preprocess_str, strip_comments_and_linebreaks, verify_exception) + PRIME, + TEST_SCOPE, + preprocess_str, + strip_comments_and_linebreaks, + verify_exception, +) from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.test_utils import read_file_from_dict from starkware.cairo.lang.compiler.type_casts import CairoTypeError @@ -20,7 +28,8 @@ def test_compiler(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" const x = 5 const y = 2 * x @@ -35,8 +44,12 @@ def test_compiler(): ret label: jmp label if [fp + 3 + 1] != 0 -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ [ap] = [[fp + 6] + 16]; ap++ ap += 1028 [ap] = [fp] @@ -45,6 +58,7 @@ def test_compiler(): ret jmp rel 0 if [fp + 4] != 0 """ + ) def test_scope_const(): @@ -61,7 +75,9 @@ def test_scope_const(): [ap + 4] = f.x; ap++ """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 5; ap++ [ap + 1] = 1234; ap++ [ap + 2] = 1234; ap++ @@ -69,51 +85,65 @@ def test_scope_const(): [ap + 3] = 5; ap++ [ap + 4] = 1234; ap++ """ + ) def test_pow_failure(): - verify_exception("""\ + verify_exception( + """\ func foo(x : felt): tempvar y = x ** 2 end -""", """ +""", + """ file:?:?: Operator '**' is only supported for constant values. tempvar y = x ** 2 ^****^ -""") - verify_exception("""\ +""", + ) + verify_exception( + """\ const X = 2 const Y = 2 ** (2 * 3) const Z = 2 ** (X * 3) -""", """ +""", + """ file:?:?: Identifier 'X' is not allowed in this context. const Z = 2 ** (X * 3) ^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) def test_referenced_before_definition_failure(): - verify_exception(""" + verify_exception( + """ const x = 5 func f(): [ap + 1] = x; ap++ const x = 1234 end -""", """ +""", + """ file:?:?: Identifier 'x' referenced before definition. [ap + 1] = x; ap++ ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ foo.x = 6 func foo(): const x = 6 end -""", """ +""", + """ file:?:?: Identifier 'foo.x' referenced before definition. foo.x = 6 ^***^ -""") +""", + ) def test_assign_future_label(): @@ -126,12 +156,15 @@ def test_assign_future_label(): [ap] = 8; ap++ """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 2; ap++ [ap] = 4; ap++ [ap] = 6; ap++ [ap] = 8; ap++ """ + ) def test_temporary_variable(): @@ -155,7 +188,9 @@ def test_temporary_variable(): tempvar h = nondet %{ 5**i %} """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [ap + (-1)] + [fp + (-3)]; ap++ ap += 3 [ap] = [ap + (-4)]; ap++ @@ -171,51 +206,66 @@ def test_temporary_variable(): %{ memory[ap] = int(5**i) %} ap += 1 """ + ) def test_temporary_variable_failures(): - verify_exception(""" + verify_exception( + """ tempvar x : felt = cast([ap], felt*) -""", """ +""", + """ file:?:?: Cannot assign an expression of type 'felt*' to a temporary variable of type 'felt'. tempvar x : felt = cast([ap], felt*) ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ tempvar _ = 0 -""", """ +""", + """ file:?:?: Reference name cannot be '_'. tempvar _ = 0 ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ struct T: member x : felt member y : felt end tempvar a : T = nondet %{ 1 %} -""", """ +""", + """ file:?:?: Hint tempvars must be of type felt. tempvar a : T = nondet %{ 1 %} ^************^ -""") +""", + ) def test_tempvar_modifier_failures(): - verify_exception(""" + verify_exception( + """ func main(): tempvar local x = 5 end -""", """ +""", + """ file:?:?: Unexpected modifier 'local'. tempvar local x = 5 ^***^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ tempvar x = [ap - 1] + [fp - 3] [x] = [[ap]] -""", """ +""", + """ file:?:?: While expanding the reference 'x' in: [x] = [[ap]] ^ @@ -224,7 +274,9 @@ def test_tempvar_modifier_failures(): ^ Preprocessed instruction: [[ap + (-1)]] = [[ap]] -""", exc_type=InstructionBuilderError) +""", + exc_type=InstructionBuilderError, + ) def test_static_assert(): @@ -235,70 +287,91 @@ def test_static_assert(): static_assert x + 7 == ap + 4 """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ap += 3 """ + ) def test_static_assert_failures(): - verify_exception(""" + verify_exception( + """ static_assert 3 + fp + 10 == 0 + fp + 14 -""", """ +""", + """ file:?:?: Static assert failed: fp + 13 != fp + 14. static_assert 3 + fp + 10 == 0 + fp + 14 ^**************************************^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ let x = ap ap += 3 static_assert x + 7 == 0 -""", """ +""", + """ file:?:?: Static assert failed: ap + 4 != 0. static_assert x + 7 == 0 ^**********************^ -""") - - -@pytest.mark.parametrize('last_statement', [ - 'jmp body if [ap] != 0', - 'ap += 0', - '[ap] = [ap]', - '[ap] = [ap]; ap++', -]) +""", + ) + + +@pytest.mark.parametrize( + "last_statement", + [ + "jmp body if [ap] != 0", + "ap += 0", + "[ap] = [ap]", + "[ap] = [ap]; ap++", + ], +) def test_func_failures(last_statement): - verify_exception(f""" + verify_exception( + f""" func f(x): body: ret {last_statement} end -""", """ +""", + """ file:?:?: Function must end with a return instruction or a jump. func f(x): ^ -""") +""", + ) def test_func_modifier_failures(): - verify_exception(f""" + verify_exception( + f""" func f(local x): ret end -""", """ +""", + """ file:?:?: Unexpected modifier 'local'. func f(local x): ^***^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" func f(x) -> (local y): ret end -""", """ +""", + """ file:?:?: Unexpected modifier 'local'. func f(x) -> (local y): ^***^ -""") +""", + ) def test_return(): @@ -320,7 +393,9 @@ def test_return(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 1; ap++ [ap] = [fp]; ap++ [ap] = [fp + 1] + 2; ap++ @@ -339,57 +414,73 @@ def test_return(): ret ret """ + ) def test_return_failures(): # Named after positional. - verify_exception(""" + verify_exception( + """ func f() -> (a, b, c): return (a=1, b=1, [fp] + 1) end -""", """ +""", + """ file:?:?: Positional arguments must not appear after named arguments. return (a=1, b=1, [fp] + 1) ^******^ -""") +""", + ) # Wrong num. - verify_exception(""" + verify_exception( + """ func f() -> (a, b, c, d): return (1, [fp] + 1) end -""", """ +""", + """ file:?:?: Expected exactly 4 expressions, got 2. return (1, [fp] + 1) ^******************^ -""") +""", + ) # Wrong num. - verify_exception(""" + verify_exception( + """ func f() -> (a, b): return () end -""", """ +""", + """ file:?:?: Expected exactly 2 expressions, got 0. return () ^*******^ -""") +""", + ) # Unknown name. - verify_exception(""" + verify_exception( + """ func f() -> (a, b, c): return (a=1, d=1, [fp] + 1) end -""", """ +""", + """ file:?:?: Expected named arg 'b' found 'd'. return (a=1, d=1, [fp] + 1) ^ -""") +""", + ) # Not in func. - verify_exception(""" + verify_exception( + """ return (a=1, [fp] + 1) -""", """ +""", + """ file:?:?: return cannot be used outside of a function. return (a=1, [fp] + 1) ^********************^ -""") +""", + ) def test_tail_call(): @@ -402,8 +493,11 @@ def test_tail_call(): end """ program = preprocess_str( - code=code, prime=PRIME, main_scope=ScopedName.from_string('test_scope')) - assert program.format() == """\ + code=code, prime=PRIME, main_scope=ScopedName.from_string("test_scope") + ) + assert ( + program.format() + == """\ [ap] = [fp + (-3)]; ap++ call rel -1 ret @@ -411,41 +505,50 @@ def test_tail_call(): call rel -5 ret """ + ) def test_tail_call_failure(): - verify_exception(""" + verify_exception( + """ func g() -> (a): return (a=0) end return g() -""", """ +""", + """ file:?:?: return cannot be used outside of a function. return g() ^********^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ func g() -> (a): return (a=0) end func f(x, y) -> (a, b, c, d, e): return g() end -""", """ +""", + """ file:?:?: Cannot convert the return type of g to the return type of f. return g() ^********^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ func g{x, y}() -> (a): return (a=0) end func f{y, x}() -> (a): return g() end -""", """ +""", + """ file:?:?: Cannot convert the implicit arguments of g to the implicit arguments of f. return g() ^********^ @@ -457,30 +560,37 @@ def test_tail_call_failure(): file:?:? func f{y, x}() -> (a): ^**^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ func f(x, y) -> (a, b, c, d, e): return g() end -""", """ +""", + """ file:?:?: Unknown identifier 'g'. return g() ^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ func g(x, y) -> (a : felt): return (a=5) end func f(x, y) -> (a : felt*): return g(x, y) end -""", """ +""", + """ file:?:?: Cannot convert the return type of g to the return type of f. return g(x, y) ^************^ -""") +""", + ) def test_function_call(): @@ -498,7 +608,9 @@ def test_function_call(): res.c = 1 """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [fp + (-4)]; ap++ call rel 5 [ap] = 1; ap++ @@ -515,6 +627,7 @@ def test_function_call(): call rel -23 [ap + (-1)] = 1 """ + ) def test_func_args(): @@ -533,28 +646,36 @@ def test_func_args(): """ program = preprocess_str(code=code, prime=PRIME, main_scope=scope) reference_x = program.instructions[-1].flow_tracking_data.resolve_reference( - reference_manager=program.reference_manager, name=scope + 'f.x') - assert reference_x.value.format() == '[cast(fp + (-6), felt*)]' + reference_manager=program.reference_manager, name=scope + "f.x" + ) + assert reference_x.value.format() == "[cast(fp + (-6), felt*)]" reference_y = program.instructions[-1].flow_tracking_data.resolve_reference( - reference_manager=program.reference_manager, name=scope + 'f.y') - assert reference_y.value.format() == f'[cast(fp + (-5), {scope}.T*)]' + reference_manager=program.reference_manager, name=scope + "f.y" + ) + assert reference_y.value.format() == f"[cast(fp + (-5), {scope}.T*)]" reference_z = program.instructions[-1].flow_tracking_data.resolve_reference( - reference_manager=program.reference_manager, name=scope + 'f.z') - assert reference_z.value.format() == f'[cast(fp + (-3), {scope}.T**)]' - assert program.format() == """\ + reference_manager=program.reference_manager, name=scope + "f.z" + ) + assert reference_z.value.format() == f"[cast(fp + (-3), {scope}.T**)]" + assert ( + program.format() + == """\ [fp + (-6)] = 1; ap++ [fp + (-5)] = 2; ap++ [[fp + (-3)] + 1] = [fp + (-4)]; ap++ ret """ + ) def test_func_args_failures(): - verify_exception(""" + verify_exception( + """ func f(x): [ap] = [x] + 1 end -""", """ +""", + """ file:?:?: While expanding the reference 'x' in: [ap] = [x] + 1 ^ @@ -563,7 +684,9 @@ def test_func_args_failures(): ^ Preprocessed instruction: [ap] = [[fp + (-3)]] + 1 -""", exc_type=InstructionBuilderError) +""", + exc_type=InstructionBuilderError, + ) def test_with_statement(): @@ -579,13 +702,16 @@ def test_with_statement(): [ap] = x """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 0 [ap] = 1 [ap] = 2 [ap] = 1000 [ap] = 1001 """ + ) def test_with_statement_locals(): @@ -604,7 +730,9 @@ def test_with_statement_locals(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ret ap += 2 [fp] = 0 @@ -612,38 +740,48 @@ def test_with_statement_locals(): [fp + 1] = [ap + (-1)] ret """ + ) def test_with_statement_failure(): - verify_exception(""" + verify_exception( + """ with x: [ap] = [ap] end -""", """ +""", + """ file:?:?: Unknown reference 'x'. with x: ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ const x = 0 with x: [ap] = [ap] end -""", """ +""", + """ file:?:?: Expected 'x' to be a reference, found: const. with x: ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ let x = 0 with x as y: [ap] = [ap] end -""", """ +""", + """ file:?:?: The 'as' keyword is not supported in 'with' statements. with x as y: ^ -""") +""", + ) def test_implicit_args(): @@ -685,7 +823,9 @@ def test_implicit_args(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [fp + (-1234)]; ap++ [ap] = [fp + (-1233)]; ap++ ret @@ -710,28 +850,36 @@ def test_implicit_args(): [ap] = [ap + (-2)]; ap++ ret """ + ) def test_implicit_args_failures(): - verify_exception(""" + verify_exception( + """ func f{x}(x): ret end -""", """ +""", + """ file:?:?: Arguments and return values cannot have the same name of an implicit argument. func f{x}(x): ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f{x}() -> (x): ret end -""", """ +""", + """ file:?:?: Arguments and return values cannot have the same name of an implicit argument. func f{x}() -> (x): ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f{x}(): ret end @@ -740,15 +888,18 @@ def test_implicit_args_failures(): f() ret end -""", """ +""", + """ file:?:?: While trying to retrieve the implicit argument 'x' in: f() ^*^ file:?:?: Unknown identifier 'x'. func f{x}(): ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f{x}(y): ret end @@ -760,26 +911,32 @@ def test_implicit_args_failures(): f(1) ret end -""", """ +""", + """ file:?:?: While trying to update the implicit return value 'x' in: f(1) ^**^ file:?:?: 'x' cannot be used as an implicit return value. Consider using a 'with' statement. func f{x}(y): ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f{x}(): let x = cast(0, felt*) return () end -""", """ +""", + """ file:?:?: Reference rebinding must preserve the reference type. Previous type: 'felt', new type: \ 'felt*'. let x = cast(0, felt*) ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f{x}(): ret end @@ -788,14 +945,16 @@ def test_implicit_args_failures(): f() ret end -""", """ +""", + """ file:?:?: While trying to update the implicit return value 'x' in: f() ^*^ file:?:?: Redefinition of 'test_scope.g.x'. func f{x}(): ^ -""") +""", + ) def test_implcit_argument_bindings(): @@ -810,7 +969,9 @@ def test_implcit_argument_bindings(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ret [ap] = [fp + (-5)]; ap++ [ap] = [fp + (-3)]; ap++ @@ -820,10 +981,12 @@ def test_implcit_argument_bindings(): [ap] = [ap + (-3)]; ap++ ret """ + ) def test_implcit_argument_bindings_failures(): - verify_exception(""" + verify_exception( + """ func foo{x}(y) -> (z): ret end @@ -832,12 +995,15 @@ def test_implcit_argument_bindings_failures(): let x = foo{5}(0) ret end -""", """ +""", + """ file:?:?: Implicit argument binding must be of the form: arg_name=var. let x = foo{5}(0) ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func foo{x}(y) -> (z): ret end @@ -847,12 +1013,15 @@ def test_implcit_argument_bindings_failures(): let (res) = foo{y=x}(0) ret end -""", """ +""", + """ file:?:?: Unexpected implicit argument binding: y. let (res) = foo{y=x}(0) ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func foo{x}(y) -> (z): ret end @@ -861,11 +1030,13 @@ def test_implcit_argument_bindings_failures(): foo{x=2}(0) ret end -""", """ +""", + """ file:?:?: Implicit argument binding must be an identifier. foo{x=2}(0) ^ -""") +""", + ) def test_func_args_scope(): @@ -882,7 +1053,9 @@ def test_func_args_scope(): [ap + 5] = f.Args.z; ap++ """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 1234; ap++ [fp + (-5)] = 1; ap++ [fp + (-4)] = 2; ap++ @@ -891,6 +1064,7 @@ def test_func_args_scope(): [ap + 4] = 1234; ap++ [ap + 5] = 2; ap++ """ + ) def test_func_args_and_rets_scope(): @@ -908,7 +1082,9 @@ def test_func_args_and_rets_scope(): [ap + 6] = f.Return.x; ap++ """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 1234; ap++ [fp + (-5)] = 1; ap++ [fp + (-4)] = 2; ap++ @@ -918,6 +1094,7 @@ def test_func_args_and_rets_scope(): [ap + 5] = 0; ap++ [ap + 6] = 2; ap++ """ + ) def test_func_named_args(): @@ -934,17 +1111,21 @@ def test_func_named_args(): call f """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ret [ap + 2] = 2; ap++ [ap + (-1)] = 0; ap++ [ap + (-1)] = 1; ap++ call rel -7 """ + ) def test_func_named_args_failures(): - verify_exception(""" + verify_exception( + """ func f(x, y, z): ret end @@ -954,11 +1135,13 @@ def test_func_named_args_failures(): f_args.x = 0; ap++ static_assert f_args + f.Args.SIZE == ap call f -""", """ +""", + """ file:?:?: Static assert failed: ap + 1 != ap. static_assert f_args + f.Args.SIZE == ap ^**************************************^ -""") +""", + ) def test_function_call_by_value_args(): @@ -979,7 +1162,9 @@ def test_function_call_by_value_args(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 2; ap++ [ap] = [fp + (-5)]; ap++ [ap] = [fp + (-4)]; ap++ @@ -990,15 +1175,20 @@ def test_function_call_by_value_args(): call rel -8 ret """ + ) -@pytest.mark.parametrize('test_line, expected_type, actual_type, arrow', [ - ('f(1, y=13)', 'T', 'felt', '^^'), - ('f(1, y=&y)', 'T', 'T*', '^^'), - ('f(1, y=t)', 'T', 'S', '^'), -]) +@pytest.mark.parametrize( + "test_line, expected_type, actual_type, arrow", + [ + ("f(1, y=13)", "T", "felt", "^^"), + ("f(1, y=&y)", "T", "T*", "^^"), + ("f(1, y=t)", "T", "S", "^"), + ], +) def test_func_by_value_args_failures(test_line, expected_type, actual_type, arrow): - verify_exception(f""" + verify_exception( + f""" struct T: member s : felt member t : felt @@ -1013,11 +1203,14 @@ def test_func_by_value_args_failures(test_line, expected_type, actual_type, arro {test_line} ret end -""", f""" +""", + f""" file:?:?: Expected expression of type '{expected_type}', got '{actual_type}'. {test_line} {arrow} -""", main_scope=ScopedName()) +""", + main_scope=ScopedName(), + ) def test_func_by_value_return(): @@ -1032,23 +1225,30 @@ def test_func_by_value_return(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [fp + (-4)]; ap++ [ap] = [fp + (-3)]; ap++ [ap] = [ap + (-102)]; ap++ [ap] = [ap + (-102)]; ap++ ret """ + ) -@pytest.mark.parametrize('jmp_code', [ - 'jmp loop if [ap] != 0', - 'jmp rel 3', - 'jmp abs 3', - 'jmp rel [ap + 3] if [ap] != 0', -]) +@pytest.mark.parametrize( + "jmp_code", + [ + "jmp loop if [ap] != 0", + "jmp rel 3", + "jmp abs 3", + "jmp rel [ap + 3] if [ap] != 0", + ], +) def test_function_flow_revoke(jmp_code): - verify_exception(f""" + verify_exception( + f""" func foo(): loop: {jmp_code} @@ -1061,7 +1261,8 @@ def test_function_flow_revoke(jmp_code): assert x = 0 ret end -""", """ +""", + """ file:?:?: Reference 'x' was revoked. assert x = 0 ^ @@ -1069,7 +1270,8 @@ def test_function_flow_revoke(jmp_code): file:?:? tempvar x = 0 ^ -""") +""", + ) def test_scope_label(): @@ -1090,7 +1292,9 @@ def test_scope_label(): call f """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ jmp rel 0 jmp rel 4 call rel 2 @@ -1102,16 +1306,17 @@ def test_scope_label(): jmp rel -10 call rel -12 """ + ) def test_import(): files = { - '.': """ + ".": """ from a import f as g, h as h2 call g call h2 """, - 'a': """ + "a": """ func f(): jmp f end @@ -1119,132 +1324,164 @@ def test_import(): func h(): jmp h end -""" +""", } program = preprocess_codes( - codes=[(files['.'], '.')], - pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files))) + codes=[(files["."], ".")], + pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), + ) - assert program.format() == """\ + assert ( + program.format() + == """\ jmp rel 0 jmp rel 0 call rel -4 call rel -4 """ + ) def test_import_identifiers(): # Define files used in this test. files = { - '.': """ + ".": """ from a.b.c import alpha as x from a.b.c import beta from a.b.c import xi """, - 'a.b.c': """ + "a.b.c": """ from tau import xi const alpha = 0 const beta = 1 const gamma = 2 """, - 'tau': """ + "tau": """ const xi = 42 -""" +""", } # Prepare auxiliary functions for tests. scope = ScopedName.from_string - def get_full_name(name, curr_scope=''): + def get_full_name(name, curr_scope=""): try: return program.identifiers.search( - accessible_scopes=[scope(curr_scope)], name=scope(name)).get_canonical_name() + accessible_scopes=[scope(curr_scope)], name=scope(name) + ).get_canonical_name() except IdentifierError: return None # Preprocess program. program = preprocess_codes( - codes=[(files['.'], '.')], + codes=[(files["."], ".")], pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), - main_scope=scope('__main__')) + main_scope=scope("__main__"), + ) # Verify identifiers are resolved correctly. - assert get_full_name('x', '__main__') == scope('a.b.c.alpha') - assert get_full_name('beta', '__main__') == scope('a.b.c.beta') - assert get_full_name('xi', '__main__') == scope('tau.xi') + assert get_full_name("x", "__main__") == scope("a.b.c.alpha") + assert get_full_name("beta", "__main__") == scope("a.b.c.beta") + assert get_full_name("xi", "__main__") == scope("tau.xi") - assert get_full_name('alpha', 'a.b.c') == scope('a.b.c.alpha') - assert get_full_name('beta', 'a.b.c') == scope('a.b.c.beta') - assert get_full_name('gamma', 'a.b.c') == scope('a.b.c.gamma') - assert get_full_name('xi', 'a.b.c') == scope('tau.xi') + assert get_full_name("alpha", "a.b.c") == scope("a.b.c.alpha") + assert get_full_name("beta", "a.b.c") == scope("a.b.c.beta") + assert get_full_name("gamma", "a.b.c") == scope("a.b.c.gamma") + assert get_full_name("xi", "a.b.c") == scope("tau.xi") - assert get_full_name('xi', 'tau') == scope('tau.xi') + assert get_full_name("xi", "tau") == scope("tau.xi") # Verify inaccessible identifiers. - assert get_full_name('alpha', '__main__') is None - assert get_full_name('gamma', '__main__') is None - assert get_full_name('a.b.c.alpha', '__main__') is None - assert get_full_name('tau.xi', '__main__') is None + assert get_full_name("alpha", "__main__") is None + assert get_full_name("gamma", "__main__") is None + assert get_full_name("a.b.c.alpha", "__main__") is None + assert get_full_name("tau.xi", "__main__") is None def test_import_errors(): # Inaccessible import. - verify_exception(""" + verify_exception( + """ from foo import bar -""", """ +""", + """ file:?:?: Could not load module 'foo'. Error: 'foo' from foo import bar ^*^ -""", files={}, exc_type=LocationError) +""", + files={}, + exc_type=LocationError, + ) # Ignoring aliasing. - verify_exception(""" + verify_exception( + """ from foo import bar as notbar [ap] = bar -""", """ +""", + """ file:?:?: Unknown identifier 'bar'. [ap] = bar ^*^ -""", files={'foo': 'const bar = 3'}) +""", + files={"foo": "const bar = 3"}, + ) # Identifier redefinition. - verify_exception(""" + verify_exception( + """ const bar = 0 from foo import bar -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.bar'. from foo import bar ^*^ -""", files={'foo': 'const bar=0'}) +""", + files={"foo": "const bar=0"}, + ) - verify_exception(f""" + verify_exception( + f""" const lambda = 0 from foo import bar as lambda -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.lambda'. from foo import bar as lambda ^****^ -""", files={'foo': 'const bar=0'}) +""", + files={"foo": "const bar=0"}, + ) - verify_exception('from foo import bar', """ \ + verify_exception( + "from foo import bar", + """ \ file:?:?: Cannot import 'bar' from 'foo'. from foo import bar ^*^ -""", files={'foo': ''}) +""", + files={"foo": ""}, + ) def test_error_scope_redefinition(): - verify_exception(""" + verify_exception( + """ from a import b from a.b import c -""", """ +""", + """ Scope 'a.b' collides with a different identifier of type 'const'. -""", files={'a': 'const b = 0', 'a.b': 'const c = 1'}) +""", + files={"a": "const b = 0", "a.b": "const c = 1"}, + ) def test_scope_failures(): - verify_exception(""" + verify_exception( + """ func f(): const x = 5 ret @@ -1253,12 +1490,15 @@ def test_scope_failures(): [ap] = x; ap++ ret end -""", """ +""", + """ file:?:?: Unknown identifier 'x'. [ap] = x; ap++ ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f(): label: ret @@ -1267,42 +1507,54 @@ def test_scope_failures(): call label ret end -""", """ +""", + """ file:?:?: Unknown identifier 'label'. call label ^***^ -""") +""", + ) def test_const_failures(): - verify_exception(""" + verify_exception( + """ const x = y -""", """ +""", + """ file:?:?: Unknown identifier 'y'. const x = y ^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ const x = 0 [ap] = x.y.z -""", """ +""", + """ file:?:?: Unexpected '.' after 'test_scope.x' which is const. [ap] = x.y.z ^***^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ const x = [ap] + 5 -""", """ +""", + """ file:?:?: Expected a constant expression. const x = [ap] + 5 ^******^ -""") +""", + ) def test_labels(): - scope = ScopedName.from_string('my.cool.scope') - program = preprocess_str(""" + scope = ScopedName.from_string("my.cool.scope") + program = preprocess_str( + """ const x = 7 a0: [ap] = x; ap++ # Size: 2. @@ -1317,27 +1569,30 @@ def test_labels(): jmp a3 if [ap] != 0 # Size: 2. call a3 # Size: 2. a3: -""", prime=PRIME, main_scope=scope) +""", + prime=PRIME, + main_scope=scope, + ) program_labels = { name: identifier_definition.pc for name, identifier_definition in program.identifiers.get_scope(scope).identifiers.items() - if isinstance(identifier_definition, LabelDefinition)} - assert program_labels == {'a0': 0, 'a1': 4, 'a2': 6, 'a3': 14} + if isinstance(identifier_definition, LabelDefinition) + } + assert program_labels == {"a0": 0, "a1": 4, "a2": 6, "a3": 14} def test_process_file_scope(): # Verify the good scenario. - valid_scope = ScopedName.from_string('some.valid.scope') - program = preprocess_str('const x = 4', prime=PRIME, main_scope=valid_scope) + valid_scope = ScopedName.from_string("some.valid.scope") + program = preprocess_str("const x = 4", prime=PRIME, main_scope=valid_scope) module = CairoModule(cairo_file=program, module_name=valid_scope) - assert program.identifiers.as_dict() == { - valid_scope + 'x': ConstDefinition(4) - } + assert program.identifiers.as_dict() == {valid_scope + "x": ConstDefinition(4)} def test_label_resolution(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" [ap] = 7; ap++ # Size: 2. loop: @@ -1350,8 +1605,12 @@ def test_label_resolution(): jmp loop # Size: 2. jmp loop if [ap] != 0 # Size: 2. call loop # Size 2. -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ [ap] = 7; ap++ [ap] = [ap + (-1)] + 1 jmp rel 7 @@ -1362,110 +1621,145 @@ def test_label_resolution(): jmp rel -11 if [ap] != 0 call rel -13 """ + ) def test_labels_failures(): - verify_exception(""" + verify_exception( + """ jmp x.y.z -""", """ +""", + """ file:?:?: Unknown identifier 'x'. jmp x.y.z ^***^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ const x = 0 jmp x -""", """ +""", + """ file:?:?: Expected a label name. Identifier 'x' is of type const. jmp x ^ -""") +""", + ) def test_redefinition_failures(): - verify_exception(""" + verify_exception( + """ name: const name = 0 -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.name'. const name = 0 ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ const name = 0 let name = ap -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.name'. let name = ap ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ let name = ap name: -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.name'. name: ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f(name, x, name): [ap + name] = 1 [ap + x] = 2 end -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.f.Args.name'. func f(name, x, name): ^**^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f() -> (name, x, name): [ap] = 1 [ap] = 2 end -""", """ +""", + """ file:?:?: Redefinition of 'test_scope.f.Return.name'. func f() -> (name, x, name): ^**^ -""") +""", + ) def test_directives(): - program = preprocess_str(code="""\ + program = preprocess_str( + code="""\ # This is a comment. %builtins ab cd ef [fp] = [fp] -""", prime=PRIME) - assert program.builtins == ['ab', 'cd', 'ef'] - assert program.format() == """\ +""", + prime=PRIME, + ) + assert program.builtins == ["ab", "cd", "ef"] + assert ( + program.format() + == """\ %builtins ab cd ef [fp] = [fp] """ + ) def test_directives_failures(): - verify_exception(""" + verify_exception( + """ [fp] = [fp] %builtins ab cd ef -""", """ +""", + """ file:?:?: Directives must appear at the top of the file. %builtins ab cd ef ^****************^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ %lang abc -""", """ +""", + """ file:?:?: Unsupported %lang directive. Are you using the correct compiler? %lang abc ^*******^ -""") +""", + ) def test_conditionals(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" let x = 2 if [ap] * 2 == [fp] + 3: let x = 3 @@ -1474,8 +1768,12 @@ def test_conditionals(): let x = 4 [ap] = x; ap++ end -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ [ap] = [ap] * 2; ap++ [ap] = [fp] + 3; ap++ [ap] = [ap + (-2)] - [ap + (-1)]; ap++ @@ -1484,58 +1782,84 @@ def test_conditionals(): jmp rel 4 [ap] = 4; ap++ """ - program = preprocess_str(code=""" + ) + program = preprocess_str( + code=""" if [ap] == [fp]: ret else: [ap] = [ap] end [fp] = [fp] -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ [ap] = [ap] - [fp]; ap++ jmp rel 3 if [ap + (-1)] != 0 ret [ap] = [ap] [fp] = [fp] """ - program = preprocess_str(code=""" + ) + program = preprocess_str( + code=""" if [ap] == 0: ret end [fp] = [fp] -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ jmp rel 3 if [ap] != 0 ret [fp] = [fp] """ + ) # No jump if there is no "Non-equal" block. - program = preprocess_str(code=""" + program = preprocess_str( + code=""" if [ap] == 0: [fp + 1] = [fp + 1] end [fp] = [fp] -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ jmp rel 3 if [ap] != 0 [fp + 1] = [fp + 1] [fp] = [fp] """ - program = preprocess_str(code=""" + ) + program = preprocess_str( + code=""" if [ap] != 0: ret end [fp] = [fp] -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ jmp rel 4 if [ap] != 0 jmp rel 3 ret [fp] = [fp] """ + ) # With locals. - program = preprocess_str(code=""" + program = preprocess_str( + code=""" func a(): alloc_locals local a @@ -1551,8 +1875,12 @@ def test_conditionals(): [fp] = [fp] ret end -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ ap += 3 jmp rel 8 if [ap] != 0 [fp + 2] = 6 @@ -1563,6 +1891,7 @@ def test_conditionals(): [fp] = [fp] ret """ + ) def test_hints_good(): @@ -1613,64 +1942,80 @@ def test_hints_unindent(): def test_hints_failures(): - verify_exception(""" + verify_exception( + """ %{ hint %} -""", """ +""", + """ file:?:?: Found a hint at the end of a code block. Hints must be followed by an instruction. %{ ^^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ func f(): %{ hint %} end [ap] = 1 -""", """ +""", + """ file:?:?: Found a hint at the end of a code block. Hints must be followed by an instruction. %{ ^^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ [fp] = [fp] %{ hint %} label: [fp] = [fp] -""", """ +""", + """ file:?:?: Hints before labels are not allowed. %{ ^^ -""") +""", + ) def test_builtins_failures(): - verify_exception(""" + verify_exception( + """ %builtins a %builtins b -""", """ +""", + """ file:?:?: Redefinition of builtins directive. %builtins b ^*********^ -""") +""", + ) def test_builtin_directive_duplicate_entry(): - verify_exception(""" + verify_exception( + """ %builtins pedersen ecdsa pedersen -""", """ +""", + """ file:?:?: The builtin 'pedersen' appears twice in builtins directive. %builtins pedersen ecdsa pedersen ^*******************************^ -""") +""", + ) def test_references(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" call label1 label1: ret @@ -1697,8 +2042,12 @@ def test_references(): let y = ap [y] = 0; ap++ [y] = 0; ap++ -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ call rel 2 ret [ap + 1] = 1; ap++ @@ -1718,11 +2067,13 @@ def test_references(): [ap] = 0; ap++ [ap + (-1)] = 0; ap++ """ + ) def test_reference_type_deduction(): scope = TEST_SCOPE - program = preprocess_str(code=""" + program = preprocess_str( + code=""" struct T: member t : felt end @@ -1735,7 +2086,10 @@ def test_reference_type_deduction(): let e : felt* = [b] return () end -""", prime=PRIME, main_scope=scope) +""", + prime=PRIME, + main_scope=scope, + ) def get_reference_type(name): identifier_definition = program.identifiers.get_by_full_name(scope + name) @@ -1744,15 +2098,16 @@ def get_reference_type(name): _, expr_type = simplify_type_system(identifier_definition.references[0].value) return expr_type - assert get_reference_type('foo.a').format() == f'{scope}.T***' - assert get_reference_type('foo.b').format() == f'{scope}.T**' - assert get_reference_type('foo.c').format() == 'felt*' - assert get_reference_type('foo.d').format() == f'{scope}.T*' - assert get_reference_type('foo.e').format() == 'felt*' + assert get_reference_type("foo.a").format() == f"{scope}.T***" + assert get_reference_type("foo.b").format() == f"{scope}.T**" + assert get_reference_type("foo.c").format() == "felt*" + assert get_reference_type("foo.d").format() == f"{scope}.T*" + assert get_reference_type("foo.e").format() == "felt*" def test_rebind_reference(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" struct T: member pad0 : felt member pad1 : felt @@ -1765,28 +2120,37 @@ def test_rebind_reference(): let x : T* = cast(fp - 3, T*) [cast(x, felt)] = x.t [y] = [y] -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ [ap + 1] = [ap + 3] [fp + (-3)] = [fp + (-1)] [ap + 3] = [ap + 3] """ + ) def test_rebind_reference_failures(): - verify_exception(""" + verify_exception( + """ let x = cast(ap, felt*) let x = cast(ap, felt**) -""", """ +""", + """ file:?:?: Reference rebinding must preserve the reference type. Previous type: 'felt*', \ new type: 'felt**'. let x = cast(ap, felt**) ^ -""") +""", + ) def test_reference_over_calls(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" func f(): ap += 3 jmp label1 if [ap] != 0; ap++ @@ -1801,8 +2165,12 @@ def test_reference_over_calls(): [x] = 0 call f [x] = 0 -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ ap += 3 jmp rel 4 if [ap] != 0; ap++ [ap] = [ap]; ap++ @@ -1813,10 +2181,12 @@ def test_reference_over_calls(): call rel -11 [ap + (-6)] = 0 """ + ) def test_reference_over_calls_failures(): - verify_exception(f""" + verify_exception( + f""" func f(): ap += 3 jmp label1 if [ap] != 0 @@ -1828,7 +2198,8 @@ def test_reference_over_calls_failures(): let x = ap + 1 call f [x] = 0 -""", """ +""", + """ file:?:?: Reference 'x' was revoked. [x] = 0 ^ @@ -1836,9 +2207,11 @@ def test_reference_over_calls_failures(): file:?:? let x = ap + 1 ^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" func f(): ap += 3 jmp label1 if [ap] != 0 @@ -1851,7 +2224,8 @@ def test_reference_over_calls_failures(): let x = ap + 1 call f [x] = 0 -""", """ +""", + """ file:?:?: Reference 'x' was revoked. [x] = 0 ^ @@ -1859,41 +2233,53 @@ def test_reference_over_calls_failures(): file:?:? let x = ap + 1 ^ -""") - - -@pytest.mark.parametrize('revoking_instruction, has_def_location', [ - ('ap += [fp]', True), - ('call label', True), - ('call rel 0', True), - ('ret', False), - ('jmp label', False), - ('jmp rel 0', False), - ('jmp abs 0', False), -]) +""", + ) + + +@pytest.mark.parametrize( + "revoking_instruction, has_def_location", + [ + ("ap += [fp]", True), + ("call label", True), + ("call rel 0", True), + ("ret", False), + ("jmp label", False), + ("jmp rel 0", False), + ("jmp abs 0", False), + ], +) def test_references_revoked(revoking_instruction, has_def_location): - def_loction_str = """\ + def_loction_str = ( + """\ Reference was defined here: file:?:? let x = ap ^ -""" if has_def_location else '' +""" + if has_def_location + else "" + ) - verify_exception(f""" + verify_exception( + f""" label: let x = ap {revoking_instruction} [x] = 0 -""", f""" +""", + f""" file:?:?: Reference 'x' was revoked. [x] = 0 ^ {def_loction_str} -""") +""", + ) def test_references_revoked_multiple_location(): - verify_exception(f""" + verify_exception( + f""" if [ap] == 0: let x = ap else: @@ -1902,7 +2288,8 @@ def test_references_revoked_multiple_location(): end ap += [fp] [x] = 0 -""", """ +""", + """ file:?:?: Reference 'x' was revoked. [x] = 0 @@ -1914,15 +2301,18 @@ def test_references_revoked_multiple_location(): file:?:? let x = ap ^ -""") +""", + ) def test_references_failures(): - verify_exception(""" + verify_exception( + """ let ref = [fp] let ref2 = ref [ref2] = [[fp]] -""", """ +""", + """ file:?:?: While expanding the reference 'ref2' in: [ref2] = [[fp]] ^**^ @@ -1934,22 +2324,27 @@ def test_references_failures(): ^**^ Preprocessed instruction: [[fp]] = [[fp]] -""", exc_type=InstructionBuilderError) - - -@pytest.mark.parametrize('valid, has0, has1, has2', [ - (False, True, True, True), - (False, False, True, True), - (False, True, False, True), - (False, True, True, False), - (False, False, True, False), - (False, False, False, True), - (True, True, False, False), -]) +""", + exc_type=InstructionBuilderError, + ) + + +@pytest.mark.parametrize( + "valid, has0, has1, has2", + [ + (False, True, True, True), + (False, False, True, True), + (False, True, False, True), + (False, True, True, False), + (False, False, True, False), + (False, False, False, True), + (True, True, False, False), + ], +) def test_reference_flow_revokes(valid, has0, has1, has2): - def0 = 'let ref = [fp]' if has0 else '' - def1 = 'let ref = [fp + 1]' if has1 else '' - def2 = 'let ref = [fp + 2]' if has2 else '' + def0 = "let ref = [fp]" if has0 else "" + def1 = "let ref = [fp + 1]" if has1 else "" + def2 = "let ref = [fp + 2]" if has2 else "" code = f""" {def0} jmp b if [ap] != 0 @@ -1964,21 +2359,26 @@ def test_reference_flow_revokes(valid, has0, has1, has2): if valid: preprocess_str(code, prime=PRIME) else: - verify_exception(code, """ + verify_exception( + code, + """ file:?:?: Reference 'ref' was revoked. [ref] = [fp + 3] ^*^ -""") +""", + ) def test_implicit_arg_revocation(): - verify_exception(""" + verify_exception( + """ func foo{x}(y): foo(y=1) ap += [fp] return foo(y=2) end -""", """ +""", + """ file:?:?: While trying to retrieve the implicit argument 'x' in: return foo(y=2) ^******^ @@ -1989,11 +2389,13 @@ def test_implicit_arg_revocation(): file:?:? foo(y=1) ^******^ -""") +""", + ) def test_reference_flow_converge(): - program = preprocess_str(""" + program = preprocess_str( + """ if [ap] != 0: tempvar a = 1 else: @@ -2001,19 +2403,25 @@ def test_reference_flow_converge(): end assert a = a -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ jmp rel 6 if [ap] != 0 [ap] = 2; ap++ jmp rel 4 [ap] = 1; ap++ [ap + (-1)] = [ap + (-1)] """ + ) def test_typed_references(): scope = TEST_SCOPE - program = preprocess_str(code=""" + program = preprocess_str( + code=""" func main(): struct T: member pad0 : felt @@ -2038,24 +2446,30 @@ def test_typed_references(): [fp] = y.a + 1 ret end -""", prime=PRIME, main_scope=scope) +""", + prime=PRIME, + main_scope=scope, + ) def get_reference(name): scoped_name = scope + name assert isinstance(program.identifiers.get_by_full_name(scoped_name), ReferenceDefinition) return program.instructions[-1].flow_tracking_data.resolve_reference( - reference_manager=program.reference_manager, name=scoped_name) + reference_manager=program.reference_manager, name=scoped_name + ) - expected_type_x = mark_type_resolved(parse_type(f'{scope}.main.Struct*')) - assert simplify_type_system(get_reference('main.x').value)[1] == expected_type_x + expected_type_x = mark_type_resolved(parse_type(f"{scope}.main.Struct*")) + assert simplify_type_system(get_reference("main.x").value)[1] == expected_type_x - expected_type_y = mark_type_resolved(parse_type(f'{scope}.main.Struct')) - reference = get_reference('main.y') + expected_type_y = mark_type_resolved(parse_type(f"{scope}.main.Struct")) + reference = get_reference("main.y") assert simplify_type_system(reference.value)[1] == expected_type_y - assert reference.value.format() == f'[cast(ap + 10, {scope}.main.Struct*)]' - assert program.format() == """\ + assert reference.value.format() == f"[cast(ap + 10, {scope}.main.Struct*)]" + assert ( + program.format() + == """\ [fp] = [ap + 12] [fp] = [[ap + 12] + 3] [ap] = [[ap + 12] + 3]; ap++ @@ -2063,45 +2477,57 @@ def get_reference(name): [fp] = [ap + 11] + 1 ret """ + ) def test_typed_references_failures(): - verify_exception(f""" + verify_exception( + f""" let x = fp x.a = x.a -""", """ +""", + """ file:?:?: Cannot apply dot-operator to non-struct type 'felt'. x.a = x.a ^*^ -""", exc_type=CairoTypeError) - verify_exception(f""" +""", + exc_type=CairoTypeError, + ) + verify_exception( + f""" struct T: member z : felt end let x : T = ap x.z = x.z -""", """ +""", + """ file:?:?: Cannot assign an expression of type 'felt' to a reference of type 'test_scope.T'. let x : T = ap ^ -""") - verify_exception(f""" +""", + ) + verify_exception( + f""" struct T: member z : felt end let x : T* = [cast(ap, T*)] -""", """ +""", + """ file:?:?: Cannot assign an expression of type 'test_scope.T' to a reference of type 'test_scope.T*'. let x : T* = [cast(ap, T*)] ^^ -""") +""", + ) def test_return_value_reference(): scope = TEST_SCOPE - program = preprocess_str(code=""" + program = preprocess_str( + code=""" func foo() -> (val, x, y): ret end @@ -2116,27 +2542,35 @@ def test_return_value_reference(): let z = call abs 0 ret end -""", prime=PRIME, main_scope=scope) +""", + prime=PRIME, + main_scope=scope, + ) def get_reference(name): scoped_name = scope + name assert isinstance(program.identifiers.get_by_full_name(scoped_name), ReferenceDefinition) return program.instructions[-1].flow_tracking_data.resolve_reference( - reference_manager=program.reference_manager, name=scoped_name) + reference_manager=program.reference_manager, name=scoped_name + ) - expected_type = mark_type_resolved(parse_type( - f'{scope}.foo.{CodeElementFunction.RETURN_SCOPE}')) - assert simplify_type_system(get_reference('main.x').value)[1] == expected_type + expected_type = mark_type_resolved( + parse_type(f"{scope}.foo.{CodeElementFunction.RETURN_SCOPE}") + ) + assert simplify_type_system(get_reference("main.x").value)[1] == expected_type - expected_type = mark_type_resolved(parse_type( - f'{scope}.main.{CodeElementFunction.RETURN_SCOPE}')) - assert simplify_type_system(get_reference('main.y').value)[1] == expected_type + expected_type = mark_type_resolved( + parse_type(f"{scope}.main.{CodeElementFunction.RETURN_SCOPE}") + ) + assert simplify_type_system(get_reference("main.y").value)[1] == expected_type - expected_type = parse_type('felt') - assert simplify_type_system(get_reference('main.z').value)[1] == expected_type + expected_type = parse_type("felt") + assert simplify_type_system(get_reference("main.z").value)[1] == expected_type - assert program.format() == """\ + assert ( + program.format() + == """\ ret call rel -1 [ap] = 0; ap++ @@ -2145,48 +2579,63 @@ def get_reference(name): call abs 0 ret """ + ) def test_return_value_reference_failures(): - verify_exception(f""" + verify_exception( + f""" let x = call foo -""", """ +""", + """ file:?:?: Unknown identifier 'foo'. let x = call foo ^*^ -""") - verify_exception(f""" +""", + ) + verify_exception( + f""" func foo(): ret end let x = call foo [x.a] = 0 -""", """ +""", + """ file:?:?: Member 'a' does not appear in definition of struct 'test_scope.foo.Return'. [x.a] = 0 ^*^ -""", exc_type=CairoTypeError) - verify_exception(f""" +""", + exc_type=CairoTypeError, + ) + verify_exception( + f""" func foo(): ret end let x : unknown_type* = call foo -""", """ +""", + """ file:?:?: Unknown identifier 'unknown_type'. let x : unknown_type* = call foo ^**********^ -""") - verify_exception(f""" +""", + ) + verify_exception( + f""" struct T: member s : felt end let x : T* = cast(ap, T*) [ap] = x.a -""", """ +""", + """ file:?:?: Member 'a' does not appear in definition of struct 'test_scope.T'. [ap] = x.a ^*^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) def test_unpacking(): @@ -2211,7 +2660,9 @@ def test_unpacking(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = 5; ap++ [ap] = 6; ap++ [ap] = 1; ap++ @@ -2233,29 +2684,37 @@ def test_unpacking(): [fp + 2] = [ap + (-4)] ret """ + ) def test_unpacking_failures(): - verify_exception(f""" + verify_exception( + f""" func foo() -> (a): ret end let (a, b) = foo() -""", """ +""", + """ file:?:?: Expected 1 unpacking identifier, found 2. let (a, b) = foo() ^**^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" let (a, b) = 1 + 3 -""", """ +""", + """ file:?:?: Cannot unpack 1 + 3. let (a, b) = 1 + 3 ^***^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" struct T: member a : felt member b : felt @@ -2264,13 +2723,16 @@ def test_unpacking_failures(): ret end let (a, b, c) = foo() -""", """ +""", + """ file:?:?: Expected 2 unpacking identifiers, found 3. let (a, b, c) = foo() ^*****^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" struct T: member a : felt member b : felt @@ -2279,13 +2741,16 @@ def test_unpacking_failures(): ret end let (a, b : T) = foo() -""", """ +""", + """ file:?:?: Expected expression of type 'felt', got 'test_scope.T'. let (a, b : T) = foo() ^***^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" struct T: member a : felt member b : felt @@ -2302,14 +2767,17 @@ def test_unpacking_failures(): let (a, local b : S) = foo() ret end -""", """ +""", + """ file:?:?: Expected expression of type 'test_scope.T', got 'test_scope.S'. let (a, local b : S) = foo() ^*********^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" struct T: end @@ -2322,19 +2790,23 @@ def test_unpacking_failures(): let (local _ : T*) = foo() ret end -""", """ +""", + """ file:?:?: Reference name cannot be '_'. let (local _ : T*) = foo() ^**********^ -""") +""", + ) - verify_exception(f""" + verify_exception( + f""" func foo() -> (a): ret end let (a) = foo() [a] = [a] -""", """ +""", + """ file:?:?: While expanding the reference 'a' in: [a] = [a] ^ @@ -2343,77 +2815,98 @@ def test_unpacking_failures(): ^ Preprocessed instruction: [[ap + (-1)]] = [[ap + (-1)]] -""", exc_type=InstructionBuilderError) +""", + exc_type=InstructionBuilderError, + ) def test_unpacking_modifier_failure(): - verify_exception(""" + verify_exception( + """ func foo() -> (a, b): ret end let (a, local b) = foo() -""", """ +""", + """ file:?:?: Unexpected modifier 'local'. let (a, local b) = foo() ^***^ -""") +""", + ) def test_member_def_failures(): - verify_exception(""" + verify_exception( + """ struct T: member t end -""", """ +""", + """ file:?:?: Struct members must be explicitly typed (e.g., member x : felt). member t ^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ member t : felt -""", """ +""", + """ file:?:?: The member keyword may only be used inside a struct. member t : felt ^******^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ struct T: member local t end -""", """ +""", + """ file:?:?: Unexpected modifier 'local'. member local t ^***^ -""") +""", + ) def test_bad_struct(): - verify_exception(""" + verify_exception( + """ struct T: return() end -""", """ +""", + """ file:?:?: Unexpected statement inside a struct definition. return() ^******^ -""") +""", + ) def test_bad_type_annotation(): - verify_exception(""" + verify_exception( + """ func foo(): local a : foo ret end -""", """ +""", + """ file:?:?: Expected 'test_scope.foo' to be a struct. Found: 'function'. local a : foo ^*^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ func foo(): struct test: member a : foo* @@ -2421,13 +2914,16 @@ def test_bad_type_annotation(): ret end -""", """ +""", + """ file:?:?: Expected 'foo' to be a struct. Found: 'function'. member a : foo* ^*^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ func foo(): struct test: member a : foo.abc* @@ -2435,15 +2931,18 @@ def test_bad_type_annotation(): ret end -""", """ +""", + """ file:?:?: Unknown identifier 'test_scope.foo.abc'. member a : foo.abc* ^*****^ -""") +""", + ) def test_cast_failure(): - verify_exception(""" + verify_exception( + """ struct A: end @@ -2451,33 +2950,40 @@ def test_cast_failure(): let a = cast(5, A) return () end -""", """ +""", + """ file:?:?: Cannot cast 'felt' to 'test_scope.A'. let a = cast(5, A) ^********^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) def test_nested_function_failure(): - verify_exception(""" + verify_exception( + """ func foo(): func bar(): return() end return () end -""", """ +""", + """ file:?:?: Nested functions are not supported. func bar(): ^*^ Outer function was defined here: file:?:? func foo(): ^*^ -""") +""", + ) def test_namespace_inside_function_failure(): - verify_exception(""" + verify_exception( + """ func foo(): namespace MyNamespace: end @@ -2485,14 +2991,16 @@ def test_namespace_inside_function_failure(): end -""", """ +""", + """ file:?:?: Cannot define a namespace inside a function. namespace MyNamespace: ^*********^ Outer function was defined here: file:?:? func foo(): ^*^ -""") +""", + ) def test_struct_assignments(): @@ -2517,13 +3025,16 @@ def test_struct_assignments(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ap += 3 [fp] = [[fp + (-3)]] [fp + 1] = [[fp + (-3)] + 1] [fp + 2] = [[fp + (-3)] + 2] ret """ + ) code = f"""\ {struct_def} @@ -2533,7 +3044,9 @@ def test_struct_assignments(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [ap] = [[fp + (-3)]]; ap++ [ap] = [[fp + (-4)]]; ap++ [ap] = [[ap + (-1)]]; ap++ @@ -2548,6 +3061,7 @@ def test_struct_assignments(): [[ap + (-3)] + 2] = [ap + (-1)] ret """ + ) def test_continuous_structs(): @@ -2577,7 +3091,9 @@ def test_continuous_structs(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ [fp + (-8)] = 1 [fp + (-7)] = 2 [fp + (-6)] = 3 @@ -2586,6 +3102,7 @@ def test_continuous_structs(): [fp + (-3)] = 6 ret """ + ) def test_subscript_operator(): @@ -2775,12 +3292,15 @@ def test_tuple_assertions(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ap += 2 [fp] = [ap] [fp + 1] = [ap + 1] ret """ + ) def test_tuple_expression(): @@ -2804,7 +3324,9 @@ def test_tuple_expression(): end """ program = preprocess_str(code=code, prime=PRIME) - assert program.format() == """\ + assert ( + program.format() + == """\ ap += 4 [fp] = 1 [fp + 1] = [[fp]] @@ -2815,10 +3337,12 @@ def test_tuple_expression(): [fp] = [fp] ret """ + ) def test_tuple_expression_failures(): - verify_exception(""" + verify_exception( + """ struct A: member x : felt end @@ -2826,14 +3350,18 @@ def test_tuple_expression_failures(): end let a = cast(fp, A*) let b = cast((1, a), B) -""", """ +""", + """ file:?:?: Cannot cast an expression of type '(felt, test_scope.A*)' to 'test_scope.B'. The former has 2 members while the latter has 0 members. let b = cast((1, a), B) ^****^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) - verify_exception(""" + verify_exception( + """ struct A: member x : felt member y : felt @@ -2844,16 +3372,20 @@ def test_tuple_expression_failures(): end let a = [cast(fp, A*)] let b = cast((a, 1), B) -""", """ +""", + """ file:?:?: While expanding the reference 'a' in: let b = cast((a, 1), B) ^ file:?:?: Cannot cast 'test_scope.A' to 'felt'. let a = [cast(fp, A*)] ^************^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) - verify_exception(""" + verify_exception( + """ struct A: member x : felt member y : felt @@ -2863,37 +3395,50 @@ def test_tuple_expression_failures(): member b : A end let b = cast([cast(ap, (felt, felt*)*)], B) -""", """ +""", + """ file:?:?: Cannot cast 'felt*' to 'test_scope.A'. let b = cast([cast(ap, (felt, felt*)*)], B) ^************************^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) - verify_exception(""" + verify_exception( + """ struct B: end let b = cast([cast(ap, (felt, felt*)*)], B) -""", """ +""", + """ file:?:?: Cannot cast an expression of type '(felt, felt*)' to 'test_scope.B'. The former has 2 members while the latter has 0 members. let b = cast([cast(ap, (felt, felt*)*)], B) ^************************^ -""", exc_type=CairoTypeError) - verify_exception(""" +""", + exc_type=CairoTypeError, + ) + verify_exception( + """ (1, 1) = 1 -""", """ +""", + """ file:?:?: Expected a 'felt' or a pointer type. Got: '(felt, felt)'. (1, 1) = 1 ^****^ -""") +""", + ) - verify_exception(""" + verify_exception( + """ assert (1, 1) = 1 -""", """ +""", + """ file:?:?: Cannot compare '(felt, felt)' and 'felt'. assert (1, 1) = 1 ^***************^ -""") +""", + ) def test_struct_constructor(): @@ -2969,31 +3514,39 @@ def test_struct_constructor(): def test_struct_constructor_failures(): - verify_exception(""" + verify_exception( + """ func foo(): ret end foo(3) = foo(4) -""", """ +""", + """ file:?:?: Expected 'foo' to be a struct. Found: 'function'. foo(3) = foo(4) ^****^ -""") - verify_exception(""" +""", + ) + verify_exception( + """ struct A: member next: A* end assert A(next=0) = A(next=0) -""", """ +""", + """ file:?:?: Cannot cast 'felt' to 'test_scope.A*'. assert A(next=0) = A(next=0) ^ -""", exc_type=CairoTypeError) +""", + exc_type=CairoTypeError, + ) def verify_exception_for_expr(expr_str: str, expected_error: str): - verify_exception(f""" + verify_exception( + f""" struct T: member x : felt member y : felt @@ -3003,49 +3556,68 @@ def verify_exception_for_expr(expr_str: str, expected_error: str): alloc_locals local a : T = {expr_str} end -""", expected_error, exc_type=CairoTypeError) +""", + expected_error, + exc_type=CairoTypeError, + ) - verify_exception_for_expr('T(5, 6, 7)', """ + verify_exception_for_expr( + "T(5, 6, 7)", + """ file:?:?: Cannot cast an expression of type '(felt, felt, felt)' to 'test_scope.T'. The former has 3 members while the latter has 2 members. local a : T = T(5, 6, 7) ^********^ -""") +""", + ) - verify_exception_for_expr('&T(5, 6)', """ + verify_exception_for_expr( + "&T(5, 6)", + """ file:?:?: Expression has no address. local a : T = &T(5, 6) ^*****^ -""") +""", + ) - verify_exception_for_expr('T(5, 6).x', """ + verify_exception_for_expr( + "T(5, 6).x", + """ file:?:?: Accessing struct members for r-value structs is not supported yet. local a : T = T(5, 6).x ^*******^ -""") +""", + ) - verify_exception_for_expr('T{a}(5, 6)', """ + verify_exception_for_expr( + "T{a}(5, 6)", + """ file:?:?: Implicit arguments cannot be used with struct constructors. local a : T = T{a}(5, 6) ^ -""") +""", + ) def test_unsupported_decorator(): - verify_exception(""" + verify_exception( + """ @external func foo(): return() end -""", """ +""", + """ file:?:?: Unsupported decorator: 'external'. @external ^*******^ -""") +""", + ) def test_skipped_functions(): - files = {'module': """ + files = { + "module": """ func func0(): tempvar x = 0 return () @@ -3058,14 +3630,19 @@ def test_skipped_functions(): tempvar x = 2 return func1() end -""", '.': """ +""", + ".": """ from module import func2 func2() -"""} +""", + } program = preprocess_codes( - codes=[(files['.'], '.')], - pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files))) - assert program.format() == """\ + codes=[(files["."], ".")], + pass_manager=default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)), + ) + assert ( + program.format() + == """\ [ap] = 1; ap++ ret [ap] = 2; ap++ @@ -3073,13 +3650,16 @@ def test_skipped_functions(): ret call rel -5 """ + ) program = preprocess_codes( - codes=[(files['.'], '.')], + codes=[(files["."], ".")], pass_manager=default_pass_manager( - prime=PRIME, - read_module=read_file_from_dict(files), - opt_unused_functions=False)) - assert program.format() == """\ + prime=PRIME, read_module=read_file_from_dict(files), opt_unused_functions=False + ), + ) + assert ( + program.format() + == """\ [ap] = 0; ap++ ret [ap] = 1; ap++ @@ -3089,6 +3669,7 @@ def test_skipped_functions(): ret call rel -5 """ + ) def test_known_ap_change_decorator(): @@ -3111,15 +3692,18 @@ def test_known_ap_change_decorator(): preprocess_str(code=code, prime=PRIME) # Negative case. - verify_exception(""" + verify_exception( + """ @known_ap_change func foo(): foo() return () end -""", """ +""", + """ file:?:?: The compiler was unable to deduce the change of the ap register, as required by this \ decorator. @known_ap_change ^**************^ -""") +""", + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py index 3ca17a83..0c72d120 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test_utils.py @@ -7,54 +7,62 @@ from starkware.cairo.lang.compiler.preprocessor.pass_manager import PassManager from starkware.cairo.lang.compiler.preprocessor.preprocess_codes import preprocess_codes from starkware.cairo.lang.compiler.preprocessor.preprocessor import ( - PreprocessedProgram, Preprocessor) + PreprocessedProgram, + Preprocessor, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.test_utils import read_file_from_dict -PRIME = 3 * 2**30 + 1 +PRIME = 3 * 2 ** 30 + 1 # Note that the TEST_SCOPE is hardcoded in the tests. -TEST_SCOPE = ScopedName.from_string('test_scope') +TEST_SCOPE = ScopedName.from_string("test_scope") def strip_comments_and_linebreaks(program: str): """ Removes all comments and empty lines from the given program. """ - program = re.sub(r'\s*#.*\n', '\n', program) - return re.sub('\n+', '\n', program) + program = re.sub(r"\s*#.*\n", "\n", program) + return re.sub("\n+", "\n", program) def default_read_module(module_name: str): - raise Exception( - f'Error: trying to read module {module_name}, no reading algorithm provided.') + raise Exception(f"Error: trying to read module {module_name}, no reading algorithm provided.") def preprocess_str( - code: str, prime: int, main_scope: Optional[ScopedName] = None, - preprocessor_cls: Optional[Type[Preprocessor]] = None) -> PreprocessedProgram: + code: str, + prime: int, + main_scope: Optional[ScopedName] = None, + preprocessor_cls: Optional[Type[Preprocessor]] = None, +) -> PreprocessedProgram: return preprocess_str_ex( code=code, pass_manager=default_pass_manager( - prime=prime, read_module=default_read_module, preprocessor_cls=preprocessor_cls), - main_scope=main_scope) + prime=prime, read_module=default_read_module, preprocessor_cls=preprocessor_cls + ), + main_scope=main_scope, + ) def preprocess_str_ex( - code: str, pass_manager: PassManager, - main_scope: Optional[ScopedName] = None) -> PreprocessedProgram: + code: str, pass_manager: PassManager, main_scope: Optional[ScopedName] = None +) -> PreprocessedProgram: if main_scope is None: main_scope = TEST_SCOPE - return preprocess_codes( - [(code, '')], - pass_manager=pass_manager, - main_scope=main_scope) + return preprocess_codes([(code, "")], pass_manager=pass_manager, main_scope=main_scope) def verify_exception( - code: str, error: str, files: Dict[str, str] = {}, main_scope: Optional[ScopedName] = None, - exc_type=PreprocessorError, pass_manager: Optional[PassManager] = None): + code: str, + error: str, + files: Dict[str, str] = {}, + main_scope: Optional[ScopedName] = None, + exc_type=PreprocessorError, + pass_manager: Optional[PassManager] = None, +): """ Verifies that compiling the code results in the given error. """ @@ -65,9 +73,6 @@ def verify_exception( pass_manager = default_pass_manager(prime=PRIME, read_module=read_file_from_dict(files)) with pytest.raises(exc_type) as e: - preprocess_codes( - codes=[(code, '')], - pass_manager=pass_manager, - main_scope=main_scope) + preprocess_codes(codes=[(code, "")], pass_manager=pass_manager, main_scope=main_scope) # Remove line and column information from the error using a regular expression. - assert re.sub(':[0-9]+:[0-9]+', 'file:?:?', str(e.value)) == error.strip() + assert re.sub(":[0-9]+:[0-9]+", "file:?:?", str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_utils.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_utils.py index 08e63ef5..cfe39838 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_utils.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_utils.py @@ -3,7 +3,7 @@ from starkware.cairo.lang.compiler.ast.code_elements import CodeBlock, CodeElementEmptyLine from starkware.cairo.lang.compiler.ast.types import TypedIdentifier -from starkware.cairo.lang.compiler.error_handling import Location, ParentLocation +from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.parser import ParserContext, parse from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError @@ -15,11 +15,13 @@ def assert_no_modifier(typed_identifier: TypedIdentifier): if typed_identifier.modifier is not None: raise PreprocessorError( f"Unexpected modifier '{typed_identifier.modifier.format()}'.", - location=typed_identifier.modifier.location) + location=typed_identifier.modifier.location, + ) def verify_empty_code_block( - code_block: CodeBlock, error_message: str, default_location: Optional[Location]): + code_block: CodeBlock, error_message: str, default_location: Optional[Location] +): """ Verifies that the given code_block is empty (except for empty lines) and raises an exception otherwise. @@ -27,7 +29,7 @@ def verify_empty_code_block( for commented_code_elm in code_block.code_elements: code_elm = commented_code_elm.code_elm if not isinstance(code_elm, CodeElementEmptyLine): - if hasattr(code_elm, 'location'): + if hasattr(code_elm, "location"): location = code_elm.location # type: ignore elif commented_code_elm.location is not None: location = commented_code_elm.location @@ -36,19 +38,17 @@ def verify_empty_code_block( raise PreprocessorError(error_message, location=location) -def autogen_parse_code_block( - path: str, code: str, parent_location: Optional[ParentLocation] -) -> CodeBlock: +def autogen_parse_code_block(path: str, code: str, parser_context: ParserContext) -> CodeBlock: """ Parses the given code as CodeBlock. Can be used for auto-generation of code during compilation. """ code_hash = hashlib.sha256(code.encode()).hexdigest() - filename = f'{path}/{code_hash}.cairo' + filename = f"{path}/{code_hash}.cairo" return parse( filename=filename, code=code, - code_type='code_block', + code_type="code_block", expected_type=CodeBlock, - parser_context=ParserContext(parent_location=parent_location), + parser_context=parser_context, ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking.py b/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking.py index 04caf2de..535ac4ad 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking.py @@ -4,7 +4,7 @@ from starkware.cairo.lang.compiler.ast.expr import ExprConst, Expression -RegChangeLike = Union[Union[int, Expression, 'RegChange']] +RegChangeLike = Union[Union[int, Expression, "RegChange"]] class RegChange(ABC): @@ -91,17 +91,18 @@ class RegTrackingData: possible to deduce the register at another pointer (as long as both points belong to the same group). """ + group: int = 0 offset: int = 0 @classmethod - def new(cls, group_alloc: Callable) -> 'RegTrackingData': + def new(cls, group_alloc: Callable) -> "RegTrackingData": return cls( group=group_alloc(), offset=0, ) - def __sub__(self, other: 'RegTrackingData') -> RegChange: + def __sub__(self, other: "RegTrackingData") -> RegChange: """ If possible, returns the difference between the values of ap between self and other. Otherwise, returns RegChangeUnknown. @@ -112,15 +113,15 @@ def __sub__(self, other: 'RegTrackingData') -> RegChange: return RegChangeUnknown() return RegChangeKnown(self.offset - other.offset) - def add(self, change: RegChangeLike, group_alloc: Callable) -> 'RegTrackingData': + def add(self, change: RegChangeLike, group_alloc: Callable) -> "RegTrackingData": change = RegChange.from_expr(change) if isinstance(change, RegChangeKnown): return RegTrackingData(group=self.group, offset=self.offset + change.value) if isinstance(change, RegChangeUnknown): return RegTrackingData(group=group_alloc(), offset=0) - raise NotImplementedError(f'Unsupported change type {type(change).__name__}') + raise NotImplementedError(f"Unsupported change type {type(change).__name__}") - def converge(self, other: 'RegTrackingData', group_alloc: Callable): + def converge(self, other: "RegTrackingData", group_alloc: Callable): if not isinstance(other, RegTrackingData): return other.converge(self, group_alloc) if self != other: diff --git a/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py b/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py index 7b9bd2fa..229de61e 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/reg_tracking_test.py @@ -4,17 +4,22 @@ from starkware.cairo.lang.compiler.ast.expr import ExprConst, ExprIdentifier from starkware.cairo.lang.compiler.preprocessor.reg_tracking import ( - RegChange, RegChangeKnown, RegChangeUnconstrained, RegChangeUnknown, RegTrackingData) + RegChange, + RegChangeKnown, + RegChangeUnconstrained, + RegChangeUnknown, + RegTrackingData, +) def test_from_expr(): assert RegChange.from_expr(5) == RegChangeKnown(5) assert RegChange.from_expr(RegChangeKnown(6)) == RegChangeKnown(6) assert RegChange.from_expr(ExprConst(7)) == RegChangeKnown(7) - assert RegChange.from_expr(ExprIdentifier('asd')) == RegChangeUnknown() + assert RegChange.from_expr(ExprIdentifier("asd")) == RegChangeUnknown() with pytest.raises(TypeError): - RegChange.from_expr('wrong type') + RegChange.from_expr("wrong type") def test_reg_change_add(): @@ -23,7 +28,7 @@ def test_reg_change_add(): assert RegChangeUnknown() + RegChangeKnown(2) == RegChangeUnknown() with pytest.raises(TypeError): - RegChangeKnown(3) + 'asd' + RegChangeKnown(3) + "asd" with pytest.raises(TypeError): RegChangeUnconstrained() + RegChangeKnown(0) @@ -36,10 +41,13 @@ def test_reg_change_and(): def test_reg_tracking_data(): - assert RegTrackingData(group=3, offset=5) - RegTrackingData(group=3, offset=17) == \ - RegChangeKnown(-12) - assert RegTrackingData(group=3, offset=5) - RegTrackingData(group=4, offset=17) == \ - RegChangeUnknown() + assert RegTrackingData(group=3, offset=5) - RegTrackingData( + group=3, offset=17 + ) == RegChangeKnown(-12) + assert ( + RegTrackingData(group=3, offset=5) - RegTrackingData(group=4, offset=17) + == RegChangeUnknown() + ) def test_reg_tracking_data_add(): @@ -48,6 +56,7 @@ def test_reg_tracking_data_add(): def group_alloc(): return next(groups) + assert initial_data.add(3, group_alloc) == RegTrackingData(group=3, offset=8) assert initial_data.add(RegChangeUnknown(), group_alloc) == RegTrackingData(group=4, offset=0) assert initial_data.add(RegChangeUnknown(), group_alloc) == RegTrackingData(group=5, offset=0) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py b/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py index 5f70c882..3eba92ac 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/struct_collector.py @@ -4,13 +4,19 @@ from starkware.cairo.lang.compiler.ast.arguments import IdentifierList from starkware.cairo.lang.compiler.ast.cairo_types import CairoType from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeBlock, CodeElement, CodeElementEmptyLine, CodeElementFunction, CodeElementMember) + CodeBlock, + CodeElement, + CodeElementEmptyLine, + CodeElementFunction, + CodeElementMember, +) from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( - IdentifierAwareVisitor) + IdentifierAwareVisitor, +) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.preprocessor.preprocessor_utils import assert_no_modifier from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -21,6 +27,7 @@ class MemberInfo: """ Represents a member that wasn't assigned an offset yet. """ + name: str cairo_type: CairoType # Unresolved type. location: Optional[Location] = LocationField @@ -35,12 +42,13 @@ def __init__(self, identifiers: IdentifierManager): super().__init__(identifiers=identifiers) def _visit_default(self, obj): - assert isinstance(obj, (CodeBlock, CodeElement)), \ - f'Received unexpected object of type {type(obj).__name__}.' + assert isinstance( + obj, (CodeBlock, CodeElement) + ), f"Received unexpected object of type {type(obj).__name__}." def add_struct_definition( - self, members_list: List[MemberInfo], struct_name: ScopedName, - location: Optional[Location]): + self, members_list: List[MemberInfo], struct_name: ScopedName, location: Optional[Location] + ): offset = 0 members: Dict[str, MemberDefinition] = {} @@ -50,11 +58,12 @@ def add_struct_definition( name = member_info.name if name in members: raise PreprocessorError( - f"Redefinition of '{struct_name + name}'.", - location=member_info.location) + f"Redefinition of '{struct_name + name}'.", location=member_info.location + ) members[name] = MemberDefinition( - offset=offset, cairo_type=cairo_type, location=member_info.location) + offset=offset, cairo_type=cairo_type, location=member_info.location + ) offset += self.get_size(cairo_type) self.add_name_definition( @@ -65,11 +74,15 @@ def add_struct_definition( size=offset, location=location, ), - location=location) + location=location, + ) def create_struct_from_identifier_list( - self, identifier_list: Optional[IdentifierList], struct_name: ScopedName, - location: Optional[Location]): + self, + identifier_list: Optional[IdentifierList], + struct_name: ScopedName, + location: Optional[Location], + ): """ Creates a struct based on the given 'identifier_list'. """ @@ -77,18 +90,19 @@ def create_struct_from_identifier_list( if identifier_list is not None: for arg in identifier_list.identifiers: assert_no_modifier(arg) - members_list.append(MemberInfo( - name=arg.identifier.name, - cairo_type=arg.get_type(), - location=arg.location)) + members_list.append( + MemberInfo( + name=arg.identifier.name, cairo_type=arg.get_type(), location=arg.location + ) + ) location = identifier_list.location self.add_struct_definition( - members_list=members_list, struct_name=struct_name, location=location) + members_list=members_list, struct_name=struct_name, location=location + ) - def handle_struct_definition( - self, struct_name: ScopedName, code_block: CodeBlock, location): + def handle_struct_definition(self, struct_name: ScopedName, code_block: CodeBlock, location): members_list: List[MemberInfo] = [] for commented_code_element in code_block.code_elements: elm = commented_code_element.code_elm @@ -98,36 +112,42 @@ def handle_struct_definition( if not isinstance(elm, CodeElementMember): raise PreprocessorError( - 'Unexpected statement inside a struct definition.', - location=getattr(elm, 'location', location)) + "Unexpected statement inside a struct definition.", + location=getattr(elm, "location", location), + ) assert_no_modifier(elm.typed_identifier) if elm.typed_identifier.expr_type is None: raise PreprocessorError( - 'Struct members must be explicitly typed (e.g., member x : felt).', - location=elm.typed_identifier.location) + "Struct members must be explicitly typed (e.g., member x : felt).", + location=elm.typed_identifier.location, + ) identifier = elm.typed_identifier.identifier - members_list.append(MemberInfo( - name=identifier.name, - cairo_type=elm.typed_identifier.get_type(), - location=identifier.location)) + members_list.append( + MemberInfo( + name=identifier.name, + cairo_type=elm.typed_identifier.get_type(), + location=identifier.location, + ) + ) self.add_struct_definition( - members_list=members_list, struct_name=struct_name, location=location) + members_list=members_list, struct_name=struct_name, location=location + ) def visit_CodeElementFunction(self, elm: CodeElementFunction): new_scope = self.current_scope + elm.name - if elm.element_type == 'struct': + if elm.element_type == "struct": if len(elm.decorators) != 0: raise PreprocessorError( - 'Decorators for structs are not supported.', - location=elm.decorators[0].location + "Decorators for structs are not supported.", location=elm.decorators[0].location ) self.handle_struct_definition( - struct_name=new_scope, code_block=elm.code_block, location=elm.identifier.location) + struct_name=new_scope, code_block=elm.code_block, location=elm.identifier.location + ) return # Process code_elements. diff --git a/src/starkware/cairo/lang/compiler/preprocessor/struct_collector_test.py b/src/starkware/cairo/lang/compiler/preprocessor/struct_collector_test.py index 17584b43..75ac3120 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/struct_collector_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/struct_collector_test.py @@ -5,7 +5,11 @@ from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer, TypeStruct from starkware.cairo.lang.compiler.ast.module import CairoModule from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, FutureIdentifierDefinition, MemberDefinition, StructDefinition) + AliasDefinition, + FutureIdentifierDefinition, + MemberDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.parser import parse_file from starkware.cairo.lang.compiler.preprocessor.identifier_collector import IdentifierCollector from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError @@ -25,7 +29,9 @@ def _collect_struct_definitions(codes: Dict[str, str]) -> Dict[str, Set[str]]: CairoModule( cairo_file=parse_file(code), module_name=ScopedName.from_string(name), - ) for name, code in codes.items()] + ) + for name, code in codes.items() + ] identifier_collector = IdentifierCollector() for module in modules: identifier_collector.visit(module) @@ -35,16 +41,19 @@ def _collect_struct_definitions(codes: Dict[str, str]) -> Dict[str, Set[str]]: return { str(name): identifier_definition for name, identifier_definition in struct_collector.identifiers.as_dict().items() - if not isinstance(identifier_definition, FutureIdentifierDefinition)} + if not isinstance(identifier_definition, FutureIdentifierDefinition) + } def test_struct_collector(): - modules = {'module': """ + modules = { + "module": """ struct S: member x : S* member y : S* end -""", '__main__': """ +""", + "__main__": """ from module import S func foo{z}(a : S, b) -> (c : S): @@ -55,82 +64,116 @@ def test_struct_collector(): return (c=a + X) end const Y = 1 + 1 -"""} +""", + } scope = ScopedName.from_string struct_defs = _collect_struct_definitions(modules) expected_def = { - 'module.S': StructDefinition( - full_name=scope('module.S'), + "module.S": StructDefinition( + full_name=scope("module.S"), members={ - 'x': MemberDefinition(offset=0, cairo_type=TypePointer(pointee=TypeStruct( - scope=scope('module.S'), is_fully_resolved=True))), - 'y': MemberDefinition(offset=1, cairo_type=TypePointer(pointee=TypeStruct( - scope=scope('module.S'), is_fully_resolved=True))), - }, size=2), - '__main__.S': AliasDefinition(destination=scope('module.S')), - '__main__.foo.Args': StructDefinition( - full_name=scope('__main__.foo.Args'), + "x": MemberDefinition( + offset=0, + cairo_type=TypePointer( + pointee=TypeStruct(scope=scope("module.S"), is_fully_resolved=True) + ), + ), + "y": MemberDefinition( + offset=1, + cairo_type=TypePointer( + pointee=TypeStruct(scope=scope("module.S"), is_fully_resolved=True) + ), + ), + }, + size=2, + ), + "__main__.S": AliasDefinition(destination=scope("module.S")), + "__main__.foo.Args": StructDefinition( + full_name=scope("__main__.foo.Args"), members={ - 'a': MemberDefinition(offset=0, cairo_type=TypeStruct( - scope=scope('module.S'), is_fully_resolved=True)), - 'b': MemberDefinition(offset=2, cairo_type=TypeFelt()), - }, size=3), - '__main__.foo.ImplicitArgs': StructDefinition( - full_name=scope('__main__.foo.ImplicitArgs'), - members={'z': MemberDefinition(offset=0, cairo_type=TypeFelt())}, size=1), - '__main__.foo.Return': StructDefinition( - full_name=scope('__main__.foo.Return'), + "a": MemberDefinition( + offset=0, cairo_type=TypeStruct(scope=scope("module.S"), is_fully_resolved=True) + ), + "b": MemberDefinition(offset=2, cairo_type=TypeFelt()), + }, + size=3, + ), + "__main__.foo.ImplicitArgs": StructDefinition( + full_name=scope("__main__.foo.ImplicitArgs"), + members={"z": MemberDefinition(offset=0, cairo_type=TypeFelt())}, + size=1, + ), + "__main__.foo.Return": StructDefinition( + full_name=scope("__main__.foo.Return"), members={ - 'c': MemberDefinition(offset=0, cairo_type=TypeStruct( - scope=scope('module.S'), is_fully_resolved=True)) - }, size=2), - '__main__.foo.T': StructDefinition( - full_name=scope('__main__.foo.T'), + "c": MemberDefinition( + offset=0, cairo_type=TypeStruct(scope=scope("module.S"), is_fully_resolved=True) + ) + }, + size=2, + ), + "__main__.foo.T": StructDefinition( + full_name=scope("__main__.foo.T"), members={ - 'x': MemberDefinition(offset=0, cairo_type=TypePointer(pointee=TypeStruct( - scope=scope('module.S'), is_fully_resolved=True))), - }, size=1) + "x": MemberDefinition( + offset=0, + cairo_type=TypePointer( + pointee=TypeStruct(scope=scope("module.S"), is_fully_resolved=True) + ), + ), + }, + size=1, + ), } assert struct_defs == expected_def def test_struct_collector_failure(): - modules = {'module': """ + modules = { + "module": """ struct S: member x : S* member x : S* end -"""} +""" + } with pytest.raises(PreprocessorError, match="Redefinition of 'module.S.x'."): _collect_struct_definitions(modules) - modules = {'module': """ + modules = { + "module": """ struct S: member local a end -"""} +""" + } with pytest.raises(PreprocessorError, match="Unexpected modifier 'local'."): _collect_struct_definitions(modules) - modules = {'module': """ + modules = { + "module": """ struct S: return() end -"""} - with pytest.raises(PreprocessorError, match='Unexpected statement inside a struct definition.'): +""" + } + with pytest.raises(PreprocessorError, match="Unexpected statement inside a struct definition."): _collect_struct_definitions(modules) - verify_exception(""" + verify_exception( + """ @decorator struct Struct: end -""", """ +""", + """ file:?:?: Decorators for structs are not supported. @decorator ^********^ -""") +""", + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py index 4c61de53..fe4ddf3f 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels.py @@ -13,7 +13,7 @@ def __init__(self): self.anon_label_counter = 0 def get(self): - label_name = f'_anon_label{self.anon_label_counter}' + label_name = f"_anon_label{self.anon_label_counter}" self.anon_label_counter += 1 return label_name diff --git a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py index 18cf361c..aab1b230 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/unique_labels_test.py @@ -2,7 +2,8 @@ def test_unique_label_creator(): - program = preprocess_str(code=""" + program = preprocess_str( + code=""" namespace B: func foo(x, y) -> (res): if x == 0: @@ -24,8 +25,12 @@ def test_unique_label_creator(): B.foo(1, 2) ret end -""", prime=PRIME) - assert program.format() == """\ +""", + prime=PRIME, + ) + assert ( + program.format() + == """\ jmp rel 10 if [fp + (-4)] != 0 jmp rel 5 if [fp + (-3)] != 0 [ap] = 0; ap++ @@ -42,3 +47,4 @@ def test_unique_label_creator(): call rel -22 ret """ + ) diff --git a/src/starkware/cairo/lang/compiler/program.py b/src/starkware/cairo/lang/compiler/program.py index e705e8f5..9cf9f2f7 100644 --- a/src/starkware/cairo/lang/compiler/program.py +++ b/src/starkware/cairo/lang/compiler/program.py @@ -9,9 +9,15 @@ from starkware.cairo.lang.compiler.debug_info import DebugInfo from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, IdentifierDefinition, LabelDefinition, ReferenceDefinition) + ConstDefinition, + IdentifierDefinition, + LabelDefinition, + ReferenceDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierManager, MissingIdentifierError) + IdentifierManager, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.identifier_manager_field import IdentifierManagerField from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingDataActual, ReferenceManager from starkware.cairo.lang.compiler.references import Reference @@ -23,13 +29,14 @@ class CairoHint: code: str accessible_scopes: List[ScopedName] = field( - metadata=dict(marshmallow_field=mfields.List(ScopedNameAsStr))) + metadata=dict(marshmallow_field=mfields.List(ScopedNameAsStr)) + ) flow_tracking_data: FlowTrackingDataActual class ProgramBase(ABC): @abstractmethod - def stripped(self) -> 'StrippedProgram': + def stripped(self) -> "StrippedProgram": """ Returns the program as a StrippedProgram. """ @@ -46,25 +53,28 @@ class StrippedProgram(ProgramBase): Cairo program minimal information (stripped from hints, identifiers, etc.). The absence of hints is crucial for security reasons. Can be used for verifying execution. """ + prime: int data: List[int] builtins: List[str] main: int - def stripped(self) -> 'StrippedProgram': + def stripped(self) -> "StrippedProgram": return self def run_validity_checks(self): - assert isinstance(self.prime, int) and self.prime > 2**63, 'Invalid prime.' + assert isinstance(self.prime, int) and self.prime > 2 ** 63, "Invalid prime." assert isinstance(self.data, list) and all( - isinstance(x, int) and 0 <= x < self.prime for x in self.data), \ - 'Invalid program data.' - assert isinstance(self.builtins, list) and \ - all(is_valid_builtin_name(builtin) for builtin in self.builtins) and \ - len(set(self.builtins)) == len(self.builtins), \ - 'Invalid builtin list.' - assert isinstance(self.main, int) and 0 <= self.main < len(self.data), \ - 'Invalid main() address.' + isinstance(x, int) and 0 <= x < self.prime for x in self.data + ), "Invalid program data." + assert ( + isinstance(self.builtins, list) + and all(is_valid_builtin_name(builtin) for builtin in self.builtins) + and len(set(self.builtins)) == len(self.builtins) + ), "Invalid builtin list." + assert isinstance(self.main, int) and 0 <= self.main < len( + self.data + ), "Invalid main() address." @marshmallow_dataclass.dataclass(repr=False) @@ -75,7 +85,8 @@ class Program(ProgramBase, SerializableMarshmallowDataclass): builtins: List[str] main_scope: ScopedName = field(metadata=dict(marshmallow_field=ScopedNameAsStr())) identifiers: IdentifierManager = field( - metadata=dict(marshmallow_field=IdentifierManagerField())) + metadata=dict(marshmallow_field=IdentifierManagerField()) + ) # Holds all the allocated references in the program. reference_manager: ReferenceManager debug_info: Optional[DebugInfo] = None @@ -90,29 +101,33 @@ def stripped(self) -> StrippedProgram: ) def get_identifier( - self, name: Union[str, ScopedName], expected_type: Type[IdentifierDefinition], - full_name_lookup: Optional[bool] = None): + self, + name: Union[str, ScopedName], + expected_type: Type[IdentifierDefinition], + full_name_lookup: Optional[bool] = None, + ): scoped_name = name if isinstance(name, ScopedName) else ScopedName.from_string(name) if full_name_lookup is True: result = self.identifiers.root.get(scoped_name) else: - result = self.identifiers.search( - accessible_scopes=[self.main_scope], - name=scoped_name) + result = self.identifiers.search(accessible_scopes=[self.main_scope], name=scoped_name) result.assert_fully_parsed() identifier_definition = result.identifier_definition assert isinstance(identifier_definition, expected_type), ( - f"'{scoped_name}' is expected to be {expected_type.TYPE}, " + # type: ignore - f'found {identifier_definition.TYPE}.') # type: ignore + f"'{scoped_name}' is expected to be {expected_type.TYPE}, " + + f"found {identifier_definition.TYPE}." # type: ignore + ) # type: ignore return identifier_definition def get_label(self, name: Union[str, ScopedName], full_name_lookup: Optional[bool] = None): return self.get_identifier( - name=name, expected_type=LabelDefinition, full_name_lookup=full_name_lookup).pc + name=name, expected_type=LabelDefinition, full_name_lookup=full_name_lookup + ).pc def get_const(self, name: Union[str, ScopedName], full_name_lookup: Optional[bool] = None): return self.get_identifier( - name=name, expected_type=ConstDefinition, full_name_lookup=full_name_lookup).value + name=name, expected_type=ConstDefinition, full_name_lookup=full_name_lookup + ).value def get_reference_binds(self, name: Union[str, ScopedName]) -> List[Reference]: """ @@ -124,14 +139,14 @@ def get_reference_binds(self, name: Union[str, ScopedName]) -> List[Reference]: @property def main(self) -> Optional[int]: # type: ignore try: - return self.get_label('main') + return self.get_label("main") except MissingIdentifierError: return None @property def start(self) -> int: try: - return self.get_label('__start__') + return self.get_label("__start__") except MissingIdentifierError: return 0 @@ -140,5 +155,8 @@ def is_valid_builtin_name(name: str) -> bool: """ Returns true if name may be used as a builtin name. """ - return isinstance(name, str) and len(name) < 1000 and set(name) <= { - *string.ascii_lowercase, *string.digits, '_'} + return ( + isinstance(name, str) + and len(name) < 1000 + and set(name) <= {*string.ascii_lowercase, *string.digits, "_"} + ) diff --git a/src/starkware/cairo/lang/compiler/references.py b/src/starkware/cairo/lang/compiler/references.py index d97b8f02..c4d34f8c 100644 --- a/src/starkware/cairo/lang/compiler/references.py +++ b/src/starkware/cairo/lang/compiler/references.py @@ -6,13 +6,23 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypePointer from starkware.cairo.lang.compiler.ast.expr import ( - ExprCast, ExprConst, ExprDeref, Expression, ExprOperator, ExprReg) + ExprCast, + ExprConst, + ExprDeref, + Expression, + ExprOperator, + ExprReg, +) from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer from starkware.cairo.lang.compiler.fields import ExpressionAsStr from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.preprocessor.reg_tracking import ( - RegChange, RegChangeKnown, RegChangeLike, RegTrackingData) + RegChange, + RegChangeKnown, + RegChangeLike, + RegTrackingData, +) class FlowTrackingError(Exception): @@ -22,8 +32,8 @@ def __init__(self, message): def create_simple_ref_expr( - reg: Register, offset: int, cairo_type: CairoType, - location: Optional[Location]) -> Expression: + reg: Register, offset: int, cairo_type: CairoType, location: Optional[Location] +) -> Expression: """ Creates an expression of the form '[cast(reg + offset, cairo_type*)]'. """ @@ -31,12 +41,15 @@ def create_simple_ref_expr( addr=ExprCast( expr=ExprOperator( a=ExprReg(reg=reg, location=location), - op='+', + op="+", b=ExprConst(val=offset, location=location), - location=location), + location=location, + ), dest_type=TypePointer(pointee=cairo_type, location=location), - location=location), - location=location) + location=location, + ), + location=location, + ) @marshmallow_dataclass.dataclass @@ -52,6 +65,7 @@ class Reference: [ap] = [x] * 2; ap++ # Thus, this instruction will translate to '[ap] = [ap - 1] * 2; ap++' # and will set [ap] to 10. """ + pc: int value: Expression = field(metadata=dict(marshmallow_field=ExpressionAsStr(required=True))) # The value of flow_tracking when this reference was defined. @@ -60,8 +74,11 @@ class Reference: # A list of definition sites for the reference. # The list may hold more then once location if the reference is defined from the # convergence of multiple reference definitions. - locations: List[Location] = field(default_factory=list, compare=False, metadata=dict( - marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True))) + locations: List[Location] = field( + default_factory=list, + compare=False, + metadata=dict(marshmallow_field=marshmallow.fields.Field(load_only=True, dump_only=True)), + ) Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema @@ -74,9 +91,9 @@ def eval(self, ap_tracking_data: RegTrackingData): return translate_ap(self.value, ap_diff) except FlowTrackingError as exc: if len(self.locations) > 0: - exc.notes.append('Reference was defined here:') + exc.notes.append("Reference was defined here:") for location in self.locations: - exc.notes.append(location.topmost_location().to_string_with_content('')) + exc.notes.append(location.topmost_location().to_string_with_content("")) raise @@ -89,28 +106,32 @@ def translate_ap(expr, ap_diff: RegChangeLike): def ap(location): return ExprOperator( ExprReg(reg=Register.AP, location=location), - '-', + "-", ExprConst(val=diff, location=location), - location=location) + location=location, + ) + else: ap = None - fp = (lambda location: ExprReg(reg=Register.FP, location=location)) + fp = lambda location: ExprReg(reg=Register.FP, location=location) return SubstituteRegisterTransformer(ap, fp).visit(expr) class SubstituteRegisterTransformer(ExpressionTransformer): def __init__( - self, ap: Optional[Callable[[Optional[Location]], Expression]], - fp: Callable[[Optional[Location]], Expression]): + self, + ap: Optional[Callable[[Optional[Location]], Expression]], + fp: Callable[[Optional[Location]], Expression], + ): self.ap = ap self.fp = fp def visit_ExprReg(self, expr: ExprReg): if expr.reg is Register.AP: if self.ap is None: - raise FlowTrackingError('Failed to deduce ap.') + raise FlowTrackingError("Failed to deduce ap.") return self.ap(expr.location) elif expr.reg is Register.FP: return self.fp(expr.location) else: - raise NotImplementedError(f'Register of type {expr.reg} is not supported') + raise NotImplementedError(f"Register of type {expr.reg} is not supported") diff --git a/src/starkware/cairo/lang/compiler/references_test.py b/src/starkware/cairo/lang/compiler/references_test.py index 81bbee53..42c8b8de 100644 --- a/src/starkware/cairo/lang/compiler/references_test.py +++ b/src/starkware/cairo/lang/compiler/references_test.py @@ -8,16 +8,18 @@ def test_eval_reference(): x = Reference( pc=0, - value=parse_expr('2 * ap + 3 * fp - 5'), - ap_tracking_data=RegTrackingData(group=1, offset=5)) + value=parse_expr("2 * ap + 3 * fp - 5"), + ap_tracking_data=RegTrackingData(group=1, offset=5), + ) with pytest.raises(FlowTrackingError): x.eval(RegTrackingData(group=2, offset=5)) - assert x.eval(RegTrackingData(group=1, offset=8)).format() == '2 * (ap - 3) + 3 * fp - 5' + assert x.eval(RegTrackingData(group=1, offset=8)).format() == "2 * (ap - 3) + 3 * fp - 5" def test_eval_reference_fp_only(): x = Reference( pc=0, - value=parse_expr('3 * fp - 5 + fp * fp'), - ap_tracking_data=RegTrackingData(group=1, offset=5)) - assert x.eval(RegTrackingData(group=2, offset=7)) == parse_expr('3 * fp - 5 + fp * fp') + value=parse_expr("3 * fp - 5 + fp * fp"), + ap_tracking_data=RegTrackingData(group=1, offset=5), + ) + assert x.eval(RegTrackingData(group=2, offset=7)) == parse_expr("3 * fp - 5 + fp * fp") diff --git a/src/starkware/cairo/lang/compiler/resolve_search_result.py b/src/starkware/cairo/lang/compiler/resolve_search_result.py index 0387a5de..30128833 100644 --- a/src/starkware/cairo/lang/compiler/resolve_search_result.py +++ b/src/starkware/cairo/lang/compiler/resolve_search_result.py @@ -1,14 +1,22 @@ from starkware.cairo.lang.compiler.constants import SIZE_CONSTANT from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, DefinitionError, IdentifierDefinition, ReferenceDefinition, StructDefinition) + ConstDefinition, + DefinitionError, + IdentifierDefinition, + ReferenceDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, IdentifierSearchResult) + IdentifierError, + IdentifierManager, + IdentifierSearchResult, +) from starkware.cairo.lang.compiler.offset_reference import OffsetReferenceDefinition def resolve_search_result( - search_result: IdentifierSearchResult, - identifiers: IdentifierManager) -> IdentifierDefinition: + search_result: IdentifierSearchResult, identifiers: IdentifierManager +) -> IdentifierDefinition: """ Returns a fully parsed identifier definition for the given identifier search result. If search_result contains a reference with non_parsed data, returns an instance of @@ -27,18 +35,20 @@ def resolve_search_result( struct_name = identifier_definition.full_name if member_def is None: raise DefinitionError( - f"'{search_result.non_parsed}' is not a member of '{struct_name}'.") + f"'{search_result.non_parsed}' is not a member of '{struct_name}'." + ) if len(search_result.non_parsed) > 1: raise IdentifierError( f"Unexpected '.' after '{struct_name + search_result.non_parsed.path[0]}' which is " - f'{member_def.TYPE}.') + f"{member_def.TYPE}." + ) identifier_definition = member_def elif isinstance(identifier_definition, ReferenceDefinition): identifier_definition = OffsetReferenceDefinition( - parent=identifier_definition, - member_path=search_result.non_parsed) + parent=identifier_definition, member_path=search_result.non_parsed + ) else: search_result.assert_fully_parsed() diff --git a/src/starkware/cairo/lang/compiler/resolve_search_result_test.py b/src/starkware/cairo/lang/compiler/resolve_search_result_test.py index a5fb1d21..a34137b8 100644 --- a/src/starkware/cairo/lang/compiler/resolve_search_result_test.py +++ b/src/starkware/cairo/lang/compiler/resolve_search_result_test.py @@ -3,7 +3,10 @@ from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, IdentifierSearchResult) + IdentifierError, + IdentifierManager, + IdentifierSearchResult, +) from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result from starkware.cairo.lang.compiler.scoped_name import ScopedName @@ -12,11 +15,10 @@ def test_resolve_search_result(): struct_def = StructDefinition( - full_name=scope('T'), + full_name=scope("T"), members={ - 'a': MemberDefinition(offset=0, cairo_type=TypeFelt()), - - 'b': MemberDefinition(offset=1, cairo_type=TypeFelt()), + "a": MemberDefinition(offset=0, cairo_type=TypeFelt()), + "b": MemberDefinition(offset=1, cairo_type=TypeFelt()), }, size=2, ) @@ -32,5 +34,7 @@ def test_resolve_search_result(): search_result=IdentifierSearchResult( identifier_definition=struct_def, canonical_name=struct_def.full_name, - non_parsed=scope('a.z')), - identifiers=identifier) + non_parsed=scope("a.z"), + ), + identifiers=identifier, + ) diff --git a/src/starkware/cairo/lang/compiler/scoped_name.py b/src/starkware/cairo/lang/compiler/scoped_name.py index 42379e43..02498b8f 100644 --- a/src/starkware/cairo/lang/compiler/scoped_name.py +++ b/src/starkware/cairo/lang/compiler/scoped_name.py @@ -6,16 +6,16 @@ @dataclasses.dataclass(frozen=True) class ScopedName: - SEPARATOR: ClassVar[str] = '.' + SEPARATOR: ClassVar[str] = "." path: Tuple[str, ...] = () def __post_init__(self): - assert '' not in self.path, 'Empty namespace is not supported.' + assert "" not in self.path, "Empty namespace is not supported." assert all([self.SEPARATOR not in part for part in self.path]) @classmethod def from_string(cls, scope: str): - if scope == '': + if scope == "": # Handle the special case of an empty tuple. return cls() return cls(tuple(scope.split(cls.SEPARATOR))) @@ -29,14 +29,14 @@ def __len__(self) -> int: """ return len(self.path) - def startswith(self, other: Union[str, 'ScopedName']) -> bool: + def startswith(self, other: Union[str, "ScopedName"]) -> bool: if isinstance(other, str): return self.startswith(self.from_string(other)) assert isinstance(other, ScopedName) - return self[:len(other)] == other + return self[: len(other)] == other - def __add__(self, other: Union[str, 'ScopedName']): + def __add__(self, other: Union[str, "ScopedName"]): if isinstance(other, str): return self + ScopedName.from_string(other) diff --git a/src/starkware/cairo/lang/compiler/scoped_name_test.py b/src/starkware/cairo/lang/compiler/scoped_name_test.py index 181aaddf..40b98fb9 100644 --- a/src/starkware/cairo/lang/compiler/scoped_name_test.py +++ b/src/starkware/cairo/lang/compiler/scoped_name_test.py @@ -4,36 +4,36 @@ def test_scoped_name(): - assert ScopedName(('some', 'thing')).path == ('some', 'thing') - assert str(ScopedName(('some', 'thing'))) == 'some.thing' - assert ScopedName.from_string('some.thing').path == ('some', 'thing') - assert ScopedName(('some', 'thing')) + 'el.se' == ScopedName(('some', 'thing', 'el', 'se')) - assert ScopedName(('some', 'thing')) + 'el.se' != ScopedName(('some', 'thing', 'else')) + assert ScopedName(("some", "thing")).path == ("some", "thing") + assert str(ScopedName(("some", "thing"))) == "some.thing" + assert ScopedName.from_string("some.thing").path == ("some", "thing") + assert ScopedName(("some", "thing")) + "el.se" == ScopedName(("some", "thing", "el", "se")) + assert ScopedName(("some", "thing")) + "el.se" != ScopedName(("some", "thing", "else")) - assert ScopedName.from_string('aa.bb.cc.dd')[1:3] == ScopedName.from_string('bb.cc') + assert ScopedName.from_string("aa.bb.cc.dd")[1:3] == ScopedName.from_string("bb.cc") with pytest.raises(AssertionError): - ScopedName(('some', 'thing.else')) + ScopedName(("some", "thing.else")) def test_empty(): - assert str(ScopedName()) == '' - assert ScopedName.from_string('') == ScopedName() + assert str(ScopedName()) == "" + assert ScopedName.from_string("") == ScopedName() def test_len(): assert len(ScopedName()) == 0 - assert len(ScopedName.from_string('a')) == 1 - assert len(ScopedName.from_string('a.b')) == 2 - assert len(ScopedName.from_string('x.a.b')) == 3 - assert len(ScopedName.from_string('x.a.b.c')) == 4 + assert len(ScopedName.from_string("a")) == 1 + assert len(ScopedName.from_string("a.b")) == 2 + assert len(ScopedName.from_string("x.a.b")) == 3 + assert len(ScopedName.from_string("x.a.b.c")) == 4 def test_startswith(): - assert ScopedName.from_string('a.b').startswith(ScopedName.from_string('a')) - assert not ScopedName.from_string('x.a.b').startswith(ScopedName.from_string('a')) - assert not ScopedName.from_string('a.b').startswith(ScopedName.from_string('x')) - assert not ScopedName.from_string('a.b').startswith('b') - assert ScopedName.from_string('x.a.b').startswith('') - assert ScopedName.from_string('x.a.b').startswith('x.a') - assert not ScopedName.from_string('abc').startswith('a') + assert ScopedName.from_string("a.b").startswith(ScopedName.from_string("a")) + assert not ScopedName.from_string("x.a.b").startswith(ScopedName.from_string("a")) + assert not ScopedName.from_string("a.b").startswith(ScopedName.from_string("x")) + assert not ScopedName.from_string("a.b").startswith("b") + assert ScopedName.from_string("x.a.b").startswith("") + assert ScopedName.from_string("x.a.b").startswith("x.a") + assert not ScopedName.from_string("abc").startswith("a") diff --git a/src/starkware/cairo/lang/compiler/substitute_identifiers.py b/src/starkware/cairo/lang/compiler/substitute_identifiers.py index fefbd405..dcaa31c5 100644 --- a/src/starkware/cairo/lang/compiler/substitute_identifiers.py +++ b/src/starkware/cairo/lang/compiler/substitute_identifiers.py @@ -2,7 +2,14 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeStruct from starkware.cairo.lang.compiler.ast.expr import ( - ExprCast, ExprConst, Expression, ExprFutureLabel, ExprIdentifier, ExprPow, ExprTuple) + ExprCast, + ExprConst, + Expression, + ExprFutureLabel, + ExprIdentifier, + ExprPow, + ExprTuple, +) from starkware.cairo.lang.compiler.ast.expr_func_call import ExprFuncCall from starkware.cairo.lang.compiler.ast.rvalue import RvalueFuncCall from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer @@ -15,14 +22,17 @@ class SubstituteIdentifiers(ExpressionTransformer): def __init__( - self, get_identifier_callback: GetIdentifierCallback, - resolve_type_callback: ResolveTypeCallback = None): + self, + get_identifier_callback: GetIdentifierCallback, + resolve_type_callback: ResolveTypeCallback = None, + ): super().__init__() self.get_identifier_callback = get_identifier_callback self.resolve_type_callback = ( resolve_type_callback if resolve_type_callback is not None - else (lambda cairo_type: cairo_type)) + else (lambda cairo_type: cairo_type) + ) def visit_ExprIdentifier(self, expr: ExprIdentifier) -> Expression: val = self.get_identifier_callback(expr) @@ -36,15 +46,16 @@ def visit_ExprCast(self, expr: ExprCast): dest_type=self.resolve_type_callback(expr.dest_type), cast_type=expr.cast_type, notes=expr.notes, - location=expr.location) + location=expr.location, + ) def visit_ExprPow(self, expr: ExprPow): # Same as super().visit_ExprPow, except that we don't visit expr.b. # The reason is that the exponent shouldn't be taken modulo PRIME, so we don't allow # using identifiers in the exponent. return ExprPow( - a=self.visit(expr.a), b=expr.b, - location=self.location_modifier(expr.location)) + a=self.visit(expr.a), b=expr.b, location=self.location_modifier(expr.location) + ) def visit_RvalueFuncCall(self, rvalue: RvalueFuncCall): # Same as super().visit_RvalueFuncCall, except that we don't visit rvalue.func_ident. @@ -53,35 +64,46 @@ def visit_RvalueFuncCall(self, rvalue: RvalueFuncCall): return RvalueFuncCall( func_ident=rvalue.func_ident, arguments=self.visit_ArgList(rvalue.arguments), - implicit_arguments=None if rvalue.implicit_arguments is None else self.visit_ArgList( - rvalue.implicit_arguments), - location=rvalue.location) + implicit_arguments=None + if rvalue.implicit_arguments is None + else self.visit_ArgList(rvalue.implicit_arguments), + location=rvalue.location, + ) def visit_ExprFuncCall(self, expr: ExprFuncCall): # Convert ExprFuncCall to ExprCast. rvalue = expr.rvalue if rvalue.implicit_arguments is not None: raise CairoTypeError( - 'Implicit arguments cannot be used with struct constructors.', - location=rvalue.implicit_arguments.location) + "Implicit arguments cannot be used with struct constructors.", + location=rvalue.implicit_arguments.location, + ) - struct_type = self.resolve_type_callback(TypeStruct( - scope=ScopedName.from_string(rvalue.func_ident.name), - is_fully_resolved=False, - location=expr.location)) + struct_type = self.resolve_type_callback( + TypeStruct( + scope=ScopedName.from_string(rvalue.func_ident.name), + is_fully_resolved=False, + location=expr.location, + ) + ) - return self.visit(ExprCast( - expr=ExprTuple(rvalue.arguments, location=expr.location), - dest_type=struct_type, - location=expr.location)) + return self.visit( + ExprCast( + expr=ExprTuple(rvalue.arguments, location=expr.location), + dest_type=struct_type, + location=expr.location, + ) + ) def visit_ExprFutureLabel(self, expr: ExprFutureLabel): return self.visit(expr.identifier) def substitute_identifiers( - expr: Expression, get_identifier_callback: GetIdentifierCallback, - resolve_type_callback: ResolveTypeCallback = None) -> Expression: + expr: Expression, + get_identifier_callback: GetIdentifierCallback, + resolve_type_callback: ResolveTypeCallback = None, +) -> Expression: """ Replaces identifiers by other expressions according to the given callback. """ diff --git a/src/starkware/cairo/lang/compiler/type_casts.py b/src/starkware/cairo/lang/compiler/type_casts.py index 9eb77374..346e4aa6 100644 --- a/src/starkware/cairo/lang/compiler/type_casts.py +++ b/src/starkware/cairo/lang/compiler/type_casts.py @@ -2,7 +2,13 @@ from typing import Optional from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, CastType, TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, + CastType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.expr import ExprDeref, Expression, ExprTuple from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager @@ -16,8 +22,12 @@ class CairoTypeError(LocationError): def check_cast( - src_type: CairoType, dest_type: CairoType, identifier_manager: IdentifierManager, - expr: Optional[Expression] = None, cast_type: CastType = CastType.EXPLICIT) -> bool: + src_type: CairoType, + dest_type: CairoType, + identifier_manager: IdentifierManager, + expr: Optional[Expression] = None, + cast_type: CastType = CastType.EXPLICIT, +) -> bool: """ Returns true if the given expression can be casted from src_type to dest_type according to the given 'cast_type'. @@ -41,19 +51,21 @@ def check_cast( # CastType.UNPACKING checks: # Allow explicit cast between felts and pointers. - if isinstance(src_type, (TypeFelt, TypePointer)) and \ - isinstance(dest_type, (TypeFelt, TypePointer)): + if isinstance(src_type, (TypeFelt, TypePointer)) and isinstance( + dest_type, (TypeFelt, TypePointer) + ): return True if cast_type is CastType.UNPACKING: return False # CastType.EXPLICIT checks: - assert expr is not None, f'CastType.EXPLICIT requires expr != None.' + assert expr is not None, f"CastType.EXPLICIT requires expr != None." if isinstance(src_type, TypeTuple) and isinstance(dest_type, TypeStruct): struct_def = get_struct_definition( - struct_name=dest_type.resolved_scope, identifier_manager=identifier_manager) + struct_name=dest_type.resolved_scope, identifier_manager=identifier_manager + ) n_src_members = len(src_type.members) n_dest_members = len(struct_def.members) @@ -62,24 +74,31 @@ def check_cast( f"""\ Cannot cast an expression of type '{src_type.format()}' to '{dest_type.format()}'. The former has {n_src_members} members while the latter has {n_dest_members} members.""", - location=expr.location) + location=expr.location, + ) src_exprs = ( [arg.expr for arg in expr.members.args] if isinstance(expr, ExprTuple) - else itertools.repeat(expr)) + else itertools.repeat(expr) + ) for (src_expr, src_member_type, dest_member) in zip( - src_exprs, src_type.members, struct_def.members.values()): + src_exprs, src_type.members, struct_def.members.values() + ): dest_member_type = dest_member.cairo_type if not check_cast( - src_type=src_member_type, dest_type=dest_member_type, - identifier_manager=identifier_manager, expr=src_expr, - cast_type=CastType.FORCED if cast_type is CastType.FORCED else CastType.ASSIGN): + src_type=src_member_type, + dest_type=dest_member_type, + identifier_manager=identifier_manager, + expr=src_expr, + cast_type=CastType.FORCED if cast_type is CastType.FORCED else CastType.ASSIGN, + ): raise CairoTypeError( f"Cannot cast '{src_member_type.format()}' to '{dest_member_type.format()}'.", - location=src_expr.location) + location=src_expr.location, + ) return True @@ -87,9 +106,12 @@ def check_cast( return False # CastType.FORCED checks: - if isinstance(src_type, TypeFelt) and isinstance(dest_type, TypeStruct) and isinstance( - expr, ExprDeref): + if ( + isinstance(src_type, TypeFelt) + and isinstance(dest_type, TypeStruct) + and isinstance(expr, ExprDeref) + ): return True - assert cast_type is CastType.FORCED, f'Unsupported cast type: {cast_type}.' + assert cast_type is CastType.FORCED, f"Unsupported cast type: {cast_type}." return False diff --git a/src/starkware/cairo/lang/compiler/type_casts_test.py b/src/starkware/cairo/lang/compiler/type_casts_test.py index 3947dafe..24a88231 100644 --- a/src/starkware/cairo/lang/compiler/type_casts_test.py +++ b/src/starkware/cairo/lang/compiler/type_casts_test.py @@ -6,27 +6,36 @@ from starkware.cairo.lang.compiler.type_casts import check_cast -@pytest.mark.parametrize('src, dest, explicit_cast, unpacking_cast, assign_cast', [ - ['T', 'T', True, True, True], - ['felt', 'felt*', True, True, False], - ['felt*', 'felt', True, True, False], - ['felt*', 'T*', True, True, False], - ['T*', 'felt*', True, True, True], - ['felt*', 'T', False, False, False], - ['T', 'felt*', False, False, False], - ['felt', '(felt,felt)', False, False, False], -]) +@pytest.mark.parametrize( + "src, dest, explicit_cast, unpacking_cast, assign_cast", + [ + ["T", "T", True, True, True], + ["felt", "felt*", True, True, False], + ["felt*", "felt", True, True, False], + ["felt*", "T*", True, True, False], + ["T*", "felt*", True, True, True], + ["felt*", "T", False, False, False], + ["T", "felt*", False, False, False], + ["felt", "(felt,felt)", False, False, False], + ], +) def test_type_casts( - src: str, dest: str, explicit_cast: bool, unpacking_cast: bool, assign_cast: bool): + src: str, dest: str, explicit_cast: bool, unpacking_cast: bool, assign_cast: bool +): identifier_manager = IdentifierManager() src_type = parse_type(src) dest_type = parse_type(dest) - expr = parse_expr('[ap]') + expr = parse_expr("[ap]") actual_results = [ check_cast( - src_type=src_type, dest_type=dest_type, identifier_manager=identifier_manager, - expr=expr, cast_type=cast_type) - for cast_type in [CastType.EXPLICIT, CastType.UNPACKING, CastType.ASSIGN]] + src_type=src_type, + dest_type=dest_type, + identifier_manager=identifier_manager, + expr=expr, + cast_type=cast_type, + ) + for cast_type in [CastType.EXPLICIT, CastType.UNPACKING, CastType.ASSIGN] + ] expected_results = [explicit_cast, unpacking_cast, assign_cast] assert actual_results == expected_results diff --git a/src/starkware/cairo/lang/compiler/type_system.py b/src/starkware/cairo/lang/compiler/type_system.py index 5cd88302..0fa92943 100644 --- a/src/starkware/cairo/lang/compiler/type_system.py +++ b/src/starkware/cairo/lang/compiler/type_system.py @@ -1,7 +1,12 @@ import dataclasses from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.expr import ExprCast, Expression from starkware.cairo.lang.compiler.expression_transformer import ExpressionTransformer @@ -18,15 +23,13 @@ def mark_type_resolved(cairo_type: CairoType) -> CairoType: elif isinstance(cairo_type, TypeStruct): if cairo_type.is_fully_resolved: return cairo_type - return dataclasses.replace( - cairo_type, - is_fully_resolved=True) + return dataclasses.replace(cairo_type, is_fully_resolved=True) elif isinstance(cairo_type, TypeTuple): return dataclasses.replace( - cairo_type, - members=[mark_type_resolved(member) for member in cairo_type.members]) + cairo_type, members=[mark_type_resolved(member) for member in cairo_type.members] + ) else: - raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') + raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") def is_type_resolved(cairo_type: CairoType) -> bool: @@ -42,13 +45,14 @@ def is_type_resolved(cairo_type: CairoType) -> bool: elif isinstance(cairo_type, TypeTuple): return all(map(is_type_resolved, cairo_type.members)) else: - raise NotImplementedError(f'Type {type(cairo_type).__name__} is not supported.') + raise NotImplementedError(f"Type {type(cairo_type).__name__} is not supported.") class MarkResolved(ExpressionTransformer): def visit_ExprCast(self, expr: ExprCast): return dataclasses.replace( - expr, expr=self.visit(expr.expr), dest_type=mark_type_resolved(expr.dest_type)) + expr, expr=self.visit(expr.expr), dest_type=mark_type_resolved(expr.dest_type) + ) def mark_types_in_expr_resolved(expr: Expression): diff --git a/src/starkware/cairo/lang/compiler/type_system_visitor.py b/src/starkware/cairo/lang/compiler/type_system_visitor.py index 1470e487..f075dd70 100644 --- a/src/starkware/cairo/lang/compiler/type_system_visitor.py +++ b/src/starkware/cairo/lang/compiler/type_system_visitor.py @@ -2,23 +2,43 @@ from typing import Optional, Tuple from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.expr import ( - ExprAddressOf, ExprCast, ExprConst, ExprDeref, ExprDot, Expression, ExprFutureLabel, ExprHint, - ExprIdentifier, ExprNeg, ExprOperator, ExprParentheses, ExprPyConst, ExprReg, ExprSubscript, - ExprTuple) + ExprAddressOf, + ExprCast, + ExprConst, + ExprDeref, + ExprDot, + Expression, + ExprFutureLabel, + ExprHint, + ExprIdentifier, + ExprNeg, + ExprOperator, + ExprParentheses, + ExprPyConst, + ExprReg, + ExprSubscript, + ExprTuple, +) from starkware.cairo.lang.compiler.error_handling import Location from starkware.cairo.lang.compiler.expression_simplifier import ExpressionSimplifier from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( - IdentifierAwareVisitor) + IdentifierAwareVisitor, +) from starkware.cairo.lang.compiler.type_casts import CairoTypeError, check_cast def get_expr_addr(expr: Expression): if not isinstance(expr, ExprDeref): - raise CairoTypeError('Expression has no address.', location=expr.location) + raise CairoTypeError("Expression has no address.", location=expr.location) return expr.addr @@ -45,8 +65,8 @@ def visit_ExprFutureLabel(self, expr: ExprFutureLabel) -> Tuple[ExprFutureLabel, def visit_ExprIdentifier(self, expr: ExprIdentifier) -> Tuple[Expression, CairoType]: raise CairoTypeError( - f"Identifier '{expr.format()}' is not allowed in this context.", - location=expr.location) + f"Identifier '{expr.format()}' is not allowed in this context.", location=expr.location + ) def visit_ExprReg(self, expr: ExprReg) -> Tuple[ExprReg, TypeFelt]: return expr, TypeFelt(location=expr.location) @@ -59,17 +79,18 @@ def visit_ExprOperator(self, expr: ExprOperator) -> Tuple[ExprOperator, CairoTyp result_type: CairoType if isinstance(a_type, TypeFelt) and isinstance(b_type, TypeFelt): result_type = TypeFelt(location=expr.location) - elif isinstance(a_type, TypePointer) and isinstance(b_type, TypeFelt) and op in ['+', '-']: + elif isinstance(a_type, TypePointer) and isinstance(b_type, TypeFelt) and op in ["+", "-"]: result_type = a_type - elif isinstance(a_type, TypeFelt) and isinstance(b_type, TypePointer) and op == '+': + elif isinstance(a_type, TypeFelt) and isinstance(b_type, TypePointer) and op == "+": result_type = b_type - elif isinstance(a_type, TypePointer) and a_type == b_type and op == '-': + elif isinstance(a_type, TypePointer) and a_type == b_type and op == "-": result_type = TypeFelt(location=expr.location) else: raise CairoTypeError( f"Operator '{op}' is not implemented for types " f"'{a_type.format()}' and '{b_type.format()}'.", - location=expr.location) + location=expr.location, + ) return dataclasses.replace(expr, a=a_expr, b=b_expr), result_type def visit_ExprPow(self, expr: ExprOperator) -> Tuple[ExprOperator, CairoType]: @@ -80,7 +101,8 @@ def visit_ExprPow(self, expr: ExprOperator) -> Tuple[ExprOperator, CairoType]: raise CairoTypeError( f"Operator '**' is not implemented for types " f"'{a_type.format()}' and '{b_type.format()}'.", - location=expr.location) + location=expr.location, + ) return dataclasses.replace(expr, a=a_expr, b=b_expr), TypeFelt(location=expr.location) def visit_ExprAddressOf(self, expr: ExprAddressOf) -> Tuple[Expression, TypePointer]: @@ -92,7 +114,8 @@ def visit_ExprNeg(self, expr: ExprNeg) -> Tuple[ExprNeg, TypeFelt]: if not isinstance(inner_type, TypeFelt): raise CairoTypeError( f"Unary '-' is not supported for type '{inner_type.format()}'.", - location=expr.location) + location=expr.location, + ) return dataclasses.replace(expr, val=inner_expr), TypeFelt(location=expr.location) @@ -107,15 +130,17 @@ def visit_ExprDeref(self, expr: ExprDeref) -> Tuple[ExprDeref, CairoType]: return dataclasses.replace(expr, addr=addr_expr), addr_type.pointee else: raise CairoTypeError( - f"Cannot dereference type '{addr_type.format()}'.", - location=expr.location) + f"Cannot dereference type '{addr_type.format()}'.", location=expr.location + ) @staticmethod def verify_offset_is_felt(offset_type: CairoType, offset_location: Location): if not isinstance(offset_type, TypeFelt): raise CairoTypeError( - 'Cannot apply subscript-operator with offset of non-felt type ' - f"'{offset_type.format()}'.", location=offset_location) + "Cannot apply subscript-operator with offset of non-felt type " + f"'{offset_type.format()}'.", + location=offset_location, + ) def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoType]: inner_expr, inner_type = self.visit(expr.expr) @@ -126,16 +151,18 @@ def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoTyp offset_expr = ExpressionSimplifier().visit(offset_expr) if not isinstance(offset_expr, ExprConst): raise CairoTypeError( - 'Subscript-operator for tuples supports only constant offsets, found ' + "Subscript-operator for tuples supports only constant offsets, found " f"'{type(offset_expr).__name__}'.", - location=offset_expr.location) + location=offset_expr.location, + ) offset_value = offset_expr.val tuple_len = len(inner_type.members) if not 0 <= offset_value < tuple_len: raise CairoTypeError( - f'Tuple index {offset_value} is out of range [0, {tuple_len}).', - location=expr.location) + f"Tuple index {offset_value} is out of range [0, {tuple_len}).", + location=expr.location, + ) item_type = inner_type.members[offset_value] @@ -144,22 +171,27 @@ def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoTyp return ( # Take the inner item, but keep the original expression's location. dataclasses.replace( - inner_expr.members.args[offset_value].expr, location=expr.location), - item_type) + inner_expr.members.args[offset_value].expr, location=expr.location + ), + item_type, + ) elif isinstance(inner_expr, ExprDeref): # Handles pointers cast as tuples*, e.g. `[cast(ap, (felt, felt)*][0]`. addr = inner_expr.addr offset_in_felts = ExprConst( val=sum(map(self.get_size, inner_type.members[:offset_value])), - location=offset_expr.location) + location=offset_expr.location, + ) addr_with_offset = ExprOperator( - a=addr, op='+', b=offset_in_felts, location=expr.location) + a=addr, op="+", b=offset_in_felts, location=expr.location + ) return ExprDeref(addr=addr_with_offset, location=expr.location), item_type else: raise CairoTypeError( - 'Unexpected expression typed as TypeTuple. Expected ExprTuple or ExprDeref, ' + "Unexpected expression typed as TypeTuple. Expected ExprTuple or ExprDeref, " f"found '{type(inner_expr).__name__}'.", - location=expr.location) + location=expr.location, + ) elif isinstance(inner_type, TypePointer): self.verify_offset_is_felt(offset_type, offset_expr.location) try: @@ -171,25 +203,31 @@ def visit_ExprSubscript(self, expr: ExprSubscript) -> Tuple[Expression, CairoTyp element_size_expr = ExprConst(val=element_size, location=expr.location) modified_offset_expr = ExprOperator( - a=offset_expr, op='*', b=element_size_expr, location=expr.location) + a=offset_expr, op="*", b=element_size_expr, location=expr.location + ) simplified_expr = ExprDeref( addr=ExprOperator( - a=inner_expr, op='+', b=modified_offset_expr, location=expr.location), - location=expr.location) + a=inner_expr, op="+", b=modified_offset_expr, location=expr.location + ), + location=expr.location, + ) return simplified_expr, inner_type.pointee else: raise CairoTypeError( - 'Cannot apply subscript-operator to non-pointer, non-tuple type ' + "Cannot apply subscript-operator to non-pointer, non-tuple type " f"'{inner_type.format()}'.", - location=expr.location) + location=expr.location, + ) def verify_identifier_manager_initialized(self, location: Optional[Location]): if self.identifiers_initalized: return raise CairoTypeError( - 'Identifiers must be initialized for type-simplification of dot-operator ' - 'expressions.', location=location) + "Identifiers must be initialized for type-simplification of dot-operator " + "expressions.", + location=location, + ) def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: self.verify_identifier_manager_initialized(location=expr.location) @@ -198,32 +236,39 @@ def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: if isinstance(inner_type, TypePointer): if not isinstance(inner_type.pointee, TypeStruct): raise CairoTypeError( - f'Cannot apply dot-operator to pointer-to-non-struct type ' - f"'{inner_type.format()}'.", location=expr.location) + f"Cannot apply dot-operator to pointer-to-non-struct type " + f"'{inner_type.format()}'.", + location=expr.location, + ) # Allow for . as ->, once. inner_type = inner_type.pointee elif isinstance(inner_type, TypeStruct): if isinstance(inner_expr, ExprTuple): raise CairoTypeError( - 'Accessing struct members for r-value structs is not supported yet.', - location=expr.location) + "Accessing struct members for r-value structs is not supported yet.", + location=expr.location, + ) # Get the address, to evaluate . as ->. inner_expr = get_expr_addr(inner_expr) else: raise CairoTypeError( f"Cannot apply dot-operator to non-struct type '{inner_type.format()}'.", - location=expr.location) + location=expr.location, + ) try: struct_def = get_struct_definition( - struct_name=inner_type.resolved_scope, identifier_manager=self.identifiers) + struct_name=inner_type.resolved_scope, identifier_manager=self.identifiers + ) except Exception as exc: raise CairoTypeError(str(exc), location=expr.location) if expr.member.name not in struct_def.members: raise CairoTypeError( f"Member '{expr.member.name}' does not appear in definition of struct " - f"'{inner_type.format()}'.", location=expr.location) + f"'{inner_type.format()}'.", + location=expr.location, + ) member_definition = struct_def.members[expr.member.name] member_type = member_definition.cairo_type member_offset = member_definition.offset @@ -233,8 +278,9 @@ def visit_ExprDot(self, expr: ExprDot) -> Tuple[ExprDeref, CairoType]: else: mem_offset_expr = ExprConst(val=member_offset, location=expr.location) simplified_expr = ExprDeref( - addr=ExprOperator(a=inner_expr, op='+', b=mem_offset_expr, location=expr.location), - location=expr.location) + addr=ExprOperator(a=inner_expr, op="+", b=mem_offset_expr, location=expr.location), + location=expr.location, + ) return simplified_expr, member_type @@ -243,11 +289,16 @@ def visit_ExprCast(self, expr: ExprCast) -> Tuple[Expression, CairoType]: dest_type = expr.dest_type if not check_cast( - src_type=src_type, dest_type=dest_type, identifier_manager=self.identifiers, - expr=inner_expr, cast_type=expr.cast_type): + src_type=src_type, + dest_type=dest_type, + identifier_manager=self.identifiers, + expr=inner_expr, + cast_type=expr.cast_type, + ): raise CairoTypeError( f"Cannot cast '{src_type.format()}' to '{dest_type.format()}'.", - location=expr.location) + location=expr.location, + ) # Remove the cast() from the expression, but keep its original location. return dataclasses.replace(inner_expr, location=expr.location), dest_type @@ -257,18 +308,20 @@ def visit_ExprTuple(self, expr: ExprTuple) -> Tuple[ExprTuple, TypeTuple]: # Call visit on each member to obtain a list of the form (expr, type). member_expr_types = [self.visit(arg.expr) for arg in args] result_members = [ - dataclasses.replace(arg, expr=expr) for arg, (expr, _) in zip(args, member_expr_types)] + dataclasses.replace(arg, expr=expr) for arg, (expr, _) in zip(args, member_expr_types) + ] result_expr = dataclasses.replace( - expr, members=dataclasses.replace(expr.members, args=result_members)) + expr, members=dataclasses.replace(expr.members, args=result_members) + ) cairo_type = TypeTuple( - members=[expr_type for expr, expr_type in member_expr_types], - location=expr.location) + members=[expr_type for expr, expr_type in member_expr_types], location=expr.location + ) return result_expr, cairo_type def simplify_type_system( - expr: Expression, - identifiers: Optional[IdentifierManager] = None) -> Tuple[Expression, CairoType]: + expr: Expression, identifiers: Optional[IdentifierManager] = None +) -> Tuple[Expression, CairoType]: """ Given an expression returns a type-simplified expression and its Cairo type. This includes checking types in operations, removing casts, and expanding dot and subscript diff --git a/src/starkware/cairo/lang/compiler/type_system_visitor_test.py b/src/starkware/cairo/lang/compiler/type_system_visitor_test.py index ac2431c4..5a7faff7 100644 --- a/src/starkware/cairo/lang/compiler/type_system_visitor_test.py +++ b/src/starkware/cairo/lang/compiler/type_system_visitor_test.py @@ -5,7 +5,12 @@ from starkware.cairo.lang.compiler.ast.ast_objects_test_utils import remove_parentheses from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, TypeFelt, TypePointer, TypeStruct, TypeTuple) + CairoType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast_objects_test import remove_parentheses from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager @@ -18,161 +23,208 @@ def simplify_type_system_test( - orig_expr: str, simplified_expr: str, simplified_type: CairoType, - identifiers: Optional[IdentifierManager] = None): + orig_expr: str, + simplified_expr: str, + simplified_type: CairoType, + identifiers: Optional[IdentifierManager] = None, +): parsed_expr = mark_types_in_expr_resolved(parse_expr(orig_expr)) assert simplify_type_system(parsed_expr, identifiers=identifiers) == ( - parse_expr(simplified_expr), simplified_type) + parse_expr(simplified_expr), + simplified_type, + ) def test_type_visitor(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) t_star = TypePointer(pointee=t) t_star2 = TypePointer(pointee=t_star) - simplify_type_system_test('fp + 3 + [ap]', 'fp + 3 + [ap]', TypeFelt()) - simplify_type_system_test('cast(fp + 3 + [ap], T*)', 'fp + 3 + [ap]', t_star) + simplify_type_system_test("fp + 3 + [ap]", "fp + 3 + [ap]", TypeFelt()) + simplify_type_system_test("cast(fp + 3 + [ap], T*)", "fp + 3 + [ap]", t_star) # Two casts. - simplify_type_system_test('cast(cast(fp, T*), felt)', 'fp', TypeFelt()) + simplify_type_system_test("cast(cast(fp, T*), felt)", "fp", TypeFelt()) # Cast from T to T. - simplify_type_system_test('cast([cast(fp, T*)], T)', '[fp]', t) + simplify_type_system_test("cast([cast(fp, T*)], T)", "[fp]", t) # Dereference. - simplify_type_system_test('[cast(fp, T**)]', '[fp]', t_star) - simplify_type_system_test('[[cast(fp, T**)]]', '[[fp]]', t) + simplify_type_system_test("[cast(fp, T**)]", "[fp]", t_star) + simplify_type_system_test("[[cast(fp, T**)]]", "[[fp]]", t) # Address of. - simplify_type_system_test('&([[cast(fp, T**)]])', '[fp]', t_star) - simplify_type_system_test('&&[[cast(fp, T**)]]', 'fp', t_star2) + simplify_type_system_test("&([[cast(fp, T**)]])", "[fp]", t_star) + simplify_type_system_test("&&[[cast(fp, T**)]]", "fp", t_star2) def test_type_tuples(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) t_star = TypePointer(pointee=t) # Simple tuple. simplify_type_system_test( - '(fp, [cast(fp, T*)], cast(fp,T*))', - '(fp, [fp], fp)', TypeTuple(members=[TypeFelt(), t, t_star],)) + "(fp, [cast(fp, T*)], cast(fp,T*))", + "(fp, [fp], fp)", + TypeTuple( + members=[TypeFelt(), t, t_star], + ), + ) # Nested. - simplify_type_system_test('(fp, (), ([cast(fp, T*)],))', '(fp, (), ([fp],))', TypeTuple( - members=[ - TypeFelt(), - TypeTuple(members=[]), - TypeTuple(members=[t])], - )) + simplify_type_system_test( + "(fp, (), ([cast(fp, T*)],))", + "(fp, (), ([fp],))", + TypeTuple( + members=[TypeFelt(), TypeTuple(members=[]), TypeTuple(members=[t])], + ), + ) def test_type_tuples_failures(): identifier_dict = { - scope('T'): StructDefinition( - full_name=scope('T'), + scope("T"): StructDefinition( + full_name=scope("T"), members={ - 'x': MemberDefinition(offset=0, cairo_type=TypeFelt()), - 'y': MemberDefinition(offset=1, cairo_type=TypeFelt()), + "x": MemberDefinition(offset=0, cairo_type=TypeFelt()), + "y": MemberDefinition(offset=1, cairo_type=TypeFelt()), }, size=2, ), } identifiers = IdentifierManager.from_dict(identifier_dict) - verify_exception('1 + cast((1, 2), T).x', """ + verify_exception( + "1 + cast((1, 2), T).x", + """ file:?:?: Accessing struct members for r-value structs is not supported yet. 1 + cast((1, 2), T).x ^***************^ -""", identifiers=identifiers) +""", + identifiers=identifiers, + ) def test_type_subscript_op(): felt_star_star = TypePointer(pointee=TypePointer(pointee=TypeFelt())) - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) t_star = TypePointer(pointee=t) - identifier_dict = {scope('T'): StructDefinition(full_name=scope('T'), members={}, size=7)} + identifier_dict = {scope("T"): StructDefinition(full_name=scope("T"), members={}, size=7)} identifiers = IdentifierManager.from_dict(identifier_dict) - simplify_type_system_test('cast(fp, felt*)[3]', '[fp + 3 * 1]', TypeFelt()) - simplify_type_system_test('cast(fp, felt***)[0]', '[fp + 0 * 1]', felt_star_star) - simplify_type_system_test('[cast(fp, T****)][ap][ap]', '[[[fp] + ap * 1] + ap * 1]', t_star) + simplify_type_system_test("cast(fp, felt*)[3]", "[fp + 3 * 1]", TypeFelt()) + simplify_type_system_test("cast(fp, felt***)[0]", "[fp + 0 * 1]", felt_star_star) + simplify_type_system_test("[cast(fp, T****)][ap][ap]", "[[[fp] + ap * 1] + ap * 1]", t_star) simplify_type_system_test( - 'cast(fp, T**)[1][2]', '[[fp + 1 * 1] + 2 * 7]', t, identifiers=identifiers) + "cast(fp, T**)[1][2]", "[[fp + 1 * 1] + 2 * 7]", t, identifiers=identifiers + ) # Test that 'cast(fp, T*)[2 * ap + 3]' simplifies into '[fp + (2 * ap + 3) * 7]', but without # the parentheses. assert simplify_type_system( - mark_types_in_expr_resolved(parse_expr('cast(fp, T*)[2 * ap + 3]')), - identifiers=identifiers) == ( - remove_parentheses(parse_expr('[fp + (2 * ap + 3) * 7]')), t) + mark_types_in_expr_resolved(parse_expr("cast(fp, T*)[2 * ap + 3]")), identifiers=identifiers + ) == (remove_parentheses(parse_expr("[fp + (2 * ap + 3) * 7]")), t) # Test subscript operator for tuples. - simplify_type_system_test('(cast(fp, felt**), fp, cast(fp, T*))[2]', 'fp', t_star) - simplify_type_system_test('(cast(fp, felt**), fp, cast(fp, T*))[0]', 'fp', felt_star_star) - simplify_type_system_test('(cast(fp, felt**), ap, cast(fp, T*))[3*4 - 11]', 'ap', TypeFelt()) - simplify_type_system_test('[cast(ap, (felt, felt)*)][0]', '[ap + 0]', TypeFelt()) + simplify_type_system_test("(cast(fp, felt**), fp, cast(fp, T*))[2]", "fp", t_star) + simplify_type_system_test("(cast(fp, felt**), fp, cast(fp, T*))[0]", "fp", felt_star_star) + simplify_type_system_test("(cast(fp, felt**), ap, cast(fp, T*))[3*4 - 11]", "ap", TypeFelt()) + simplify_type_system_test("[cast(ap, (felt, felt)*)][0]", "[ap + 0]", TypeFelt()) simplify_type_system_test( - '[cast(ap, (T*, T, felt, T*, felt*)*)][3]', '[ap + 9]', t_star, identifiers=identifiers) + "[cast(ap, (T*, T, felt, T*, felt*)*)][3]", "[ap + 9]", t_star, identifiers=identifiers + ) # Test failures. - verify_exception('(fp, fp, fp)[cast(ap, felt*)]', """ + verify_exception( + "(fp, fp, fp)[cast(ap, felt*)]", + """ file:?:?: Cannot apply subscript-operator with offset of non-felt type 'felt*'. (fp, fp, fp)[cast(ap, felt*)] ^*************^ -""") +""", + ) - verify_exception('(fp, fp, fp)[[ap]]', """ + verify_exception( + "(fp, fp, fp)[[ap]]", + """ file:?:?: Subscript-operator for tuples supports only constant offsets, found 'ExprDeref'. (fp, fp, fp)[[ap]] ^**^ -""") +""", + ) # The simplifier in TypeSystemVisitor cannot access PRIME, so PyConsts are unsimplified. - verify_exception('(fp, fp, fp)[%[1%]]', """ + verify_exception( + "(fp, fp, fp)[%[1%]]", + """ file:?:?: Subscript-operator for tuples supports only constant offsets, found 'ExprPyConst'. (fp, fp, fp)[%[1%]] ^***^ -""") +""", + ) - verify_exception('(fp, fp, fp)[3]', """ + verify_exception( + "(fp, fp, fp)[3]", + """ file:?:?: Tuple index 3 is out of range [0, 3). (fp, fp, fp)[3] ^*************^ -""") +""", + ) - verify_exception('[cast(fp, (T*, T, felt)*)][-1]', """ + verify_exception( + "[cast(fp, (T*, T, felt)*)][-1]", + """ file:?:?: Tuple index -1 is out of range [0, 3). [cast(fp, (T*, T, felt)*)][-1] ^****************************^ -""") +""", + ) - verify_exception('cast(fp, felt)[0]', """ + verify_exception( + "cast(fp, felt)[0]", + """ file:?:?: Cannot apply subscript-operator to non-pointer, non-tuple type 'felt'. cast(fp, felt)[0] ^***************^ -""") +""", + ) - verify_exception('[cast(fp, T*)][0]', """ + verify_exception( + "[cast(fp, T*)][0]", + """ file:?:?: Cannot apply subscript-operator to non-pointer, non-tuple type 'T'. [cast(fp, T*)][0] ^***************^ -""") +""", + ) - verify_exception('cast(fp, felt*)[[cast(ap, T*)]]', """ + verify_exception( + "cast(fp, felt*)[[cast(ap, T*)]]", + """ file:?:?: Cannot apply subscript-operator with offset of non-felt type 'T'. cast(fp, felt*)[[cast(ap, T*)]] ^************^ -""") +""", + ) - verify_exception('cast(fp, Z*)[0]', """ + verify_exception( + "cast(fp, Z*)[0]", + """ file:?:?: Unknown identifier 'Z'. cast(fp, Z*)[0] ^*************^ -""", identifiers=identifiers) +""", + identifiers=identifiers, + ) - verify_exception('cast(fp, T*)[0]', """ + verify_exception( + "cast(fp, T*)[0]", + """ file:?:?: Unknown identifier 'T'. cast(fp, T*)[0] ^*************^ -""", identifiers=None) +""", + identifiers=None, + ) def test_type_dot_op(): @@ -194,34 +246,34 @@ def test_type_dot_op(): member r : R* end """ - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) - s = TypeStruct(scope=scope('S'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) + s = TypeStruct(scope=scope("S"), is_fully_resolved=True) s_star = TypePointer(pointee=s) - r = TypeStruct(scope=scope('R'), is_fully_resolved=True) + r = TypeStruct(scope=scope("R"), is_fully_resolved=True) r_star = TypePointer(pointee=r) identifier_dict = { - scope('T'): StructDefinition( - full_name=scope('T'), + scope("T"): StructDefinition( + full_name=scope("T"), members={ - 't': MemberDefinition(offset=0, cairo_type=TypeFelt()), - 's': MemberDefinition(offset=1, cairo_type=s), - 'sp': MemberDefinition(offset=3, cairo_type=s_star), + "t": MemberDefinition(offset=0, cairo_type=TypeFelt()), + "s": MemberDefinition(offset=1, cairo_type=s), + "sp": MemberDefinition(offset=3, cairo_type=s_star), }, size=4, ), - scope('S'): StructDefinition( - full_name=scope('S'), + scope("S"): StructDefinition( + full_name=scope("S"), members={ - 'x': MemberDefinition(offset=0, cairo_type=TypeFelt()), - 'y': MemberDefinition(offset=1, cairo_type=TypeFelt()), + "x": MemberDefinition(offset=0, cairo_type=TypeFelt()), + "y": MemberDefinition(offset=1, cairo_type=TypeFelt()), }, size=2, ), - scope('R'): StructDefinition( - full_name=scope('R'), + scope("R"): StructDefinition( + full_name=scope("R"), members={ - 'r': MemberDefinition(offset=0, cairo_type=r_star), + "r": MemberDefinition(offset=0, cairo_type=r_star), }, size=1, ), @@ -230,121 +282,166 @@ def test_type_dot_op(): identifiers = IdentifierManager.from_dict(identifier_dict) for (orig_expr, simplified_expr, simplified_type) in [ - ('[cast(fp, T*)].t', '[fp]', TypeFelt()), - ('[cast(fp, T*)].s', '[fp + 1]', s), - ('[cast(fp, T*)].sp', '[fp + 3]', s_star), - ('[cast(fp, T*)].s.x', '[fp + 1]', TypeFelt()), - ('[cast(fp, T*)].s.y', '[fp + 1 + 1]', TypeFelt()), - ('[[cast(fp, T*)].sp].x', '[[fp + 3]]', TypeFelt()), - ('[cast(fp, R*)]', '[fp]', r), - ('[cast(fp, R*)].r', '[fp]', r_star), - ('[[[cast(fp, R*)].r].r].r', '[[[fp]]]', r_star), + ("[cast(fp, T*)].t", "[fp]", TypeFelt()), + ("[cast(fp, T*)].s", "[fp + 1]", s), + ("[cast(fp, T*)].sp", "[fp + 3]", s_star), + ("[cast(fp, T*)].s.x", "[fp + 1]", TypeFelt()), + ("[cast(fp, T*)].s.y", "[fp + 1 + 1]", TypeFelt()), + ("[[cast(fp, T*)].sp].x", "[[fp + 3]]", TypeFelt()), + ("[cast(fp, R*)]", "[fp]", r), + ("[cast(fp, R*)].r", "[fp]", r_star), + ("[[[cast(fp, R*)].r].r].r", "[[[fp]]]", r_star), # Test . as -> - ('cast(fp, T*).t', '[fp]', TypeFelt()), - ('cast(fp, T*).sp.y', '[[fp + 3] + 1]', TypeFelt()), - ('cast(fp, R*).r.r.r', '[[[fp]]]', r_star), + ("cast(fp, T*).t", "[fp]", TypeFelt()), + ("cast(fp, T*).sp.y", "[[fp + 3] + 1]", TypeFelt()), + ("cast(fp, R*).r.r.r", "[[[fp]]]", r_star), # More tests. - ('(cast(fp, T*).s)', '[fp + 1]', s), - ('(cast(fp, T*).s).x', '[fp + 1]', TypeFelt()), - ('(&(cast(fp, T*).s)).x', '[fp + 1]', TypeFelt()) + ("(cast(fp, T*).s)", "[fp + 1]", s), + ("(cast(fp, T*).s).x", "[fp + 1]", TypeFelt()), + ("(&(cast(fp, T*).s)).x", "[fp + 1]", TypeFelt()), ]: simplify_type_system_test( - orig_expr, simplified_expr, simplified_type, identifiers=identifiers) + orig_expr, simplified_expr, simplified_type, identifiers=identifiers + ) # Test failures. - verify_exception('cast(fp, felt).x', """ + verify_exception( + "cast(fp, felt).x", + """ file:?:?: Cannot apply dot-operator to non-struct type 'felt'. cast(fp, felt).x ^**************^ -""", identifiers=identifiers) +""", + identifiers=identifiers, + ) - verify_exception('cast(fp, felt*).x', """ + verify_exception( + "cast(fp, felt*).x", + """ file:?:?: Cannot apply dot-operator to pointer-to-non-struct type 'felt*'. cast(fp, felt*).x ^***************^ -""", identifiers=identifiers) +""", + identifiers=identifiers, + ) - verify_exception('cast(fp, T*).x', """ + verify_exception( + "cast(fp, T*).x", + """ file:?:?: Member 'x' does not appear in definition of struct 'T'. cast(fp, T*).x ^************^ -""", identifiers=identifiers) +""", + identifiers=identifiers, + ) - verify_exception('cast(fp, Z*).x', """ + verify_exception( + "cast(fp, Z*).x", + """ file:?:?: Unknown identifier 'Z'. cast(fp, Z*).x ^************^ -""", identifiers=identifiers) +""", + identifiers=identifiers, + ) - verify_exception('cast(fp, T*).x', """ + verify_exception( + "cast(fp, T*).x", + """ file:?:?: Identifiers must be initialized for type-simplification of dot-operator expressions. cast(fp, T*).x ^************^ -""", identifiers=None) +""", + identifiers=None, + ) - verify_exception('cast(fp, Z*).x', """ + verify_exception( + "cast(fp, Z*).x", + """ file:?:?: Type is expected to be fully resolved at this point. cast(fp, Z*).x ^************^ -""", identifiers=identifiers, resolve_types=False) +""", + identifiers=identifiers, + resolve_types=False, + ) def test_type_visitor_failures(): - verify_exception('[cast(fp, T*)] + 3', """ + verify_exception( + "[cast(fp, T*)] + 3", + """ file:?:?: Operator '+' is not implemented for types 'T' and 'felt'. [cast(fp, T*)] + 3 ^****************^ -""") - verify_exception('[[cast(fp, T*)]]', """ +""", + ) + verify_exception( + "[[cast(fp, T*)]]", + """ file:?:?: Cannot dereference type 'T'. [[cast(fp, T*)]] ^**************^ -""") - verify_exception('[cast(fp, T)]', """ +""", + ) + verify_exception( + "[cast(fp, T)]", + """ file:?:?: Cannot cast 'felt' to 'T'. [cast(fp, T)] ^*********^ -""") - verify_exception('&(cast(fp, T*) + 3)', """ +""", + ) + verify_exception( + "&(cast(fp, T*) + 3)", + """ file:?:?: Expression has no address. &(cast(fp, T*) + 3) ^**************^ -""") +""", + ) def test_type_visitor_pointer_arithmetic(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) t_star = TypePointer(pointee=t) - simplify_type_system_test('cast(fp, T*) + 3', 'fp + 3', t_star) - simplify_type_system_test('cast(fp, T*) - 3', 'fp - 3', t_star) - simplify_type_system_test('cast(fp, T*) - cast(3, T*)', 'fp - 3', TypeFelt()) + simplify_type_system_test("cast(fp, T*) + 3", "fp + 3", t_star) + simplify_type_system_test("cast(fp, T*) - 3", "fp - 3", t_star) + simplify_type_system_test("cast(fp, T*) - cast(3, T*)", "fp - 3", TypeFelt()) def test_type_visitor_pointer_arithmetic_failures(): - verify_exception('cast(fp, T*) + cast(fp, T*)', """ + verify_exception( + "cast(fp, T*) + cast(fp, T*)", + """ file:?:?: Operator '+' is not implemented for types 'T*' and 'T*'. cast(fp, T*) + cast(fp, T*) ^*************************^ -""") - verify_exception('cast(fp, T*) - cast(fp, S*)', """ +""", + ) + verify_exception( + "cast(fp, T*) - cast(fp, S*)", + """ file:?:?: Operator '-' is not implemented for types 'T*' and 'S*'. cast(fp, T*) - cast(fp, S*) ^*************************^ -""") - verify_exception('fp - cast(fp, T*)', """ +""", + ) + verify_exception( + "fp - cast(fp, T*)", + """ file:?:?: Operator '-' is not implemented for types 'felt' and 'T*'. fp - cast(fp, T*) ^***************^ -""") +""", + ) def verify_exception( - expr_str: str, - error: str, - identifiers: Optional[IdentifierManager] = None, - resolve_types=True): + expr_str: str, error: str, identifiers: Optional[IdentifierManager] = None, resolve_types=True +): """ Verifies that calling simplify_type_system() on the code results in the given error. """ @@ -354,4 +451,4 @@ def verify_exception( parsed_expr = mark_types_in_expr_resolved(parsed_expr) simplify_type_system(parsed_expr, identifiers) # Remove line and column information from the error using a regular expression. - assert re.sub(':[0-9]+:[0-9]+: ', 'file:?:?: ', str(e.value)) == error.strip() + assert re.sub(":[0-9]+:[0-9]+: ", "file:?:?: ", str(e.value)) == error.strip() diff --git a/src/starkware/cairo/lang/compiler/type_utils.py b/src/starkware/cairo/lang/compiler/type_utils.py new file mode 100644 index 00000000..33710a2e --- /dev/null +++ b/src/starkware/cairo/lang/compiler/type_utils.py @@ -0,0 +1,41 @@ +from typing import Optional + +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeFelt, TypeStruct, TypeTuple +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition + + +def check_felts_only_type( + cairo_type: CairoType, identifier_manager: IdentifierManager +) -> Optional[int]: + """ + A felts-only type defined to be either felt or a struct whose members are all felts-only types. + Returns the size of the given type if it is felts-only and None otherwise. + """ + + if isinstance(cairo_type, TypeFelt): + return 1 + elif isinstance(cairo_type, TypeStruct): + struct_definition = get_struct_definition( + cairo_type.resolved_scope, identifier_manager=identifier_manager + ) + + size = 0 + for member_def in struct_definition.members.values(): + res = check_felts_only_type( + member_def.cairo_type, identifier_manager=identifier_manager + ) + if res is None: + return None + size += res + return size + elif isinstance(cairo_type, TypeTuple): + size = 0 + for item_type in cairo_type.members: + res = check_felts_only_type(item_type, identifier_manager=identifier_manager) + if res is None: + return None + size += res + return size + else: + return None diff --git a/src/starkware/cairo/lang/compiler/type_utils_test.py b/src/starkware/cairo/lang/compiler/type_utils_test.py new file mode 100644 index 00000000..0415072a --- /dev/null +++ b/src/starkware/cairo/lang/compiler/type_utils_test.py @@ -0,0 +1,53 @@ +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.parser import parse_type +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import preprocess_str +from starkware.cairo.lang.compiler.type_system import mark_type_resolved +from starkware.cairo.lang.compiler.type_utils import check_felts_only_type + + +def test_check_felts_only_type(): + program = preprocess_str( + """ +struct A: + member x : felt +end + +struct B: +end + +struct C: + member x : felt + member y : (felt, A, B) + member z : A +end + +struct D: + member x : felt* +end + +struct E: + member x : D +end + """, + prime=DEFAULT_PRIME, + ) + + for (typ, expected_res) in [ + # Positive cases. + ("test_scope.A", 1), + ("test_scope.B", 0), + ("test_scope.C", 4), + ("(felt, felt)", 2), + ("(felt, (felt, test_scope.C))", 6), + # Negative cases. + ("test_scope.D", None), + ("test_scope.E", None), + ("(felt, test_scope.D)", None), + ]: + assert ( + check_felts_only_type( + cairo_type=mark_type_resolved(parse_type(typ)), + identifier_manager=program.identifiers, + ) + == expected_res + ) diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index 41333202..c4768c26 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.4.0", + "version": "0.4.1", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/instances.py b/src/starkware/cairo/lang/instances.py index 8be94e17..ffeb59fb 100644 --- a/src/starkware/cairo/lang/instances.py +++ b/src/starkware/cairo/lang/instances.py @@ -29,7 +29,7 @@ class DilutedPoolInstanceDef: @dataclasses.dataclass class CairoLayout: - layout_name: str = '' + layout_name: str = "" cpu_component_step: int = 1 # Range check units. rc_units: int = 16 @@ -43,12 +43,12 @@ class CairoLayout: plain_instance = CairoLayout( - layout_name='plain', + layout_name="plain", n_trace_columns=8, ) small_instance = CairoLayout( - layout_name='small', + layout_name="small", rc_units=16, builtins=dict( output=True, @@ -58,7 +58,7 @@ class CairoLayout: element_height=256, element_bits=252, n_inputs=2, - hash_limit=2**251 + 17 * 2**192 + 1, + hash_limit=2 ** 251 + 17 * 2 ** 192 + 1, ), range_check=RangeCheckInstanceDef( ratio=8, @@ -75,7 +75,7 @@ class CairoLayout: ) dex_instance = CairoLayout( - layout_name='dex', + layout_name="dex", rc_units=4, builtins=dict( output=True, @@ -85,7 +85,7 @@ class CairoLayout: element_height=256, element_bits=252, n_inputs=2, - hash_limit=2**251 + 17 * 2**192 + 1, + hash_limit=2 ** 251 + 17 * 2 ** 192 + 1, ), range_check=RangeCheckInstanceDef( ratio=8, @@ -102,7 +102,7 @@ class CairoLayout: ) all_instance = CairoLayout( - layout_name='all', + layout_name="all", rc_units=8, public_memory_fraction=8, diluted_pool_instance_def=DilutedPoolInstanceDef( @@ -118,7 +118,7 @@ class CairoLayout: element_height=256, element_bits=252, n_inputs=2, - hash_limit=2**251 + 17 * 2**192 + 1, + hash_limit=2 ** 251 + 17 * 2 ** 192 + 1, ), range_check=RangeCheckInstanceDef( ratio=8, @@ -139,8 +139,8 @@ class CairoLayout: ) LAYOUTS: Dict[str, CairoLayout] = { - 'plain': plain_instance, - 'small': small_instance, - 'dex': dex_instance, - 'all': all_instance, + "plain": plain_instance, + "small": small_instance, + "dex": dex_instance, + "all": all_instance, } diff --git a/src/starkware/cairo/lang/setup.py b/src/starkware/cairo/lang/setup.py index 5516f845..365eda4e 100644 --- a/src/starkware/cairo/lang/setup.py +++ b/src/starkware/cairo/lang/setup.py @@ -3,42 +3,42 @@ import setuptools DIR = os.path.abspath(os.path.dirname(__file__)) -requirements = open(os.path.join(DIR, 'requirements.txt')).read().splitlines() -version = open(os.path.join(DIR, 'starkware/cairo/lang/VERSION')).read().strip() -long_description = open('README.md', 'r', encoding='utf-8').read() +requirements = open(os.path.join(DIR, "requirements.txt")).read().splitlines() +version = open(os.path.join(DIR, "starkware/cairo/lang/VERSION")).read().strip() +long_description = open("README.md", "r", encoding="utf-8").read() setuptools.setup( - name='cairo-lang', + name="cairo-lang", version=version, - author='Starkware', - author_email='info@starkware.co', - description='Compiler and runner for the Cairo language', + author="Starkware", + author_email="info@starkware.co", + description="Compiler and runner for the Cairo language", install_requires=requirements, long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", packages=setuptools.find_packages(), - python_requires='>=3.6', - setup_requires=['wheel'], - url='https://cairo-lang.org/', + python_requires=">=3.6", + setup_requires=["wheel"], + url="https://cairo-lang.org/", package_data={ - 'starkware.cairo.common': ['*.cairo'], - 'starkware.cairo.lang.compiler': ['cairo.ebnf'], - 'starkware.cairo.lang.tracer': ['*.html', '*.css', '*.js', '*.png'], - 'starkware.cairo.lang': ['VERSION'], - 'starkware.cairo.sharp': ['config.json'], - 'starkware.crypto.signature': ['pedersen_params.json'], - 'starkware.starknet': ['common/*.cairo'], - 'starkware.starknet.core.os': ['*.cairo', '*.json'], - 'starkware.starknet.security': ['whitelists/*.json'], + "starkware.cairo.common": ["*.cairo"], + "starkware.cairo.lang.compiler": ["cairo.ebnf"], + "starkware.cairo.lang.tracer": ["*.html", "*.css", "*.js", "*.png"], + "starkware.cairo.lang": ["VERSION"], + "starkware.cairo.sharp": ["config.json"], + "starkware.crypto.signature": ["pedersen_params.json"], + "starkware.starknet": ["common/*.cairo"], + "starkware.starknet.core.os": ["*.cairo", "*.json"], + "starkware.starknet.security": ["whitelists/*.json"], }, scripts=[ - 'starkware/cairo/lang/scripts/cairo-compile', - 'starkware/cairo/lang/scripts/cairo-format', - 'starkware/cairo/lang/scripts/cairo-hash-program', - 'starkware/cairo/lang/scripts/cairo-reconstruct-traceback', - 'starkware/cairo/lang/scripts/cairo-run', - 'starkware/cairo/lang/scripts/cairo-sharp', - 'starkware/starknet/scripts/starknet-compile', - 'starkware/starknet/scripts/starknet', - ] + "starkware/cairo/lang/scripts/cairo-compile", + "starkware/cairo/lang/scripts/cairo-format", + "starkware/cairo/lang/scripts/cairo-hash-program", + "starkware/cairo/lang/scripts/cairo-reconstruct-traceback", + "starkware/cairo/lang/scripts/cairo-run", + "starkware/cairo/lang/scripts/cairo-sharp", + "starkware/starknet/scripts/starknet-compile", + "starkware/starknet/scripts/starknet", + ], ) diff --git a/src/starkware/cairo/lang/tracer/profile.py b/src/starkware/cairo/lang/tracer/profile.py index 8b068640..a6faac0a 100644 --- a/src/starkware/cairo/lang/tracer/profile.py +++ b/src/starkware/cairo/lang/tracer/profile.py @@ -32,7 +32,7 @@ def __init__(self, initial_fp, memory): # A map from a string to its id in the string table. self._string_to_id = {} # First string in the table must be ''. - self.string_id('') + self.string_id("") # A map from a filename to its id in the mapping table. self._filename_to_mapping_id = {} # A map from a function name to its id in the function table. @@ -46,16 +46,16 @@ def __init__(self, initial_fp, memory): # Global fields. sample_type = self._profile.sample_type.add() - sample_type.type = self.string_id('running time') - sample_type.unit = self.string_id('steps') - self._profile.time_nanos = int(time.time() * 10**9) + sample_type.type = self.string_id("running time") + sample_type.unit = self.string_id("steps") + self._profile.time_nanos = int(time.time() * 10 ** 9) # Main function. - self._func_name_to_id['__main__'] = self.string_id('__main__') + self._func_name_to_id["__main__"] = self.string_id("__main__") main_func = self._profile.function.add() - main_func.id = self.string_id('__main__') - main_func.system_name = main_func.name = self.string_id('') - main_func.filename = self.string_id('') + main_func.id = self.string_id("__main__") + main_func.system_name = main_func.name = self.string_id("") + main_func.filename = self.string_id("") main_func.start_line = 0 def string_id(self, s: str) -> int: @@ -76,8 +76,9 @@ def update_mapping_pc_range(self, filename: str, accessed_pc: int) -> int: """ if filename not in self._filename_to_mapping_id: # Id 0 is reserved. Shift ids. - mapping_id = self._filename_to_mapping_id[filename] = len( - self._filename_to_mapping_id) + 1 + mapping_id = self._filename_to_mapping_id[filename] = ( + len(self._filename_to_mapping_id) + 1 + ) mapping = self._profile.mapping.add() mapping.id = mapping_id mapping.memory_start = accessed_pc @@ -174,12 +175,14 @@ def profile_from_tracer_data(tracer_data: TracerData): continue builder.function_id( name=str(name), - inst_location=tracer_data.program.debug_info.instruction_locations[ident.pc]) + inst_location=tracer_data.program.debug_info.instruction_locations[ident.pc], + ) # Locations. for pc_offset, inst_location in tracer_data.program.debug_info.instruction_locations.items(): builder.location_id( - pc=tracer_data.get_pc_from_offset(pc_offset), inst_location=inst_location) + pc=tracer_data.get_pc_from_offset(pc_offset), inst_location=inst_location + ) # Samples. for trace_entry in tracer_data.trace: diff --git a/src/starkware/cairo/lang/tracer/profiler.py b/src/starkware/cairo/lang/tracer/profiler.py index 1e5c31fd..807077ef 100644 --- a/src/starkware/cairo/lang/tracer/profiler.py +++ b/src/starkware/cairo/lang/tracer/profiler.py @@ -9,18 +9,22 @@ def main(): parser = argparse.ArgumentParser( - description='A tool to generate pprof-style profiling data from a Cairo trace.') - parser.add_argument( - '--program', type=str, required=True, help='A path to the program json file.') - parser.add_argument( - '--memory', type=str, required=True, help='A path to the memory file.') + description="A tool to generate pprof-style profiling data from a Cairo trace." + ) parser.add_argument( - '--trace', type=str, required=True, help='A path to the trace file.') + "--program", type=str, required=True, help="A path to the program json file." + ) + parser.add_argument("--memory", type=str, required=True, help="A path to the memory file.") + parser.add_argument("--trace", type=str, required=True, help="A path to the trace file.") parser.add_argument( - '--debug_info', type=str, required=True, help='A path to the run time debug info file.') + "--debug_info", type=str, required=True, help="A path to the run time debug info file." + ) parser.add_argument( - '--profile_output', type=str, default='profile.pb.gz', - help='A path to an output file to write profile data to. Can be opened in pprof.') + "--profile_output", + type=str, + default="profile.pb.gz", + help="A path to an output file to write profile data to. Can be opened in pprof.", + ) args = parser.parse_args() @@ -33,10 +37,10 @@ def main(): ) data = profile_from_tracer_data(tracer_data) - with open(args.profile_output, 'wb') as fp: + with open(args.profile_output, "wb") as fp: fp.write(data) return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/lang/tracer/tracer.py b/src/starkware/cairo/lang/tracer/tracer.py index 59e91796..a1a812e0 100755 --- a/src/starkware/cairo/lang/tracer/tracer.py +++ b/src/starkware/cairo/lang/tracer/tracer.py @@ -16,7 +16,7 @@ def trace_runner(runner): runner.vm_memory.relocate_memory() runner.vm_memory.freeze() runner.segments.compute_effective_sizes(include_tmp_segments=True) - if not hasattr(runner, 'relocated_trace'): + if not hasattr(runner, "relocated_trace"): runner.relocate() # Print the non-relocated registers, the relocated values are available in the tracer. @@ -27,8 +27,12 @@ def trace_runner(runner): run_tracer( TracerData( - program=runner.program, memory=memory, trace=trace, - program_base=runner.relocate_value(runner.program_base))) + program=runner.program, + memory=memory, + trace=trace, + program_base=runner.relocate_value(runner.program_base), + ) + ) class SimpleTCPServer(socketserver.TCPServer): @@ -47,25 +51,31 @@ class Handler(http.server.SimpleHTTPRequestHandler): def do_GET(self): parsed_path = urllib.parse.urlparse(self.path) query = urllib.parse.parse_qs(parsed_path.query) - if parsed_path.path == '/data.json': + if parsed_path.path == "/data.json": # Create the returned json file. - self.write_json({ - 'code': { - filename: input_file.to_html() - for filename, input_file in tracer_data.input_files.items()}, - 'trace': [ - {'pc': entry.pc, 'ap': entry.ap, 'fp': entry.fp} - for entry in tracer_data.trace], - 'memory': { - addr: field_element_repr(val, tracer_data.program.prime) - for addr, val in tracer_data.memory.items()}, - 'public_memory': tracer_data.public_memory, - 'memory_accesses': tracer_data.memory_accesses, - }) - elif parsed_path.path == '/eval.json': + self.write_json( + { + "code": { + filename: input_file.to_html() + for filename, input_file in tracer_data.input_files.items() + }, + "trace": [ + {"pc": entry.pc, "ap": entry.ap, "fp": entry.fp} + for entry in tracer_data.trace + ], + "memory": { + addr: field_element_repr(val, tracer_data.program.prime) + for addr, val in tracer_data.memory.items() + }, + "public_memory": tracer_data.public_memory, + "memory_accesses": tracer_data.memory_accesses, + } + ) + elif parsed_path.path == "/eval.json": evaluator = WatchEvaluator( - tracer_data, entry=tracer_data.trace[int(query['step'][0])]) - self.write_json([evaluator.eval_suppress_errors(expr) for expr in query['expr']]) + tracer_data, entry=tracer_data.trace[int(query["step"][0])] + ) + self.write_json([evaluator.eval_suppress_errors(expr) for expr in query["expr"]]) else: super().do_GET() @@ -74,10 +84,10 @@ def write_json(self, json_obj): try: self.send_response(200) - self.send_header('Content-type', 'text/json') - self.send_header('Content-Length', str(len(json_str))) + self.send_header("Content-type", "text/json") + self.send_header("Content-Length", str(len(json_str))) self.end_headers() - self.wfile.write(json_str.encode('utf8')) + self.wfile.write(json_str.encode("utf8")) except BrokenPipeError: # Request was canceled. pass @@ -86,31 +96,29 @@ def start_server(): port = 8100 while True: try: - return SimpleTCPServer(('localhost', port), Handler) + return SimpleTCPServer(("localhost", port), Handler) except OSError: pass # port was not available. Try the next one. port += 1 httpd = start_server() - print('Running tracer on http://localhost:%d/' % httpd.server_address[1]) + print("Running tracer on http://localhost:%d/" % httpd.server_address[1]) print() httpd.serve_forever() def main(): parser = argparse.ArgumentParser( - description='A tool to view the trace of a Cairo program execution.') - parser.add_argument( - '--program', type=str, required=True, help='A path to the program json file.') - parser.add_argument( - '--memory', type=str, required=True, help='A path to the memory file.') - parser.add_argument( - '--trace', type=str, required=True, help='A path to the trace file.') - parser.add_argument( - '--air_public_input', type=str, help='A path to the AIR public input file.') + description="A tool to view the trace of a Cairo program execution." + ) parser.add_argument( - '--debug_info', type=str, help='A path to the run time debug info file.') + "--program", type=str, required=True, help="A path to the program json file." + ) + parser.add_argument("--memory", type=str, required=True, help="A path to the memory file.") + parser.add_argument("--trace", type=str, required=True, help="A path to the trace file.") + parser.add_argument("--air_public_input", type=str, help="A path to the AIR public input file.") + parser.add_argument("--debug_info", type=str, help="A path to the run time debug info file.") args = parser.parse_args() @@ -126,5 +134,5 @@ def main(): return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/lang/tracer/tracer_data.py b/src/starkware/cairo/lang/tracer/tracer_data.py index 68aff2e6..4eaabb90 100644 --- a/src/starkware/cairo/lang/tracer/tracer_data.py +++ b/src/starkware/cairo/lang/tracer/tracer_data.py @@ -13,7 +13,9 @@ from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.program import Program from starkware.cairo.lang.compiler.references import ( - FlowTrackingError, SubstituteRegisterTransformer) + FlowTrackingError, + SubstituteRegisterTransformer, +) from starkware.cairo.lang.compiler.resolve_search_result import resolve_search_result from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.substitute_identifiers import substitute_identifiers @@ -58,40 +60,52 @@ def __init__(self, content): self.tags = [] def mark_text( - self, line_start: int, col_start: int, line_end: int, col_end: int, classes: List[str]): + self, line_start: int, col_start: int, line_end: int, col_end: int, classes: List[str] + ): """ Surrounds the given part of the input text with a tag with the given classes. """ - self.marks.append(TextMark( - line_start=line_start, col_start=col_start, line_end=line_end, col_end=col_end, - classes=classes)) + self.marks.append( + TextMark( + line_start=line_start, + col_start=col_start, + line_end=line_end, + col_end=col_end, + classes=classes, + ) + ) # Find the offset of (line_start, col_start) inside the file, by computing the sum of the # lengths of the previous lines and adding col_start. Note that '\n's are not counted in # the sum, so we add line_start to the result. We have to subtract 2 since both # line_start and col_start are 1-based rather than 0-based. - offset_start = sum(map(len, self.lines[:line_start - 1])) + line_start + col_start - 2 + offset_start = sum(map(len, self.lines[: line_start - 1])) + line_start + col_start - 2 # Do the same for (line_end, col_end). - offset_end = sum(map(len, self.lines[:line_end - 1])) + line_end + col_end - 2 + offset_end = sum(map(len, self.lines[: line_end - 1])) + line_end + col_end - 2 self.tags.append((offset_start, -offset_end, f'')) - self.tags.append((offset_end, -float('inf'), '')) + self.tags.append((offset_end, -float("inf"), "")) def to_html(self): """ Returns the content of the file with the added HTML tags. Replaces spaces with ' ' and '\n' with '
'. """ - res = self.content.replace(' ', '\0') + res = self.content.replace(" ", "\0") for pos, size, tag_content in sorted(self.tags, reverse=True): res = res[:pos] + tag_content + res[pos:] - return res.replace('\0', ' ').replace('\n', '
\n') + return res.replace("\0", " ").replace("\n", "
\n") class TracerData: def __init__( - self, program: Program, memory: MemoryDict, trace: List[TraceEntry], program_base: int, - air_public_input: Optional[PublicInput] = None, - debug_info: Optional[DebugInfo] = None): + self, + program: Program, + memory: MemoryDict, + trace: List[TraceEntry], + program_base: int, + air_public_input: Optional[PublicInput] = None, + debug_info: Optional[DebugInfo] = None, + ): """ Constructs a TracerData object. program_base is the memory address where the program is loaded. @@ -123,22 +137,30 @@ def __init__( # Surround the instruction code with a tag. input_file.mark_text( - loc.start_line, loc.start_col, loc.end_line, loc.end_col, - [f'inst{pc_offset}', 'instruction']) + loc.start_line, + loc.start_col, + loc.end_line, + loc.end_col, + [f"inst{pc_offset}", "instruction"], + ) # Find memory accesses for each step. self.memory_accesses = [] for trace_entry in self.trace: run_context = RunContext( - pc=trace_entry.pc, ap=trace_entry.ap, fp=trace_entry.fp, memory=self.memory, - prime=self.program.prime) + pc=trace_entry.pc, + ap=trace_entry.ap, + fp=trace_entry.fp, + memory=self.memory, + prime=self.program.prime, + ) instruction_encoding, imm = run_context.get_instruction_encoding() instruction = decode_instruction(instruction_encoding, imm) dst_addr = run_context.compute_dst_addr(instruction) op0_addr = run_context.compute_op0_addr(instruction) op1_addr = run_context.compute_op1_addr(instruction, self.memory.get(op0_addr)) - self.memory_accesses.append({'dst': dst_addr, 'op0': op0_addr, 'op1': op1_addr}) + self.memory_accesses.append({"dst": dst_addr, "op0": op0_addr, "op1": op1_addr}) def get_pc_offset(self, pc: int) -> int: """ @@ -170,8 +192,13 @@ def get_current_identifier_values(self, entry: TraceEntry) -> Dict[str, str]: @classmethod def from_files( - cls, program_path: str, memory_path: str, trace_path: str, - air_public_input: Optional[str], debug_info_path: Optional[str] = None): + cls, + program_path: str, + memory_path: str, + trace_path: str, + air_public_input: Optional[str], + debug_info_path: Optional[str] = None, + ): """ Factory method constructing TracerData from files. """ @@ -187,13 +214,21 @@ def from_files( else: public_input = None - debug_info = DebugInfo.Schema().load(json.load(open(debug_info_path))) \ - if debug_info_path is not None else None + debug_info = ( + DebugInfo.Schema().load(json.load(open(debug_info_path))) + if debug_info_path is not None + else None + ) # Construct the instance. return cls( - program=program, memory=memory, trace=trace, program_base=program_base, - air_public_input=public_input, debug_info=debug_info) + program=program, + memory=memory, + trace=trace, + program_base=program_base, + air_public_input=public_input, + debug_info=debug_info, + ) def read_memory(memory_path: str, field_bytes: int) -> MemoryDict: @@ -201,7 +236,7 @@ def read_memory(memory_path: str, field_bytes: int) -> MemoryDict: Returns the memory (as a MemoryDict). """ # Use MemoryDict to verify that memory cells are consistent. - with open(memory_path, 'rb') as memory_file: + with open(memory_path, "rb") as memory_file: return MemoryDict.deserialize(memory_file.read(), field_bytes) @@ -211,12 +246,12 @@ def read_trace(trace_path: str) -> List[TraceEntry]: """ entries = [] serialization_size = TraceEntry.serialization_size() - with open(trace_path, 'rb') as trace_file: + with open(trace_path, "rb") as trace_file: while True: entry_serialized = trace_file.read(serialization_size) if not entry_serialized: break - assert len(entry_serialized) == serialization_size, 'Size of trace file is invalid.' + assert len(entry_serialized) == serialization_size, "Size of trace file is invalid." entry = TraceEntry.deserialize(entry_serialized) entries.append(entry) @@ -230,10 +265,10 @@ def field_element_repr(val: int, prime: int) -> str: # Shift val to the range (-prime // 2, prime // 2). shifted_val = (val + prime // 2) % prime - (prime // 2) # If shifted_val is small, use decimal representation. - if abs(shifted_val) < 2**40: + if abs(shifted_val) < 2 ** 40: return str(shifted_val) # Otherwise, use hex representation (allowing a sign if the number is close to prime). - if abs(shifted_val) < 2**100: + if abs(shifted_val) < 2 ** 100: return hex(shifted_val) return hex(val) @@ -241,8 +276,8 @@ def field_element_repr(val: int, prime: int) -> str: class WatchEvaluator(ExpressionEvaluator): def __init__(self, tracer_data: TracerData, entry: TraceEntry[int]): super().__init__( - prime=tracer_data.program.prime, ap=entry.ap, - fp=entry.fp, memory=tracer_data.memory) + prime=tracer_data.program.prime, ap=entry.ap, fp=entry.fp, memory=tracer_data.memory + ) self.tracer_data = tracer_data self.pc = entry.pc self.ap = entry.ap @@ -256,15 +291,16 @@ def __init__(self, tracer_data: TracerData, entry: TraceEntry[int]): self.accessible_scopes = info.accessible_scopes def eval(self, expr): - if expr == 'null': - return '' + if expr == "null": + return "" expr, expr_type = simplify_type_system( substitute_identifiers( - expr=parse_expr(expr), - get_identifier_callback=self.get_variable), - identifiers=self.tracer_data.program.identifiers) + expr=parse_expr(expr), get_identifier_callback=self.get_variable + ), + identifiers=self.tracer_data.program.identifiers, + ) if isinstance(expr_type, TypeStruct): - raise NotImplementedError('Structs are not supported.') + raise NotImplementedError("Structs are not supported.") res = self.visit(expr) if isinstance(res, ExprConst): return field_element_repr(res.val, self.tracer_data.program.prime) @@ -274,7 +310,7 @@ def eval_suppress_errors(self, expr): try: return self.eval(expr) except Exception as exc: - return f'{type(exc).__name__}: {exc}' + return f"{type(exc).__name__}: {exc}" def get_variable(self, var: ExprIdentifier): identifiers = self.tracer_data.program.identifiers @@ -283,28 +319,34 @@ def get_variable(self, var: ExprIdentifier): accessible_scopes=self.accessible_scopes, name=ScopedName.from_string(var.name), ), - identifiers=identifiers) + identifiers=identifiers, + ) if isinstance(identifier_definition, ConstDefinition): return identifier_definition.value if isinstance(identifier_definition, (ReferenceDefinition, OffsetReferenceDefinition)): return self.visit(self.eval_reference(identifier_definition, var.name)) - raise Exception( - f'Unexpected identifier {var.name} of type {identifier_definition.TYPE}.') + raise Exception(f"Unexpected identifier {var.name} of type {identifier_definition.TYPE}.") def eval_reference(self, identifier_definition, var_name: str): pc_offset = self.tracer_data.get_pc_offset(self.pc) assert self.tracer_data.program.debug_info is not None - current_flow_tracking_data = \ - self.tracer_data.program.debug_info.instruction_locations[pc_offset].flow_tracking_data + current_flow_tracking_data = self.tracer_data.program.debug_info.instruction_locations[ + pc_offset + ].flow_tracking_data try: substitute_transformer = SubstituteRegisterTransformer( ap=lambda location: ExprConst(val=self.ap, location=location), - fp=lambda location: ExprConst(val=self.fp, location=location)) - return self.visit(substitute_transformer.visit( - identifier_definition.eval( - reference_manager=self.tracer_data.program.reference_manager, - flow_tracking_data=current_flow_tracking_data))) + fp=lambda location: ExprConst(val=self.fp, location=location), + ) + return self.visit( + substitute_transformer.visit( + identifier_definition.eval( + reference_manager=self.tracer_data.program.reference_manager, + flow_tracking_data=current_flow_tracking_data, + ) + ) + ) except FlowTrackingError: raise FlowTrackingError(f"Invalid reference '{var_name}'.") diff --git a/src/starkware/cairo/lang/tracer/tracer_data_test.py b/src/starkware/cairo/lang/tracer/tracer_data_test.py index 5b655247..36cbb257 100644 --- a/src/starkware/cairo/lang/tracer/tracer_data_test.py +++ b/src/starkware/cairo/lang/tracer/tracer_data_test.py @@ -4,18 +4,20 @@ from starkware.cairo.lang.tracer.tracer_data import InputCodeFile, TracerData, WatchEvaluator from starkware.cairo.lang.vm.cairo_runner import CairoRunner -PRIME = 2**251 + 17 * 2**192 + 1 +PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 def test_input_code_file(): - input_file = InputCodeFile('aTestLine') - input_file.mark_text(1, 2, 1, 6, classes=['test']) # Mark "test" - input_file.mark_text(1, 2, 1, 10, classes=['test_line']) # Mark "test line" - input_file.mark_text(1, 6, 1, 10, classes=['line']) # Mark "line" - input_file.mark_text(1, 1, 1, 2, classes=['a']) # Mark "a" - assert input_file.to_html() == \ - 'aTest' \ + input_file = InputCodeFile("aTestLine") + input_file.mark_text(1, 2, 1, 6, classes=["test"]) # Mark "test" + input_file.mark_text(1, 2, 1, 10, classes=["test_line"]) # Mark "test line" + input_file.mark_text(1, 6, 1, 10, classes=["line"]) # Mark "line" + input_file.mark_text(1, 1, 1, 2, classes=["a"]) # Mark "a" + assert ( + input_file.to_html() + == 'aTest' 'Line' + ) def test_tracer_data(): @@ -34,11 +36,8 @@ def test_tracer_data(): ret end """ - program: Program = compile_cairo( - code=code, - prime=PRIME, - debug_info=True) - runner = CairoRunner(program, layout='small') + program: Program = compile_cairo(code=code, prime=PRIME, debug_info=True) + runner = CairoRunner(program, layout="small") runner.initialize_segments() runner.initialize_main_entrypoint() runner.initialize_vm(hint_locals={}) @@ -49,38 +48,55 @@ def test_tracer_data(): trace = runner.relocated_trace tracer_data = TracerData( - program=program, memory=memory, trace=trace, - program_base=runner.relocate_value(runner.program_base)) + program=program, + memory=memory, + trace=trace, + program_base=runner.relocate_value(runner.program_base), + ) # Test watch evaluator. watch_evaluator = WatchEvaluator(tracer_data=tracer_data, entry=tracer_data.trace[0]) - with pytest.raises(TypeError, match='NoneType'): + with pytest.raises(TypeError, match="NoneType"): watch_evaluator.eval(None) - assert watch_evaluator.eval_suppress_errors('x') == "FlowTrackingError: Invalid reference 'x'." + assert watch_evaluator.eval_suppress_errors("x") == "FlowTrackingError: Invalid reference 'x'." watch_evaluator = WatchEvaluator(tracer_data=tracer_data, entry=tracer_data.trace[1]) - assert watch_evaluator.eval('x') == '2000' + assert watch_evaluator.eval("x") == "2000" watch_evaluator = WatchEvaluator(tracer_data=tracer_data, entry=tracer_data.trace[2]) - assert watch_evaluator.eval('[ap]') == '3000' - assert watch_evaluator.eval('[ap-1]') == '2000' - assert watch_evaluator.eval('[ap-2]') == '1000' - assert watch_evaluator.eval('[fp]') == '1000' - assert watch_evaluator.eval('x') == '5000' + assert watch_evaluator.eval("[ap]") == "3000" + assert watch_evaluator.eval("[ap-1]") == "2000" + assert watch_evaluator.eval("[ap-2]") == "1000" + assert watch_evaluator.eval("[fp]") == "1000" + assert watch_evaluator.eval("x") == "5000" # Test memory_accesses. - assert memory[tracer_data.memory_accesses[0]['op1']] == 1000 - assert memory[tracer_data.memory_accesses[1]['op1']] == 2000 - assert tracer_data.memory_accesses[2]['dst'] == trace[2].ap - assert tracer_data.memory_accesses[2]['op0'] == trace[2].ap - 2 - assert tracer_data.memory_accesses[2]['op1'] == trace[2].ap - 1 + assert memory[tracer_data.memory_accesses[0]["op1"]] == 1000 + assert memory[tracer_data.memory_accesses[1]["op1"]] == 2000 + assert tracer_data.memory_accesses[2]["dst"] == trace[2].ap + assert tracer_data.memory_accesses[2]["op0"] == trace[2].ap - 2 + assert tracer_data.memory_accesses[2]["op1"] == trace[2].ap - 1 # Test current identifier values. - assert tracer_data.get_current_identifier_values(trace[0]) == {'output_ptr': '21'} - assert tracer_data.get_current_identifier_values(trace[1]) == {'output_ptr': '21', 'x': '2000'} + assert tracer_data.get_current_identifier_values(trace[0]) == {"output_ptr": "21"} + assert tracer_data.get_current_identifier_values(trace[1]) == {"output_ptr": "21", "x": "2000"} assert tracer_data.get_current_identifier_values(trace[2]) == { - 'output_ptr': '21', 'x': '5000', 'y': '3000'} + "output_ptr": "21", + "x": "5000", + "y": "3000", + } assert tracer_data.get_current_identifier_values(trace[3]) == { - 'output_ptr': '21', 'x': '5000', 'y': '3000'} + "output_ptr": "21", + "x": "5000", + "y": "3000", + } assert tracer_data.get_current_identifier_values(trace[4]) == { - 'output_ptr': '21', 'x': '5000', 'y': '3000', '__temp0': '1234'} + "output_ptr": "21", + "x": "5000", + "y": "3000", + "__temp0": "1234", + } assert tracer_data.get_current_identifier_values(trace[5]) == { - 'output_ptr': '21', 'x': '5000', 'y': '3000', '__temp0': '1234'} + "output_ptr": "21", + "x": "5000", + "y": "3000", + "__temp0": "1234", + } diff --git a/src/starkware/cairo/lang/version.py b/src/starkware/cairo/lang/version.py index d639a71d..94a0d990 100644 --- a/src/starkware/cairo/lang/version.py +++ b/src/starkware/cairo/lang/version.py @@ -1,3 +1,3 @@ import os -__version__ = open(os.path.join(os.path.dirname(__file__), 'VERSION')).read().strip() +__version__ = open(os.path.join(os.path.dirname(__file__), "VERSION")).read().strip() diff --git a/src/starkware/cairo/lang/vm/air_public_input.py b/src/starkware/cairo/lang/vm/air_public_input.py index 5f199a6d..e5355b89 100644 --- a/src/starkware/cairo/lang/vm/air_public_input.py +++ b/src/starkware/cairo/lang/vm/air_public_input.py @@ -37,8 +37,9 @@ def extract_public_memory(public_input: PublicInput) -> Dict[int, int]: for entry in public_input.public_memory: addr = entry.address value = entry.value - assert addr not in memory, \ - f'Duplicate public memory entries found with the same address: {addr}' + assert ( + addr not in memory + ), f"Duplicate public memory entries found with the same address: {addr}" memory[addr] = value return memory @@ -48,9 +49,7 @@ def extract_program_output(public_input: PublicInput, memory: Dict[int, int]) -> Returns a list of field elements represeting the program output. This function fails if the program doesn't have an output segment. """ - assert 'output' in public_input.memory_segments, 'Missing output segment.' - output_addresses = public_input.memory_segments['output'] - assert output_addresses.stop_ptr is not None, 'Missing stop_ptr for the output segment.' - return [ - memory[addr] - for addr in range(output_addresses.begin_addr, output_addresses.stop_ptr)] + assert "output" in public_input.memory_segments, "Missing output segment." + output_addresses = public_input.memory_segments["output"] + assert output_addresses.stop_ptr is not None, "Missing stop_ptr for the output segment." + return [memory[addr] for addr in range(output_addresses.begin_addr, output_addresses.stop_ptr)] diff --git a/src/starkware/cairo/lang/vm/builtin_runner.py b/src/starkware/cairo/lang/vm/builtin_runner.py index 0d7addf2..9b26407c 100644 --- a/src/starkware/cairo/lang/vm/builtin_runner.py +++ b/src/starkware/cairo/lang/vm/builtin_runner.py @@ -131,8 +131,11 @@ def get_additional_data(self) -> Any: return def extend_additional_data( - self, data: Any, relocate_callback: Callable[[MaybeRelocatable], MaybeRelocatable], - data_is_trusted: bool = True): + self, + data: Any, + relocate_callback: Callable[[MaybeRelocatable], MaybeRelocatable], + data_is_trusted: bool = True, + ): """ Adds the additional data created by another instance of the builtin runner. relocate_callback is a callback function used to relocate the addresses. @@ -165,8 +168,8 @@ class SimpleBuiltinRunner(BuiltinRunner): """ def __init__( - self, name: str, included: bool, ratio: int, cells_per_instance: int, - n_input_cells: int): + self, name: str, included: bool, ratio: int, cells_per_instance: int, n_input_cells: int + ): """ Constructs a SimpleBuiltinRunner. cells_per_instance is the number of memory cells per invocation. @@ -185,16 +188,17 @@ def initialize_segments(self, runner): self.base = runner.segments.add() def initial_stack(self) -> List[MaybeRelocatable]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." return [self.base] if self.included else [] def final_stack(self, runner, pointer): if self.included: self.stop_ptr = runner.vm_memory[pointer - 1] used = self.get_used_instances(runner=runner) * self.cells_per_instance - assert self.stop_ptr == self.base + used, \ - f'Invalid stop pointer for {self.name}. ' + \ - f'Expected: {self.base + used}, found: {self.stop_ptr}' + assert self.stop_ptr == self.base + used, ( + f"Invalid stop pointer for {self.name}. " + + f"Expected: {self.base + used}, found: {self.stop_ptr}" + ) return pointer - 1 else: self.stop_ptr = self.base @@ -213,12 +217,14 @@ def get_allocated_memory_units(self, runner): def get_used_cells_and_allocated_size(self, runner): if runner.vm.current_step < self.ratio: raise InsufficientAllocatedCells( - f'Number of steps must be at least {self.ratio} for the {self.name} builtin.') + f"Number of steps must be at least {self.ratio} for the {self.name} builtin." + ) used = self.get_used_cells(runner) size = self.cells_per_instance * safe_div(runner.vm.current_step, self.ratio) if used > size: raise InsufficientAllocatedCells( - f'The {self.name} builtin used {used} cells but the capacity is {size}.') + f"The {self.name} builtin used {used} cells but the capacity is {size}." + ) return used, size def finalize_segments(self, runner): @@ -227,10 +233,12 @@ def finalize_segments(self, runner): runner.segments.finalize(segment_index=self.base.segment_index, size=size) def get_memory_segment_addresses(self, runner): - return {self.name: MemorySegmentAddresses( - begin_addr=self.base, - stop_ptr=self.stop_ptr, - )} + return { + self.name: MemorySegmentAddresses( + begin_addr=self.base, + stop_ptr=self.stop_ptr, + ) + } def run_security_checks(self, runner): offsets = { @@ -242,19 +250,17 @@ def run_security_checks(self, runner): # Verify that n is not too large to make sure the expected_offsets set that is constructed # below is not too large. - assert n <= len(offsets) // self.n_input_cells, f'Missing memory cells for {self.name}.' + assert n <= len(offsets) // self.n_input_cells, f"Missing memory cells for {self.name}." # Check that the two inputs (x and y) of each instance are set. expected_offsets = { - self.cells_per_instance * i + j - for i in range(n) - for j in range(self.n_input_cells)} + self.cells_per_instance * i + j for i in range(n) for j in range(self.n_input_cells) + } if not expected_offsets <= offsets: missing_offsets = list(expected_offsets - offsets) - dots = '...' if len(missing_offsets) > 20 else '.' - missing_offsets_str = ', '.join(map(str, missing_offsets[:20])) + dots - raise AssertionError( - f'Missing memory cells for {self.name}: {missing_offsets_str}') + dots = "..." if len(missing_offsets) > 20 else "." + missing_offsets_str = ", ".join(map(str, missing_offsets[:20])) + dots + raise AssertionError(f"Missing memory cells for {self.name}: {missing_offsets_str}") def get_memory_accesses(self, runner): segment_size = runner.segments.get_segment_size(self.base.segment_index) diff --git a/src/starkware/cairo/lang/vm/cairo_pie.py b/src/starkware/cairo/lang/vm/cairo_pie.py index 0957cb02..0e858b23 100644 --- a/src/starkware/cairo/lang/vm/cairo_pie.py +++ b/src/starkware/cairo/lang/vm/cairo_pie.py @@ -31,8 +31,8 @@ class SegmentInfo: size: int def run_validity_checks(self): - assert isinstance(self.index, int) and 0 <= self.index < 2 ** 30, 'Invalid segment index.' - assert isinstance(self.size, int) and 0 <= self.size < 2 ** 30, 'Invalid segment size.' + assert isinstance(self.index, int) and 0 <= self.index < 2 ** 30, "Invalid segment index." + assert isinstance(self.size, int) and 0 <= self.size < 2 ** 30, "Invalid segment size." @marshmallow_dataclass.dataclass @@ -55,19 +55,21 @@ def field_bytes(self) -> int: return math.ceil(self.program.prime.bit_length() / 8) def validate_segment_order(self): - assert self.program_segment.index == 0, 'Invalid segment index for program_segment.' - assert self.execution_segment.index == 1, 'Invalid segment index for execution_segment.' + assert self.program_segment.index == 0, "Invalid segment index for program_segment." + assert self.execution_segment.index == 1, "Invalid segment index for execution_segment." for expected_segment, (name, builtin_segment) in enumerate( - self.builtin_segments.items(), 2): - assert builtin_segment.index == expected_segment, f'Invalid segment index for {name}.' + self.builtin_segments.items(), 2 + ): + assert builtin_segment.index == expected_segment, f"Invalid segment index for {name}." n_builtins = len(self.builtin_segments) - assert self.ret_fp_segment.index == n_builtins + 2, \ - f'Invalid segment index for ret_fp_segment. {self.ret_fp_segment.index}' - assert self.ret_pc_segment.index == n_builtins + 3, \ - 'Invalid segment index for ret_pc_segment.' - for expected_segment, segment in enumerate( - self.extra_segments, n_builtins + 4): - assert segment.index == expected_segment, 'Invalid segment indices for extra_segments.' + assert ( + self.ret_fp_segment.index == n_builtins + 2 + ), f"Invalid segment index for ret_fp_segment. {self.ret_fp_segment.index}" + assert ( + self.ret_pc_segment.index == n_builtins + 3 + ), "Invalid segment index for ret_pc_segment." + for expected_segment, segment in enumerate(self.extra_segments, n_builtins + 4): + assert segment.index == expected_segment, "Invalid segment indices for extra_segments." def all_segments(self) -> List[SegmentInfo]: """ @@ -91,21 +93,23 @@ def segment_sizes(self) -> Dict[int, int]: def run_validity_checks(self): self.program.run_validity_checks() assert isinstance(self.builtin_segments, dict) and all( - is_valid_builtin_name(name) for name in self.builtin_segments.keys()), \ - 'Invalid builtin_segments.' - assert isinstance(self.extra_segments, list), 'Invalid type for extra_segments.' + is_valid_builtin_name(name) for name in self.builtin_segments.keys() + ), "Invalid builtin_segments." + assert isinstance(self.extra_segments, list), "Invalid type for extra_segments." for segment_info in self.all_segments(): - assert isinstance(segment_info, SegmentInfo), 'Invalid type for segment_info.' + assert isinstance(segment_info, SegmentInfo), "Invalid type for segment_info." segment_info.run_validity_checks() - assert self.program_segment.size == len(self.program.data), \ - 'Program length does not match the program segment size.' + assert self.program_segment.size == len( + self.program.data + ), "Program length does not match the program segment size." assert self.program.builtins == list(self.builtin_segments.keys()), ( - f'Builtin list mismatch in builtin_segments. Builtins: {self.program.builtins}, ' - f'segment keys: {list(self.builtin_segments.keys())}.') - assert self.ret_fp_segment.size == 0, 'Invalid segment size for ret_fp. Must be 0.' - assert self.ret_pc_segment.size == 0, 'Invalid segment size for ret_pc. Must be 0.' + f"Builtin list mismatch in builtin_segments. Builtins: {self.program.builtins}, " + f"segment keys: {list(self.builtin_segments.keys())}." + ) + assert self.ret_fp_segment.size == 0, "Invalid segment size for ret_fp. Must be 0." + assert self.ret_pc_segment.size == 0, "Invalid segment size for ret_pc. Must be 0." self.validate_segment_order() @@ -116,38 +120,44 @@ class ExecutionResources: Indicates how many steps the program should run, how many memory cells are used from each builtin, and how many holes there are in the memory address space. """ + n_steps: int builtin_instance_counter: Dict[str, int] - n_memory_holes: int = field( - metadata=dict(marshmallow_field=mfields.Integer(missing=0))) + n_memory_holes: int = field(metadata=dict(marshmallow_field=mfields.Integer(missing=0))) Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema def run_validity_checks(self): - assert isinstance(self.n_steps, int) and 1 <= self.n_steps < 2 ** 30, ( - f'Invalid n_steps: {self.n_steps}.') - assert isinstance(self.n_memory_holes, int) and 0 <= self.n_memory_holes < 2 ** 30, ( - f'Invalid n_memory_holes: {self.n_memory_holes}.') + assert ( + isinstance(self.n_steps, int) and 1 <= self.n_steps < 2 ** 30 + ), f"Invalid n_steps: {self.n_steps}." + assert ( + isinstance(self.n_memory_holes, int) and 0 <= self.n_memory_holes < 2 ** 30 + ), f"Invalid n_memory_holes: {self.n_memory_holes}." assert isinstance(self.builtin_instance_counter, dict) and all( is_valid_builtin_name(name) and isinstance(size, int) and 0 <= size < 2 ** 30 - for name, size in self.builtin_instance_counter.items()), ( - 'Invalid builtin_instance_counter.') + for name, size in self.builtin_instance_counter.items() + ), "Invalid builtin_instance_counter." - def __add__(self, other: 'ExecutionResources') -> 'ExecutionResources': + def __add__(self, other: "ExecutionResources") -> "ExecutionResources": total_builtin_instance_counter = add_counters( - self.builtin_instance_counter, other.builtin_instance_counter) + self.builtin_instance_counter, other.builtin_instance_counter + ) return ExecutionResources( n_steps=self.n_steps + other.n_steps, builtin_instance_counter=total_builtin_instance_counter, - n_memory_holes=self.n_memory_holes + other.n_memory_holes) + n_memory_holes=self.n_memory_holes + other.n_memory_holes, + ) - def __sub__(self, other: 'ExecutionResources') -> 'ExecutionResources': + def __sub__(self, other: "ExecutionResources") -> "ExecutionResources": diff_builtin_instance_counter = sub_counters( - self.builtin_instance_counter, other.builtin_instance_counter) + self.builtin_instance_counter, other.builtin_instance_counter + ) diff_execution_resources = ExecutionResources( n_steps=self.n_steps - other.n_steps, builtin_instance_counter=diff_builtin_instance_counter, - n_memory_holes=self.n_memory_holes - other.n_memory_holes) + n_memory_holes=self.n_memory_holes - other.n_memory_holes, + ) diff_execution_resources.run_validity_checks() return diff_execution_resources @@ -156,7 +166,7 @@ def __sub__(self, other: 'ExecutionResources') -> 'ExecutionResources': def empty(cls): return cls(n_steps=0, builtin_instance_counter={}, n_memory_holes=0) - def copy(self) -> 'ExecutionResources': + def copy(self) -> "ExecutionResources": return copy.deepcopy(self) @@ -168,15 +178,16 @@ class CairoPie: For example, this may be used to join a few cairo runs into one, by concatenating respective segments. """ + metadata: CairoPieMetadata memory: MemoryDict additional_data: Dict[str, Any] execution_resources: ExecutionResources - METADATA_FILENAME = 'metadata.json' - MEMORY_FILENAME = 'memory.bin' - ADDITIONAL_DATA_FILENAME = 'additional_data.json' - EXECUTION_RESOURCES_FILENAME = 'execution_resources.json' + METADATA_FILENAME = "metadata.json" + MEMORY_FILENAME = "memory.bin" + ADDITIONAL_DATA_FILENAME = "additional_data.json" + EXECUTION_RESOURCES_FILENAME = "execution_resources.json" ALL_FILES = [ METADATA_FILENAME, MEMORY_FILENAME, @@ -186,50 +197,54 @@ class CairoPie: MAX_SIZE = 1024 ** 3 @classmethod - def from_file(cls, fileobj) -> 'CairoPie': + def from_file(cls, fileobj) -> "CairoPie": """ Loads an instance of CairoPie from a file. `fileobj` can be a path or a file object. """ if isinstance(fileobj, str): - fileobj = open(fileobj, 'rb') + fileobj = open(fileobj, "rb") verify_zip_file_prefix(fileobj=fileobj) with zipfile.ZipFile(fileobj) as zf: cls.verify_zip_format(zf) - with zf.open(cls.METADATA_FILENAME, 'r') as fp: + with zf.open(cls.METADATA_FILENAME, "r") as fp: metadata = CairoPieMetadata.Schema().load( - json.loads(fp.read(cls.MAX_SIZE).decode('ascii'))) - with zf.open(cls.MEMORY_FILENAME, 'r') as fp: + json.loads(fp.read(cls.MAX_SIZE).decode("ascii")) + ) + with zf.open(cls.MEMORY_FILENAME, "r") as fp: memory = MemoryDict.deserialize( data=fp.read(cls.MAX_SIZE), field_bytes=metadata.field_bytes, ) - with zf.open(cls.ADDITIONAL_DATA_FILENAME, 'r') as fp: - additional_data = json.loads(fp.read(cls.MAX_SIZE).decode('ascii')) - with zf.open(cls.EXECUTION_RESOURCES_FILENAME, 'r') as fp: + with zf.open(cls.ADDITIONAL_DATA_FILENAME, "r") as fp: + additional_data = json.loads(fp.read(cls.MAX_SIZE).decode("ascii")) + with zf.open(cls.EXECUTION_RESOURCES_FILENAME, "r") as fp: execution_resources = ExecutionResources.Schema().load( - json.loads(fp.read(cls.MAX_SIZE).decode('ascii'))) + json.loads(fp.read(cls.MAX_SIZE).decode("ascii")) + ) return cls(metadata, memory, additional_data, execution_resources) def to_file(self, file): - with zipfile.ZipFile(file, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: - with zf.open(self.METADATA_FILENAME, 'w') as fp: - fp.write(json.dumps( - CairoPieMetadata.Schema().dump(self.metadata)).encode('ascii')) - with zf.open(self.MEMORY_FILENAME, 'w') as fp: + with zipfile.ZipFile(file, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: + with zf.open(self.METADATA_FILENAME, "w") as fp: + fp.write(json.dumps(CairoPieMetadata.Schema().dump(self.metadata)).encode("ascii")) + with zf.open(self.MEMORY_FILENAME, "w") as fp: fp.write(self.memory.serialize(self.metadata.field_bytes)) - with zf.open(self.ADDITIONAL_DATA_FILENAME, 'w') as fp: - fp.write(json.dumps(self.additional_data).encode('ascii')) - with zf.open(self.EXECUTION_RESOURCES_FILENAME, 'w') as fp: - fp.write(json.dumps( - ExecutionResources.Schema().dump(self.execution_resources)).encode('ascii')) + with zf.open(self.ADDITIONAL_DATA_FILENAME, "w") as fp: + fp.write(json.dumps(self.additional_data).encode("ascii")) + with zf.open(self.EXECUTION_RESOURCES_FILENAME, "w") as fp: + fp.write( + json.dumps(ExecutionResources.Schema().dump(self.execution_resources)).encode( + "ascii" + ) + ) @classmethod - def deserialize(cls, cairo_pie_bytes: bytes) -> 'CairoPie': + def deserialize(cls, cairo_pie_bytes: bytes) -> "CairoPie": cairo_pie_file = io.BytesIO() cairo_pie_file.write(cairo_pie_bytes) return CairoPie.from_file(fileobj=cairo_pie_file) @@ -247,16 +262,16 @@ def run_validity_checks(self): self.metadata.run_validity_checks() self.execution_resources.run_validity_checks() - assert isinstance(self.memory, MemoryDict), 'Invalid type for memory.' + assert isinstance(self.memory, MemoryDict), "Invalid type for memory." self.run_memory_validity_checks() - assert sorted(f'{name}_builtin' for name in self.metadata.program.builtins) == sorted( - self.execution_resources.builtin_instance_counter.keys()), ( - 'Builtin list mismatch in execution_resources.') + assert sorted(f"{name}_builtin" for name in self.metadata.program.builtins) == sorted( + self.execution_resources.builtin_instance_counter.keys() + ), "Builtin list mismatch in execution_resources." assert isinstance(self.additional_data, dict) and all( - isinstance(name, str) and len(name) < 1000 for name in self.additional_data), ( - 'Invalid additional_data.') + isinstance(name, str) and len(name) < 1000 for name in self.additional_data + ), "Invalid additional_data." def run_memory_validity_checks(self): segment_sizes = self.metadata.segment_sizes() @@ -267,19 +282,22 @@ def is_valid_memory_addr(addr, allow_end_of_segment: bool = False): segment_sizes and its offset is in the valid range (if allow_end_of_segment=True, offset may refer to the next cell *after* the segment). """ - return isinstance(addr, RelocatableValue) and \ - isinstance(addr.segment_index, int) and \ - isinstance(addr.offset, int) and \ - addr.segment_index in segment_sizes and \ - 0 <= addr.offset < segment_sizes[addr.segment_index] + ( - 1 if allow_end_of_segment else 0) + return ( + isinstance(addr, RelocatableValue) + and isinstance(addr.segment_index, int) + and isinstance(addr.offset, int) + and addr.segment_index in segment_sizes + and 0 + <= addr.offset + < segment_sizes[addr.segment_index] + (1 if allow_end_of_segment else 0) + ) def is_valid_memory_value(value): return isinstance(value, int) or is_valid_memory_addr(value, allow_end_of_segment=True) for addr, value in self.memory.items(): - assert is_valid_memory_addr(addr), 'Invalid memory cell address.' - assert is_valid_memory_value(value), f'Invalid memory cell value.' + assert is_valid_memory_addr(addr), "Invalid memory cell address." + assert is_valid_memory_value(value), f"Invalid memory cell value." @classmethod def verify_zip_format(cls, zf: zipfile.ZipFile): @@ -288,31 +306,38 @@ def verify_zip_format(cls, zf: zipfile.ZipFile): type is ZIP_DEFLATED and that their size is not too big. """ # Check the compression algorithm. - assert all(zip_info.compress_type == zipfile.ZIP_DEFLATED for zip_info in zf.filelist), \ - 'Invalid compress type.' + assert all( + zip_info.compress_type == zipfile.ZIP_DEFLATED for zip_info in zf.filelist + ), "Invalid compress type." # Check that orig_filename == filename. # Use "type: ignore" since mypy doesn't recognize ZipInfo.orig_filename. assert all( - zip_info.orig_filename == zip_info.filename # type: ignore - for zip_info in zf.filelist), 'File name mismatch.' + zip_info.orig_filename == zip_info.filename for zip_info in zf.filelist # type: ignore + ), "File name mismatch." # Make sure we have exactly the files we expect, and that their size is reasonable. inner_files = {zip_info.filename: zip_info for zip_info in zf.filelist} - assert sorted(inner_files.keys()) == sorted(cls.ALL_FILES), \ - 'Invalid list of inner files in the CairoPIE zip.' - assert inner_files[cls.METADATA_FILENAME].file_size < cls.MAX_SIZE, \ - f'Invalid file size for {cls.METADATA_FILENAME}.' - assert inner_files[cls.MEMORY_FILENAME].file_size < cls.MAX_SIZE, \ - f'Invalid file size for {cls.MEMORY_FILENAME}.' - assert inner_files[cls.ADDITIONAL_DATA_FILENAME].file_size < cls.MAX_SIZE, \ - f'Invalid file size for {cls.ADDITIONAL_DATA_FILENAME}.' - assert inner_files[cls.EXECUTION_RESOURCES_FILENAME].file_size < 10000, \ - f'Invalid file size for {cls.EXECUTION_RESOURCES_FILENAME}.' + assert sorted(inner_files.keys()) == sorted( + cls.ALL_FILES + ), "Invalid list of inner files in the CairoPIE zip." + assert ( + inner_files[cls.METADATA_FILENAME].file_size < cls.MAX_SIZE + ), f"Invalid file size for {cls.METADATA_FILENAME}." + assert ( + inner_files[cls.MEMORY_FILENAME].file_size < cls.MAX_SIZE + ), f"Invalid file size for {cls.MEMORY_FILENAME}." + assert ( + inner_files[cls.ADDITIONAL_DATA_FILENAME].file_size < cls.MAX_SIZE + ), f"Invalid file size for {cls.ADDITIONAL_DATA_FILENAME}." + assert ( + inner_files[cls.EXECUTION_RESOURCES_FILENAME].file_size < 10000 + ), f"Invalid file size for {cls.EXECUTION_RESOURCES_FILENAME}." def get_segment(self, segment_info: SegmentInfo): return self.memory.get_range( - RelocatableValue(segment_index=segment_info.index, offset=0), size=segment_info.size) + RelocatableValue(segment_index=segment_info.index, offset=0), size=segment_info.size + ) def verify_zip_file_prefix(fileobj): @@ -321,4 +346,4 @@ def verify_zip_file_prefix(fileobj): """ fileobj.seek(0) # Make sure this is a zip file. - assert fileobj.read(2) in ['PK', b'PK'], 'Invalid prefix for zip file.' + assert fileobj.read(2) in ["PK", b"PK"], "Invalid prefix for zip file." diff --git a/src/starkware/cairo/lang/vm/cairo_pie_test.py b/src/starkware/cairo/lang/vm/cairo_pie_test.py index a1fb75d2..f416a39f 100644 --- a/src/starkware/cairo/lang/vm/cairo_pie_test.py +++ b/src/starkware/cairo/lang/vm/cairo_pie_test.py @@ -6,7 +6,11 @@ from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.compiler.cairo_compile import compile_cairo from starkware.cairo.lang.vm.cairo_pie import ( - CairoPie, CairoPieMetadata, ExecutionResources, SegmentInfo) + CairoPie, + CairoPieMetadata, + ExecutionResources, + SegmentInfo, +) from starkware.cairo.lang.vm.cairo_runner import get_runner_from_code from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import RelocatableValue @@ -15,8 +19,9 @@ def test_cairo_pie_serialize_deserialize(): program = compile_cairo( - code=[('%builtins output pedersen range_check ecdsa\nmain:\n[ap] = [ap]\n', '')], - prime=DEFAULT_PRIME) + code=[("%builtins output pedersen range_check ecdsa\nmain:\n[ap] = [ap]\n", "")], + prime=DEFAULT_PRIME, + ) metadata = CairoPieMetadata( program=program.stripped(), program_segment=SegmentInfo(0, 10), @@ -24,28 +29,30 @@ def test_cairo_pie_serialize_deserialize(): ret_fp_segment=SegmentInfo(6, 12), ret_pc_segment=SegmentInfo(7, 21), builtin_segments={ - 'a': SegmentInfo(4, 15), + "a": SegmentInfo(4, 15), }, extra_segments=[], ) - memory = MemoryDict({ - 1: 2, - RelocatableValue(3, 4): RelocatableValue(6, 7), - }) - additional_data = {'c': ['d', 3]} + memory = MemoryDict( + { + 1: 2, + RelocatableValue(3, 4): RelocatableValue(6, 7), + } + ) + additional_data = {"c": ["d", 3]} execution_resources = ExecutionResources( n_steps=10, n_memory_holes=7, builtin_instance_counter={ - 'output': 6, - 'pedersen': 3, - } + "output": 6, + "pedersen": 3, + }, ) cairo_pie = CairoPie( metadata=metadata, memory=memory, additional_data=additional_data, - execution_resources=execution_resources + execution_resources=execution_resources, ) fileobj = io.BytesIO() @@ -64,7 +71,7 @@ def cairo_pie(): return (output_ptr=output_ptr, pedersen_ptr=pedersen_ptr) end """ - runner = get_runner_from_code(code=[(code, '')], layout='small', prime=DEFAULT_PRIME) + runner = get_runner_from_code(code=[(code, "")], layout="small", prime=DEFAULT_PRIME) return runner.get_cairo_pie() @@ -75,48 +82,47 @@ def test_cairo_pie_validity(cairo_pie): def test_cairo_pie_validity_invalid_program_size(cairo_pie: CairoPie): cairo_pie.metadata.program_segment.size += 1 with pytest.raises( - AssertionError, match='Program length does not match the program segment size.'): + AssertionError, match="Program length does not match the program segment size." + ): cairo_pie.run_validity_checks() def test_cairo_pie_validity_invalid_builtin_list(cairo_pie: CairoPie): - cairo_pie.program.builtins.append('output') - with pytest.raises( - AssertionError, match='Invalid builtin list.'): + cairo_pie.program.builtins.append("output") + with pytest.raises(AssertionError, match="Invalid builtin list."): cairo_pie.run_validity_checks() def test_cairo_pie_validity_invalid_builtin_segments(cairo_pie: CairoPie): - cairo_pie.metadata.builtin_segments['tmp'] = cairo_pie.metadata.builtin_segments['output'] - with pytest.raises( - AssertionError, match='Builtin list mismatch in builtin_segments.'): + cairo_pie.metadata.builtin_segments["tmp"] = cairo_pie.metadata.builtin_segments["output"] + with pytest.raises(AssertionError, match="Builtin list mismatch in builtin_segments."): cairo_pie.run_validity_checks() def test_cairo_pie_validity_invalid_builtin_list_execution_resources(cairo_pie: CairoPie): - cairo_pie.execution_resources.builtin_instance_counter['tmp_builtin'] = \ - cairo_pie.execution_resources.builtin_instance_counter['output_builtin'] - with pytest.raises( - AssertionError, match='Builtin list mismatch in execution_resources.'): + cairo_pie.execution_resources.builtin_instance_counter[ + "tmp_builtin" + ] = cairo_pie.execution_resources.builtin_instance_counter["output_builtin"] + with pytest.raises(AssertionError, match="Builtin list mismatch in execution_resources."): cairo_pie.run_validity_checks() def test_cairo_pie_memory_negative_address(cairo_pie: CairoPie): # Write to a negative address. - cairo_pie.memory.set_without_checks(RelocatableValue( - segment_index=cairo_pie.metadata.program_segment.index, offset=-5), 0) - with pytest.raises( - AssertionError, match='Invalid memory cell address.'): + cairo_pie.memory.set_without_checks( + RelocatableValue(segment_index=cairo_pie.metadata.program_segment.index, offset=-5), 0 + ) + with pytest.raises(AssertionError, match="Invalid memory cell address."): cairo_pie.run_validity_checks() def test_cairo_pie_memory_invalid_address(cairo_pie: CairoPie): # Write to an invalid address. cairo_pie.memory.unfreeze_for_testing() - cairo_pie.memory[RelocatableValue( - segment_index=cairo_pie.metadata.ret_pc_segment.index, offset=0)] = 0 - with pytest.raises( - AssertionError, match='Invalid memory cell address.'): + cairo_pie.memory[ + RelocatableValue(segment_index=cairo_pie.metadata.ret_pc_segment.index, offset=0) + ] = 0 + with pytest.raises(AssertionError, match="Invalid memory cell address."): cairo_pie.run_validity_checks() @@ -124,17 +130,17 @@ def test_cairo_pie_memory_invalid_value(cairo_pie: CairoPie): # Write a value after the execution segment. output_end = RelocatableValue( segment_index=cairo_pie.metadata.execution_segment.index, - offset=cairo_pie.metadata.execution_segment.size) + offset=cairo_pie.metadata.execution_segment.size, + ) cairo_pie.memory.unfreeze_for_testing() cairo_pie.memory[output_end] = output_end + 2 # It should fail because the address is outside the segment expected size. - with pytest.raises( - AssertionError, match='Invalid memory cell address.'): + with pytest.raises(AssertionError, match="Invalid memory cell address."): cairo_pie.run_validity_checks() # Increase the size. cairo_pie.metadata.execution_segment.size += 1 # Now it should fail because of the value. - with pytest.raises(AssertionError, match='Invalid memory cell value.'): + with pytest.raises(AssertionError, match="Invalid memory cell value."): cairo_pie.run_validity_checks() @@ -142,7 +148,7 @@ def test_add_execution_resources(): """ Tests ExecutionResources __add__ calculation. """ - dummy_builtins = ['builtin1', 'builtin2', 'builtin3', 'builtin4'] + dummy_builtins = ["builtin1", "builtin2", "builtin3", "builtin4"] total_execution_resources = ExecutionResources.empty() total_builtin_instance_counter = {} @@ -162,13 +168,16 @@ def test_add_execution_resources(): random_builtin_instance_counter[random_builtin_type] = random_builtin_counter random_steps = random.randint(0, 1000) execution_resources = ExecutionResources( - n_steps=random_steps, builtin_instance_counter=random_builtin_instance_counter, - n_memory_holes=0) + n_steps=random_steps, + builtin_instance_counter=random_builtin_instance_counter, + n_memory_holes=0, + ) # Update totals. total_steps += random_steps total_builtin_instance_counter = add_counters( - total_builtin_instance_counter, random_builtin_instance_counter) + total_builtin_instance_counter, random_builtin_instance_counter + ) # Calculate total_execution_resources using __add__() function. total_execution_resources += execution_resources diff --git a/src/starkware/cairo/lang/vm/cairo_run.py b/src/starkware/cairo/lang/vm/cairo_run.py index 5756613c..90f92c8b 100644 --- a/src/starkware/cairo/lang/vm/cairo_run.py +++ b/src/starkware/cairo/lang/vm/cairo_run.py @@ -31,103 +31,144 @@ class CairoRunError(Exception): def main(): start_time = time.time() - parser = argparse.ArgumentParser( - description='A tool to run Cairo programs.') - parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') + parser = argparse.ArgumentParser(description="A tool to run Cairo programs.") + parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") parser.add_argument( - '--program', type=argparse.FileType('r'), help='The name of the program json file.') - parser.add_argument( - '--program_input', type=argparse.FileType('r'), - help='Path to a json file representing the (private) input of the program.') + "--program", type=argparse.FileType("r"), help="The name of the program json file." + ) parser.add_argument( - '--steps', type=int, - help='The number of instructions to perform. If steps is not given, runs the program until ' - 'the __end__ instruction, and then continues until the number of steps is a power of 2.') + "--program_input", + type=argparse.FileType("r"), + help="Path to a json file representing the (private) input of the program.", + ) parser.add_argument( - '--min_steps', type=int, - help='The minimal number of instructions to perform. This can be used to guarantee that ' - 'there will be enough builtin instances for the program.') + "--steps", + type=int, + help="The number of instructions to perform. If steps is not given, runs the program until " + "the __end__ instruction, and then continues until the number of steps is a power of 2.", + ) parser.add_argument( - '--debug_error', action='store_true', - help='If there is an error during the execution, stop the execution, but produce the ' - 'partial outputs.') + "--min_steps", + type=int, + help="The minimal number of instructions to perform. This can be used to guarantee that " + "there will be enough builtin instances for the program.", + ) parser.add_argument( - '--no_end', action='store_true', - help="Don't check that the program ended successfully.") + "--debug_error", + action="store_true", + help="If there is an error during the execution, stop the execution, but produce the " + "partial outputs.", + ) parser.add_argument( - '--print_memory', action='store_true', - help='Show the values on the memory after the execution.') + "--no_end", action="store_true", help="Don't check that the program ended successfully." + ) parser.add_argument( - '--relocate_prints', action='store_true', - help='Print memory and info after memory relocation.') + "--print_memory", + action="store_true", + help="Show the values on the memory after the execution.", + ) parser.add_argument( - '--secure_run', action='store_true', default=None, - help='Verify that the run is secure and can be run safely using the bootloader ' - '(the default).') + "--relocate_prints", + action="store_true", + help="Print memory and info after memory relocation.", + ) parser.add_argument( - '--no_secure_run', dest='secure_run', action='store_false', - help="Don't verify that the run is secure (see --secure_run).") + "--secure_run", + action="store_true", + default=None, + help="Verify that the run is secure and can be run safely using the bootloader " + "(the default).", + ) parser.add_argument( - '--print_info', action='store_true', - help='Print information on the execution of the program.') + "--no_secure_run", + dest="secure_run", + action="store_false", + help="Don't verify that the run is secure (see --secure_run).", + ) parser.add_argument( - '--print_segments', action='store_true', - help='Print the segment relocation table.') + "--print_info", + action="store_true", + help="Print information on the execution of the program.", + ) parser.add_argument( - '--print_output', action='store_true', - help='Prints the program output (if the output builtin is used).') + "--print_segments", action="store_true", help="Print the segment relocation table." + ) parser.add_argument( - '--memory_file', type=argparse.FileType('wb'), - help='Output file name for the memory.') + "--print_output", + action="store_true", + help="Prints the program output (if the output builtin is used).", + ) parser.add_argument( - '--trace_file', type=argparse.FileType('wb'), - help='Output file name for the execution trace.') + "--memory_file", type=argparse.FileType("wb"), help="Output file name for the memory." + ) parser.add_argument( - '--run_from_cairo_pie', type=argparse.FileType('rb'), - help='Runs a Cairo PIE file, instead of a program. ' - 'This flag can be used with --secure_run to verify the correctness of a Cairo PIE file.') + "--trace_file", + type=argparse.FileType("wb"), + help="Output file name for the execution trace.", + ) parser.add_argument( - '--cairo_pie_output', type=argparse.FileType('wb'), - help='Output file name for the CairoPIE object.') + "--run_from_cairo_pie", + type=argparse.FileType("rb"), + help="Runs a Cairo PIE file, instead of a program. " + "This flag can be used with --secure_run to verify the correctness of a Cairo PIE file.", + ) parser.add_argument( - '--debug_info_file', type=argparse.FileType('w'), - help='Output file name for debug information created at run time.') + "--cairo_pie_output", + type=argparse.FileType("wb"), + help="Output file name for the CairoPIE object.", + ) parser.add_argument( - '--air_public_input', type=argparse.FileType('w'), - help='Output file name for the public input json file of the Cairo AIR.') + "--debug_info_file", + type=argparse.FileType("w"), + help="Output file name for debug information created at run time.", + ) parser.add_argument( - '--air_private_input', type=argparse.FileType('w'), - help='Output file name for the private input json file of the Cairo AIR.') + "--air_public_input", + type=argparse.FileType("w"), + help="Output file name for the public input json file of the Cairo AIR.", + ) parser.add_argument( - '--layout', choices=LAYOUTS.keys(), default='plain', - help='The layout of the Cairo AIR.') + "--air_private_input", + type=argparse.FileType("w"), + help="Output file name for the private input json file of the Cairo AIR.", + ) parser.add_argument( - '--tracer', action='store_true', help='Run the tracer.') + "--layout", choices=LAYOUTS.keys(), default="plain", help="The layout of the Cairo AIR." + ) + parser.add_argument("--tracer", action="store_true", help="Run the tracer.") parser.add_argument( - '--profile_output', type=str, - help='A path to an output file to write profile data to. Can be opened in pprof. ' - 'Usually "profile.pb.gz".') + "--profile_output", + type=str, + help="A path to an output file to write profile data to. Can be opened in pprof. " + 'Usually "profile.pb.gz".', + ) parser.add_argument( - '--proof_mode', action='store_true', help='Prepare a provable execution trace.') + "--proof_mode", action="store_true", help="Prepare a provable execution trace." + ) parser.add_argument( - '--flavor', type=str, choices=['Debug', 'Release', 'RelWithDebInfo'], help='Build flavor.') + "--flavor", type=str, choices=["Debug", "Release", "RelWithDebInfo"], help="Build flavor." + ) python_dependencies.add_argparse_argument(parser) args = parser.parse_args() - assert int(args.program is not None) + int(args.run_from_cairo_pie is not None) == 1, \ - 'Exactly one of --program, --run_from_cairo_pie must be specified.' - assert not (args.proof_mode and args.run_from_cairo_pie), \ - '--proof_mode cannot be used with --run_from_cairo_pie.' - assert not (args.steps and args.min_steps), '--steps and --min_steps cannot be both specified.' - assert not (args.cairo_pie_output and args.no_end), \ - '--no_end and --cairo_pie_output cannot be both specified.' - assert not (args.cairo_pie_output and args.proof_mode), \ - '--proof_mode and --cairo_pie_output cannot be both specified.' + assert ( + int(args.program is not None) + int(args.run_from_cairo_pie is not None) == 1 + ), "Exactly one of --program, --run_from_cairo_pie must be specified." + assert not ( + args.proof_mode and args.run_from_cairo_pie + ), "--proof_mode cannot be used with --run_from_cairo_pie." + assert not (args.steps and args.min_steps), "--steps and --min_steps cannot be both specified." + assert not ( + args.cairo_pie_output and args.no_end + ), "--no_end and --cairo_pie_output cannot be both specified." + assert not ( + args.cairo_pie_output and args.proof_mode + ), "--proof_mode and --cairo_pie_output cannot be both specified." if args.air_public_input: - assert args.proof_mode, '--air_public_input can only be used in proof_mode.' + assert args.proof_mode, "--air_public_input can only be used in proof_mode." if args.air_private_input: - assert args.proof_mode, '--air_private_input can only be used in proof_mode.' + assert args.proof_mode, "--air_private_input can only be used in proof_mode." # If secure_run is not specified, the default is False in proof mode and True otherwise. if args.secure_run is None: @@ -140,7 +181,7 @@ def main(): print(err, file=sys.stderr) res = 1 except AssertionError as err: - print(f'Error: {err}', file=sys.stderr) + print(f"Error: {err}", file=sys.stderr) res = 1 # Generate python dependencies. @@ -154,9 +195,10 @@ def load_program(program) -> ProgramBase: program_json = json.load(program) except json.JSONDecodeError as err: raise CairoRunError( - 'Failed to load compiled program (not a valid JSON file). ' - 'Did you compile the code before running it? ' - f"Error: '{err}'") + "Failed to load compiled program (not a valid JSON file). " + "Did you compile the code before running it? " + f"Error: '{err}'" + ) return Program.Schema().load(program_json) @@ -165,17 +207,17 @@ def cairo_run(args): trace_file = args.trace_file if trace_file is None and trace_needed: # If --tracer or --profile_output is used, use a temporary file as trace_file. - trace_file = tempfile.NamedTemporaryFile(mode='wb') + trace_file = tempfile.NamedTemporaryFile(mode="wb") memory_file = args.memory_file if memory_file is None and trace_needed: # If --tracer or --profile_output is used, use a temporary file as memory_file. - memory_file = tempfile.NamedTemporaryFile(mode='wb') + memory_file = tempfile.NamedTemporaryFile(mode="wb") debug_info_file = args.debug_info_file if debug_info_file is None and trace_needed: # If --tracer or --profile_output is used, use a temporary file as debug_info_file. - debug_info_file = tempfile.NamedTemporaryFile(mode='w') + debug_info_file = tempfile.NamedTemporaryFile(mode="w") ret_code = 0 cairo_pie_input = None @@ -185,22 +227,27 @@ def cairo_run(args): steps_input = args.steps else: assert args.run_from_cairo_pie is not None - assert args.steps is None and args.min_steps is None, \ - '--steps and --min_steps cannot be specified in --run_from_cairo_pie mode.' + assert ( + args.steps is None and args.min_steps is None + ), "--steps and --min_steps cannot be specified in --run_from_cairo_pie mode." cairo_pie_input = CairoPie.from_file(args.run_from_cairo_pie) try: cairo_pie_input.run_validity_checks() except Exception as exc: # Trim error message in case it's too long. msg = str(exc)[:10000] - raise CairoRunError(f'Security check for the CairoPIE input failed: {msg}') + raise CairoRunError(f"Security check for the CairoPIE input failed: {msg}") program = cairo_pie_input.program initial_memory = cairo_pie_input.memory steps_input = cairo_pie_input.execution_resources.n_steps runner = CairoRunner( - program=program, layout=args.layout, memory=initial_memory, proof_mode=args.proof_mode, - allow_missing_builtins=args.proof_mode) + program=program, + layout=args.layout, + memory=initial_memory, + proof_mode=args.proof_mode, + allow_missing_builtins=args.proof_mode, + ) runner.initialize_segments() end = runner.initialize_main_entrypoint() @@ -215,16 +262,17 @@ def cairo_run(args): builtin_runner.extend_additional_data( data=cairo_pie_input.additional_data[name], relocate_callback=lambda x: x, - data_is_trusted=not args.secure_run) + data_is_trusted=not args.secure_run, + ) # Force segments sizes to match the CairoPie. runner.finalize_segments_by_cairo_pie(cairo_pie=cairo_pie_input) program_input = json.load(args.program_input) if args.program_input else {} - runner.initialize_vm(hint_locals={'program_input': program_input}) + runner.initialize_vm(hint_locals={"program_input": program_input}) try: if args.no_end: - assert args.steps is not None, '--steps must be specified when running with --no-end.' + assert args.steps is not None, "--steps must be specified when running with --no-end." else: additional_steps = 1 if args.proof_mode else 0 max_steps = steps_input - additional_steps if steps_input is not None else None @@ -246,7 +294,7 @@ def cairo_run(args): runner.end_run(disable_trace_padding=disable_trace_padding) except (VmException, AssertionError) as exc: if args.debug_error: - print(f'Got an error:\n{exc}') + print(f"Got an error:\n{exc}") ret_code = 1 else: raise exc @@ -262,9 +310,10 @@ def cairo_run(args): verify_secure_runner(runner) if args.run_from_cairo_pie is not None: assert cairo_pie_input is not None - assert cairo_pie_input == runner.get_cairo_pie(), \ - 'The Cairo PIE input is not identical to the resulting Cairo PIE. ' \ - 'This may indicate that the Cairo PIE was not generated by cairo_run.' + assert cairo_pie_input == runner.get_cairo_pie(), ( + "The Cairo PIE input is not identical to the resulting Cairo PIE. " + "This may indicate that the Cairo PIE was not generated by cairo_run." + ) if args.cairo_pie_output: runner.get_cairo_pie().to_file(args.cairo_pie_output) @@ -303,57 +352,81 @@ def cairo_run(args): public_input_file=args.air_public_input, memory=runner.relocated_memory, public_memory_addresses=runner.segments.get_public_memory_addresses( - runner.segment_offsets), + runner.segment_offsets + ), memory_segment_addresses=runner.get_memory_segment_addresses(), trace=runner.relocated_trace, rc_min=rc_min, - rc_max=rc_max) + rc_max=rc_max, + ) if args.air_private_input is not None: - assert args.trace_file is not None, \ - '--trace_file must be set when --air_private_input is set.' - assert args.memory_file is not None, \ - '--memory_file must be set when --air_private_input is set.' - json.dump({ - 'trace_path': f'{os.path.abspath(trace_file.name)}', - 'memory_path': f'{os.path.abspath(memory_file.name)}', - **runner.get_air_private_input(), - }, args.air_private_input, indent=4) + assert ( + args.trace_file is not None + ), "--trace_file must be set when --air_private_input is set." + assert ( + args.memory_file is not None + ), "--memory_file must be set when --air_private_input is set." + json.dump( + { + "trace_path": f"{os.path.abspath(trace_file.name)}", + "memory_path": f"{os.path.abspath(memory_file.name)}", + **runner.get_air_private_input(), + }, + args.air_private_input, + indent=4, + ) print(file=args.air_private_input) args.air_private_input.flush() if debug_info_file is not None: - json.dump( - DebugInfo.Schema().dump(runner.get_relocated_debug_info()), - debug_info_file) + json.dump(DebugInfo.Schema().dump(runner.get_relocated_debug_info()), debug_info_file) debug_info_file.flush() if args.tracer: - CAIRO_TRACER = 'starkware.cairo.lang.tracer.tracer' - subprocess.call(list(filter(None, [ - sys.executable, - '-m', - CAIRO_TRACER, - f'--program={args.program.name}', - f'--trace={trace_file.name}', - f'--memory={memory_file.name}', - f'--air_public_input={args.air_public_input.name}' if args.air_public_input else None, - f'--debug_info={debug_info_file.name}', - ]))) + CAIRO_TRACER = "starkware.cairo.lang.tracer.tracer" + subprocess.call( + list( + filter( + None, + [ + sys.executable, + "-m", + CAIRO_TRACER, + f"--program={args.program.name}", + f"--trace={trace_file.name}", + f"--memory={memory_file.name}", + f"--air_public_input={args.air_public_input.name}" + if args.air_public_input + else None, + f"--debug_info={debug_info_file.name}", + ], + ) + ) + ) if args.profile_output is not None: - CAIRO_PROFILER = 'starkware.cairo.lang.tracer.profiler' - subprocess.call(list(filter(None, [ - sys.executable, - '-m', - CAIRO_PROFILER, - f'--program={args.program.name}', - f'--trace={trace_file.name}', - f'--memory={memory_file.name}', - f'--air_public_input={args.air_public_input.name}' if args.air_public_input else None, - f'--debug_info={debug_info_file.name}', - f'--profile_output={args.profile_output}', - ]))) + CAIRO_PROFILER = "starkware.cairo.lang.tracer.profiler" + subprocess.call( + list( + filter( + None, + [ + sys.executable, + "-m", + CAIRO_PROFILER, + f"--program={args.program.name}", + f"--trace={trace_file.name}", + f"--memory={memory_file.name}", + f"--air_public_input={args.air_public_input.name}" + if args.air_public_input + else None, + f"--debug_info={debug_info_file.name}", + f"--profile_output={args.profile_output}", + ], + ) + ) + ) return ret_code @@ -373,15 +446,19 @@ def write_binary_memory(memory_file: BinaryIO, memory: MemoryDict, field_bytes: def write_air_public_input( - public_input_file, memory: MemoryDict, layout: str, - public_memory_addresses: List[Tuple[int, int]], - memory_segment_addresses: Dict[str, MemorySegmentAddresses], - trace: List[TraceEntry[int]], - rc_min: int, - rc_max: int): + public_input_file, + memory: MemoryDict, + layout: str, + public_memory_addresses: List[Tuple[int, int]], + memory_segment_addresses: Dict[str, MemorySegmentAddresses], + trace: List[TraceEntry[int]], + rc_min: int, + rc_max: int, +): public_memory = [ PublicMemoryEntry(address=addr, value=memory[addr], page=page) # type: ignore - for addr, page in public_memory_addresses] + for addr, page in public_memory_addresses + ] initial_pc = trace[0].pc assert isinstance(initial_pc, int) public_input = PublicInput( # type: ignore @@ -390,22 +467,20 @@ def write_air_public_input( rc_max=rc_max, n_steps=len(trace), memory_segments={ - 'program': MemorySegmentAddresses( # type: ignore - begin_addr=trace[0].pc, - stop_ptr=trace[-1].pc + "program": MemorySegmentAddresses( # type: ignore + begin_addr=trace[0].pc, stop_ptr=trace[-1].pc ), - 'execution': MemorySegmentAddresses( # type: ignore - begin_addr=trace[0].ap, - stop_ptr=trace[-1].ap + "execution": MemorySegmentAddresses( # type: ignore + begin_addr=trace[0].ap, stop_ptr=trace[-1].ap ), **memory_segment_addresses, }, public_memory=public_memory, ) public_input_file.write(PublicInput.Schema().dumps(public_input, indent=4)) - public_input_file.write('\n') + public_input_file.write("\n") public_input_file.flush() -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/lang/vm/cairo_runner.py b/src/starkware/cairo/lang/vm/cairo_runner.py index 068ce187..8e0a740d 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner.py +++ b/src/starkware/cairo/lang/vm/cairo_runner.py @@ -3,10 +3,14 @@ from starkware.cairo.lang.builtins.bitwise.bitwise_builtin_runner import BitwiseBuiltinRunner from starkware.cairo.lang.builtins.hash.hash_builtin_runner import HashBuiltinRunner from starkware.cairo.lang.builtins.range_check.range_check_builtin_runner import ( - RangeCheckBuiltinRunner) + RangeCheckBuiltinRunner, +) from starkware.cairo.lang.builtins.signature.signature_builtin_runner import SignatureBuiltinRunner from starkware.cairo.lang.compiler.cairo_compile import ( - compile_cairo, compile_cairo_files, get_module_reader) + compile_cairo, + compile_cairo_files, + get_module_reader, +) from starkware.cairo.lang.compiler.debug_info import DebugInfo from starkware.cairo.lang.compiler.expression_simplifier import to_field_element from starkware.cairo.lang.compiler.preprocessor.default_pass_manager import default_pass_manager @@ -15,7 +19,11 @@ from starkware.cairo.lang.instances import LAYOUTS from starkware.cairo.lang.vm.builtin_runner import BuiltinRunner, InsufficientAllocatedCells from starkware.cairo.lang.vm.cairo_pie import ( - CairoPie, CairoPieMetadata, ExecutionResources, SegmentInfo) + CairoPie, + CairoPieMetadata, + ExecutionResources, + SegmentInfo, +) from starkware.cairo.lang.vm.crypto import pedersen_hash, verify_ecdsa from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager @@ -45,57 +53,75 @@ def process_ecdsa(public_key, msg, signature): the ECDSA component. """ r, s = signature - return {'r': hex(r), 'w': hex(inv_mod_curve_size(s))} + return {"r": hex(r), "w": hex(inv_mod_curve_size(s))} class CairoRunner: def __init__( - self, program: ProgramBase, layout: str = 'plain', memory: MemoryDict = None, - proof_mode: Optional[bool] = None, allow_missing_builtins: Optional[bool] = None): + self, + program: ProgramBase, + layout: str = "plain", + memory: MemoryDict = None, + proof_mode: Optional[bool] = None, + allow_missing_builtins: Optional[bool] = None, + ): self.program = program self.layout = layout self.builtin_runners: Dict[str, BuiltinRunner] = {} self.original_steps = None self.proof_mode = False if proof_mode is None else proof_mode self.allow_missing_builtins = ( - False if allow_missing_builtins is None else allow_missing_builtins) + False if allow_missing_builtins is None else allow_missing_builtins + ) instance = LAYOUTS[self.layout] if not allow_missing_builtins: non_existing_builtins = set(self.program.builtins) - set(instance.builtins.keys()) - assert len(non_existing_builtins) == 0, \ - f'Builtins {non_existing_builtins} are not present in layout "{self.layout}"' + assert ( + len(non_existing_builtins) == 0 + ), f'Builtins {non_existing_builtins} are not present in layout "{self.layout}"' builtin_factories = dict( - output=lambda name, included: OutputBuiltinRunner( - included=included), + output=lambda name, included: OutputBuiltinRunner(included=included), pedersen=lambda name, included: HashBuiltinRunner( - name=name, included=included, - ratio=instance.builtins['pedersen'].ratio, hash_func=pedersen_hash), + name=name, + included=included, + ratio=instance.builtins["pedersen"].ratio, + hash_func=pedersen_hash, + ), range_check=lambda name, included: RangeCheckBuiltinRunner( - included=included, ratio=instance.builtins['range_check'].ratio, - inner_rc_bound=2 ** 16, n_parts=instance.builtins['range_check'].n_parts), + included=included, + ratio=instance.builtins["range_check"].ratio, + inner_rc_bound=2 ** 16, + n_parts=instance.builtins["range_check"].n_parts, + ), ecdsa=lambda name, included: SignatureBuiltinRunner( - name=name, included=included, - ratio=instance.builtins['ecdsa'].ratio, - process_signature=process_ecdsa, verify_signature=verify_ecdsa_sig), + name=name, + included=included, + ratio=instance.builtins["ecdsa"].ratio, + process_signature=process_ecdsa, + verify_signature=verify_ecdsa_sig, + ), bitwise=lambda name, included: BitwiseBuiltinRunner( - included=included, bitwise_builtin=instance.builtins['bitwise']), + included=included, bitwise_builtin=instance.builtins["bitwise"] + ), ) for name in instance.builtins: factory = builtin_factories.get(name) - assert factory is not None, f'The {name} builtin is not supported.' + assert factory is not None, f"The {name} builtin is not supported." included = name in self.program.builtins # In proof mode all the builtin_runners are required. if included or self.proof_mode: - self.builtin_runners[f'{name}_builtin'] = factory( # type: ignore - name=name, included=included) + self.builtin_runners[f"{name}_builtin"] = factory( # type: ignore + name=name, included=included + ) supported_builtin_list = list(builtin_factories.keys()) - assert is_subsequence(self.program.builtins, supported_builtin_list), \ - f'{self.program.builtins} is not a subsequence of {supported_builtin_list}.' + assert is_subsequence( + self.program.builtins, supported_builtin_list + ), f"{self.program.builtins} is not a subsequence of {supported_builtin_list}." self.memory = memory if memory is not None else MemoryDict() self.segments = MemorySegmentManager(memory=self.memory, prime=self.program.prime) @@ -111,18 +137,24 @@ def __init__( @classmethod def from_file( - cls, filename: str, prime: int, layout: str = 'plain', - remove_hints: bool = False, remove_builtins: bool = False, memory: MemoryDict = None, - preprocessor_cls: Type[Preprocessor] = Preprocessor, - proof_mode: Optional[bool] = None) -> 'CairoRunner': + cls, + filename: str, + prime: int, + layout: str = "plain", + remove_hints: bool = False, + remove_builtins: bool = False, + memory: MemoryDict = None, + preprocessor_cls: Type[Preprocessor] = Preprocessor, + proof_mode: Optional[bool] = None, + ) -> "CairoRunner": module_reader = get_module_reader(cairo_path=[]) program = compile_cairo_files( files=[filename], debug_info=True, pass_manager=default_pass_manager( - prime=prime, - read_module=module_reader.read, - preprocessor_cls=preprocessor_cls)) + prime=prime, read_module=module_reader.read, preprocessor_cls=preprocessor_cls + ), + ) if remove_hints: program.hints = {} if remove_builtins: @@ -154,9 +186,9 @@ def initialize_main_entrypoint(self): stack: List[MaybeRelocatable] = [] for builtin_name in self.program.builtins: - builtin_runner = self.builtin_runners.get(f'{builtin_name}_builtin') + builtin_runner = self.builtin_runners.get(f"{builtin_name}_builtin") if builtin_runner is None: - assert self.allow_missing_builtins, 'Missing builtin.' + assert self.allow_missing_builtins, "Missing builtin." stack += [0] else: stack += builtin_runner.initial_stack() @@ -167,22 +199,25 @@ def initialize_main_entrypoint(self): stack = [self.execution_base + 2, 0] + stack self.execution_public_memory = list(range(len(stack))) - assert isinstance(self.program, Program), \ - '--proof_mode cannot be used with a StrippedProgram.' + assert isinstance( + self.program, Program + ), "--proof_mode cannot be used with a StrippedProgram." self.initialize_state(self.program.start, stack) self.initial_fp = self.initial_ap = self.execution_base + 2 - return self.program_base + self.program.get_label('__end__') + return self.program_base + self.program.get_label("__end__") else: return_fp = self.segments.add() main = self.program.main - assert main is not None, 'Missing main().' - return self.initialize_function_entrypoint( - main, stack, return_fp=return_fp) + assert main is not None, "Missing main()." + return self.initialize_function_entrypoint(main, stack, return_fp=return_fp) def initialize_function_entrypoint( - self, entrypoint: Union[str, int], args: Sequence[MaybeRelocatable], - return_fp: MaybeRelocatable = 0): + self, + entrypoint: Union[str, int], + args: Sequence[MaybeRelocatable], + return_fp: MaybeRelocatable = 0, + ): end = self.segments.add() stack = list(args) + [return_fp, end] self.initialize_state(entrypoint, stack) @@ -198,8 +233,8 @@ def initialize_state(self, entrypoint: Union[str, int], stack: Sequence[MaybeRel self.load_data(self.execution_base, stack) def initialize_vm( - self, hint_locals, static_locals: Optional[Dict[str, Any]] = None, - vm_class=VirtualMachine): + self, hint_locals, static_locals: Optional[Dict[str, Any]] = None, vm_class=VirtualMachine + ): context = RunContext( pc=self.initial_pc, ap=self.initial_ap, @@ -212,7 +247,9 @@ def initialize_vm( static_locals = {} self.vm = vm_class( - self.program, context, hint_locals=hint_locals, + self.program, + context, + hint_locals=hint_locals, static_locals=dict(segments=self.segments, **static_locals), builtin_runners=self.builtin_runners, program_base=self.program_base, @@ -225,8 +262,8 @@ def initialize_vm( self.vm.validate_existing_memory() def run_until_label( - self, label_or_pc: Union[str, int], - run_resources: Optional[RunResources] = None): + self, label_or_pc: Union[str, int], run_resources: Optional[RunResources] = None + ): """ Runs the VM until label is reached, and stops right before that instruction is executed. 'label_or_pc' should be either a label string or an integer offset from the program_base. @@ -234,8 +271,7 @@ def run_until_label( label = self._to_pc(label_or_pc) self.run_until_pc(self.program_base + label, run_resources=run_resources) - def run_until_pc( - self, addr: MaybeRelocatable, run_resources: Optional[RunResources] = None): + def run_until_pc(self, addr: MaybeRelocatable, run_resources: Optional[RunResources] = None): """ Runs the VM until pc reaches 'addr', and stop right before that instruction is executed. """ @@ -248,14 +284,15 @@ def run_until_pc( if self.vm.run_context.pc != addr: raise self.vm.as_vm_exception( - ResourcesError('Error: End of program was not reached'), - self.vm.run_context.pc) + ResourcesError("Error: End of program was not reached"), self.vm.run_context.pc + ) def vm_step(self): if self.vm.run_context.pc == self.final_pc: raise self.vm.as_vm_exception( - Exception('Error: Execution reached the end of the program.'), - self.vm.run_context.pc) + Exception("Error: Execution reached the end of the program."), + self.vm.run_context.pc, + ) self.vm.step() def run_for_steps(self, steps: int): @@ -279,10 +316,11 @@ def run_until_next_power_of_2(self): self.run_until_steps(next_power_of_2(self.vm.current_step)) def end_run(self, disable_trace_padding: bool = True, disable_finalize_all: bool = False): - assert not self._run_ended, 'end_run called twice.' + assert not self._run_ended, "end_run called twice." self.accessed_addresses = { - self.vm_memory.relocate_value(addr) for addr in self.vm.accessed_addresses} + self.vm_memory.relocate_value(addr) for addr in self.vm.accessed_addresses + } self.vm_memory.relocate_memory() self.vm.end_run() @@ -308,24 +346,27 @@ def read_return_values(self): Reads builtin return values (end pointers) and adds them to the public memory. Note: end_run() must precede a call to this method. """ - assert self._run_ended, 'Run must be ended before calling read_return_values.' + assert self._run_ended, "Run must be ended before calling read_return_values." pointer = self.vm.run_context.ap for builtin_name in self.program.builtins[::-1]: - builtin_runner = self.builtin_runners.get(f'{builtin_name}_builtin') + builtin_runner = self.builtin_runners.get(f"{builtin_name}_builtin") if builtin_runner is None: - assert self.allow_missing_builtins, 'Missing builtin.' + assert self.allow_missing_builtins, "Missing builtin." pointer -= 1 - assert self.vm_memory[pointer] == 0, \ - f'The stop pointer of the missing builtin "{builtin_name}" must be 0.' + assert ( + self.vm_memory[pointer] == 0 + ), f'The stop pointer of the missing builtin "{builtin_name}" must be 0.' else: pointer = builtin_runner.final_stack(self, pointer) - assert not self._segments_finalized, \ - 'Cannot add the return values to the public memory after segment finalization.' + assert ( + not self._segments_finalized + ), "Cannot add the return values to the public memory after segment finalization." # Add return values to public memory. - self.execution_public_memory += list(range( - pointer - self.execution_base, self.vm.run_context.ap - self.execution_base)) + self.execution_public_memory += list( + range(pointer - self.execution_base, self.vm.run_context.ap - self.execution_base) + ) def check_used_cells(self): """ @@ -339,7 +380,7 @@ def check_used_cells(self): self.check_memory_usage() self.check_diluted_check_usage() except InsufficientAllocatedCells as e: - print(f'Warning: {e} Increasing number of steps.') + print(f"Warning: {e} Increasing number of steps.") return False return True @@ -354,14 +395,18 @@ def finalize_segments(self): if self._segments_finalized: return - assert self._run_ended, 'Run must be ended before calling finalize_segments.' + assert self._run_ended, "Run must be ended before calling finalize_segments." self.segments.finalize( - self.program_base.segment_index, size=len(self.program.data), - public_memory=[(i, 0) for i in range(len(self.program.data))]) + self.program_base.segment_index, + size=len(self.program.data), + public_memory=[(i, 0) for i in range(len(self.program.data))], + ) self.segments.finalize( self.execution_base.segment_index, public_memory=[ - (x + self.execution_base.offset, 0) for x in self.execution_public_memory]) + (x + self.execution_base.offset, 0) for x in self.execution_public_memory + ], + ) for builtin_runner in self.builtin_runners.values(): builtin_runner.finalize_segments(self) @@ -397,14 +442,16 @@ def check_range_check_usage(self): instance = LAYOUTS[self.layout] rc_units_used_by_builtins = sum( builtin_runner.get_used_perm_range_check_units(self) - for builtin_runner in self.builtin_runners.values()) + for builtin_runner in self.builtin_runners.values() + ) # Out of the range check units allowed per step three are used for the instruction. unused_rc_units = (instance.rc_units - 3) * self.vm.current_step - rc_units_used_by_builtins - rc_usage_upper_bound = (rc_max - rc_min) + rc_usage_upper_bound = rc_max - rc_min if unused_rc_units < rc_usage_upper_bound: raise InsufficientAllocatedCells( - f'There are only {unused_rc_units} cells to fill the range checks holes, but ' - f'potentially {rc_usage_upper_bound} are required.') + f"There are only {unused_rc_units} cells to fill the range checks holes, but " + f"potentially {rc_usage_upper_bound} are required." + ) def get_memory_holes(self): assert self.accessed_addresses is not None @@ -416,7 +463,8 @@ def get_memory_holes(self): for addr in builtin_runner.get_memory_accesses(self) } return self.segments.get_memory_holes( - accessed_addresses=self.accessed_addresses | builtin_accessed_addresses) + accessed_addresses=self.accessed_addresses | builtin_accessed_addresses + ) def check_memory_usage(self): """ @@ -425,19 +473,22 @@ def check_memory_usage(self): instance = LAYOUTS[self.layout] builtins_memory_units = sum( builtin_runner.get_allocated_memory_units(self) - for builtin_runner in self.builtin_runners.values()) + for builtin_runner in self.builtin_runners.values() + ) # Out of the memory units available per step, a fraction is used for public memory, and # four are used for the instruction. total_memory_units = instance.memory_units_per_step * self.vm.current_step public_memory_units = safe_div(total_memory_units, instance.public_memory_fraction) instruction_memory_units = 4 * self.vm.current_step - unused_memory_units = total_memory_units - \ - (public_memory_units + instruction_memory_units + builtins_memory_units) + unused_memory_units = total_memory_units - ( + public_memory_units + instruction_memory_units + builtins_memory_units + ) memory_address_holes = self.get_memory_holes() if unused_memory_units < memory_address_holes: raise InsufficientAllocatedCells( - f'There are only {unused_memory_units} cells to fill the memory address holes, but ' - f'{memory_address_holes} are required.') + f"There are only {unused_memory_units} cells to fill the memory address holes, but " + f"{memory_address_holes} are required." + ) def check_diluted_check_usage(self): """ @@ -451,19 +502,22 @@ def check_diluted_check_usage(self): builtin_runner.get_used_diluted_check_units( diluted_spacing=instance.diluted_pool_instance_def.spacing, diluted_n_bits=instance.diluted_pool_instance_def.n_bits, - ) * safe_div( + ) + * safe_div( self.vm.current_step, - builtin_runner.ratio if hasattr(builtin_runner, 'ratio') else 1, + builtin_runner.ratio if hasattr(builtin_runner, "ratio") else 1, ) - for builtin_runner in self.builtin_runners.values()) + for builtin_runner in self.builtin_runners.values() + ) diluted_units = instance.diluted_pool_instance_def.units_per_step * self.vm.current_step unused_diluted_units = diluted_units - diluted_units_used_by_builtins diluted_usage_upper_bound = 2 ** instance.diluted_pool_instance_def.n_bits if unused_diluted_units < diluted_usage_upper_bound: raise InsufficientAllocatedCells( - f'There are only {unused_diluted_units} cells to fill the diluted check holes, but ' - f'potentially {diluted_usage_upper_bound} are required.') + f"There are only {unused_diluted_units} cells to fill the diluted check holes, but " + f"potentially {diluted_usage_upper_bound} are required." + ) # Helper functions. @@ -477,13 +531,15 @@ def _to_pc(self, label_or_pc: Union[str, int]) -> int: Otherwise, return it unchanged. """ if isinstance(label_or_pc, str): - assert isinstance(self.program, Program), \ - 'Label name cannot be used with a StrippedProgram.' + assert isinstance( + self.program, Program + ), "Label name cannot be used with a StrippedProgram." return self.program.get_label(label_or_pc) return label_or_pc - def load_data(self, ptr: MaybeRelocatable, data: Sequence[MaybeRelocatable]) -> \ - MaybeRelocatable: + def load_data( + self, ptr: MaybeRelocatable, data: Sequence[MaybeRelocatable] + ) -> MaybeRelocatable: """ Writes data into the memory at address ptr and returns the first address after the data. """ @@ -504,11 +560,15 @@ def relocate_value(self, value): def relocate(self): self.segment_offsets = self.segments.relocate_segments() - self.relocated_memory = MemoryDict({ - self.relocate_value(addr): self.relocate_value(value) - for addr, value in self.vm_memory.items()}) + self.relocated_memory = MemoryDict( + { + self.relocate_value(addr): self.relocate_value(value) + for addr, value in self.vm_memory.items() + } + ) self.relocated_trace = relocate_trace( - self.vm.trace, self.segment_offsets, self.program.prime) + self.vm.trace, self.segment_offsets, self.program.prime + ) for builtin_runner in self.builtin_runners.values(): builtin_runner.relocate(self.relocate_value) @@ -523,14 +583,19 @@ def get_relocated_debug_info(self): def get_memory_segment_addresses(self) -> Dict[str, MemorySegmentAddresses]: def get_segment_addresses( - name: str, segment_addresses: MemorySegmentAddresses) -> MemorySegmentAddresses: - stop_ptr = segment_addresses.stop_ptr if name in self.program.builtins else \ - segment_addresses.begin_addr + name: str, segment_addresses: MemorySegmentAddresses + ) -> MemorySegmentAddresses: + stop_ptr = ( + segment_addresses.stop_ptr + if name in self.program.builtins + else segment_addresses.begin_addr + ) - assert stop_ptr is not None, f'The {name} builtin stop pointer was not set.' + assert stop_ptr is not None, f"The {name} builtin stop pointer was not set." return MemorySegmentAddresses( begin_addr=self.relocate_value(segment_addresses.begin_addr), - stop_ptr=self.relocate_value(stop_ptr)) + stop_ptr=self.relocate_value(stop_ptr), + ) return { name: get_segment_addresses(name, segment_addresses) @@ -539,31 +604,31 @@ def get_segment_addresses( } def print_memory(self, relocated: bool): - print('Addr Value') - print('-----------') + print("Addr Value") + print("-----------") old_addr = -1 memory = self.relocated_memory if relocated else self.vm_memory for addr in sorted(memory.keys()): val = memory[addr] if addr != old_addr + 1: - print('\u22ee') - print(f'{addr:<5} {to_field_element(val=val, prime=self.program.prime)}') + print("\u22ee") + print(f"{addr:<5} {to_field_element(val=val, prime=self.program.prime)}") old_addr = addr print() def print_output(self, output_callback=to_field_element): - if 'output_builtin' not in self.builtin_runners: + if "output_builtin" not in self.builtin_runners: return - output_runner = self.builtin_runners['output_builtin'] - print('Program output:') + output_runner = self.builtin_runners["output_builtin"] + print("Program output:") _, size = output_runner.get_used_cells_and_allocated_size(self) for i in range(size): val = self.vm_memory.get(output_runner.base + i) if val is not None: - print(f' {output_callback(val=val, prime=self.program.prime)}') + print(f" {output_callback(val=val, prime=self.program.prime)}") else: - print(' ') + print(" ") print() @@ -590,19 +655,19 @@ def get_info(self, relocated: bool) -> str: def print_segment_relocation_table(self): if self.segment_offsets is not None: - print('Segment relocation table:') + print("Segment relocation table:") for segment_index in range(self.segments.n_segments): - print(f'{segment_index:<5} {self.segment_offsets[segment_index]}') + print(f"{segment_index:<5} {self.segment_offsets[segment_index]}") def get_builtin_usage(self) -> str: if len(self.builtin_runners) == 0: - return '' + return "" - builtin_usage_str = '\nBuiltin usage:\n' + builtin_usage_str = "\nBuiltin usage:\n" for name, builtin_runner in self.builtin_runners.items(): used, size = builtin_runner.get_used_cells_and_allocated_size(self) - percentage = f'{used / size * 100:.2f}%' if size > 0 else '100%' - builtin_usage_str += f'{name:<30s} {percentage:>7s} (used {used} cells)\n' + percentage = f"{used / size * 100:.2f}%" if size > 0 else "100%" + builtin_usage_str += f"{name:<30s} {percentage:>7s} (used {used} cells)\n" return builtin_usage_str @@ -614,14 +679,16 @@ def get_builtin_segments_info(self): for builtin in self.builtin_runners.values(): for name, segment_addresses in builtin.get_memory_segment_addresses(self).items(): begin_addr = segment_addresses.begin_addr - assert isinstance(begin_addr, RelocatableValue), \ - f'{name} segment begin_addr is not a RelocatableValue {begin_addr}.' - assert begin_addr.offset == 0, \ - f'Unexpected {name} segment begin_addr {begin_addr.offset}.' - assert segment_addresses.stop_ptr is not None, f'{name} segment stop ptr is None.' + assert isinstance( + begin_addr, RelocatableValue + ), f"{name} segment begin_addr is not a RelocatableValue {begin_addr}." + assert ( + begin_addr.offset == 0 + ), f"Unexpected {name} segment begin_addr {begin_addr.offset}." + assert segment_addresses.stop_ptr is not None, f"{name} segment stop ptr is None." segment_index = begin_addr.segment_index segment_size = segment_addresses.stop_ptr - begin_addr - assert name not in builtin_segments, f'Builtin segment name collision: {name}.' + assert name not in builtin_segments, f"Builtin segment name collision: {name}." builtin_segments[name] = SegmentInfo(index=segment_index, size=segment_size) return builtin_segments @@ -635,7 +702,8 @@ def get_execution_resources(self) -> ExecutionResources: return ExecutionResources( n_steps=n_steps, n_memory_holes=n_memory_holes, - builtin_instance_counter=builtin_instance_counter) + builtin_instance_counter=builtin_instance_counter, + ) def get_cairo_pie(self) -> CairoPie: """ @@ -649,20 +717,23 @@ def get_cairo_pie(self) -> CairoPie: # Note that n_used_builtins might be smaller then len(builtin_segments). n_used_builtins = len(self.program.builtins) ret_fp, ret_pc = ( - self.vm_memory[self.execution_base + n_used_builtins + i] for i in range(2)) + self.vm_memory[self.execution_base + n_used_builtins + i] for i in range(2) + ) - assert isinstance(ret_fp, RelocatableValue), f'Expecting a relocatable value got {ret_fp}.' - assert isinstance(ret_pc, RelocatableValue), f'Expecting a relocatable value got {ret_pc}.' + assert isinstance(ret_fp, RelocatableValue), f"Expecting a relocatable value got {ret_fp}." + assert isinstance(ret_pc, RelocatableValue), f"Expecting a relocatable value got {ret_pc}." assert self.segments.get_segment_size(ret_fp.segment_index) == 0, ( - 'Unexpected ret_fp_segment size ' - f'{self.segments.get_segment_size(ret_fp.segment_index)}') + "Unexpected ret_fp_segment size " + f"{self.segments.get_segment_size(ret_fp.segment_index)}" + ) assert self.segments.get_segment_size(ret_pc.segment_index) == 0, ( - 'Unexpected ret_pc_segment size ' - f'{self.segments.get_segment_size(ret_pc.segment_index)}') + "Unexpected ret_pc_segment size " + f"{self.segments.get_segment_size(ret_pc.segment_index)}" + ) for addr in self.program_base, self.execution_base, ret_fp, ret_pc: - assert addr.offset == 0, 'Expecting a 0 offset.' + assert addr.offset == 0, "Expecting a 0 offset." known_segment_indices[addr.segment_index] = None # Put all the remaining segments in extra_segments. @@ -676,9 +747,11 @@ def get_cairo_pie(self) -> CairoPie: cairo_pie_metadata = CairoPieMetadata( program=self.program.stripped(), program_segment=SegmentInfo( - index=self.program_base.segment_index, size=len(self.program.data)), + index=self.program_base.segment_index, size=len(self.program.data) + ), execution_segment=SegmentInfo( - index=self.execution_base.segment_index, size=execution_size), + index=self.execution_base.segment_index, size=execution_size + ), ret_fp_segment=SegmentInfo(ret_fp.segment_index, size=0), ret_pc_segment=SegmentInfo(ret_pc.segment_index, size=0), builtin_segments=builtin_segments, @@ -699,7 +772,8 @@ def get_cairo_pie(self) -> CairoPie: def get_runner_from_code( - code: Union[str, Sequence[Tuple[str, str]]], layout: str, prime: int) -> CairoRunner: + code: Union[str, Sequence[Tuple[str, str]]], layout: str, prime: int +) -> CairoRunner: """ Given a code with some compile and run parameters (prime, layout, etc.), runs the code using Cairo runner and returns the runner. diff --git a/src/starkware/cairo/lang/vm/cairo_runner_test.py b/src/starkware/cairo/lang/vm/cairo_runner_test.py index a465a843..263a7bbc 100644 --- a/src/starkware/cairo/lang/vm/cairo_runner_test.py +++ b/src/starkware/cairo/lang/vm/cairo_runner_test.py @@ -9,7 +9,7 @@ from starkware.cairo.lang.vm.utils import RunResources from starkware.cairo.lang.vm.vm import VmException, VmExceptionBase -CAIRO_FILE = os.path.join(os.path.dirname(__file__), 'test.cairo') +CAIRO_FILE = os.path.join(os.path.dirname(__file__), "test.cairo") PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 @@ -24,11 +24,11 @@ def test_run_until_label(): runner.run_until_label(3) assert runner.vm.run_context.pc - runner.program_base == 3 assert runner.vm.current_step == 3 - runner.run_until_label('label1') + runner.run_until_label("label1") assert runner.vm.run_context.pc - runner.program_base == 6 assert runner.vm.current_step == 6 - with pytest.raises(VmException, match='End of program was not reached'): - runner.run_until_label('label0', run_resources=RunResources(steps=100)) + with pytest.raises(VmException, match="End of program was not reached"): + runner.run_until_label("label0", run_resources=RunResources(steps=100)) assert runner.vm.run_context.pc - runner.program_base == 8 assert runner.vm.current_step == 106 runner.run_until_next_power_of_2() @@ -42,13 +42,13 @@ def test_run_past_end(): end """ program = compile_cairo(code, PRIME) - runner = CairoRunner(program, layout='plain') + runner = CairoRunner(program, layout="plain") runner.initialize_segments() runner.initialize_main_entrypoint() runner.initialize_vm({}) runner.run_for_steps(1) - with pytest.raises(VmException, match='Error: Execution reached the end of the program.'): + with pytest.raises(VmException, match="Error: Execution reached the end of the program."): runner.run_for_steps(1) @@ -64,30 +64,32 @@ def test_bad_stop_ptr(): end """ with pytest.raises( - AssertionError, - match='Invalid stop pointer for output. Expected: 2:1, found: 2:3'): - get_runner_from_code(code, layout='small', prime=PRIME) + AssertionError, match="Invalid stop pointer for output. Expected: 2:1, found: 2:3" + ): + get_runner_from_code(code, layout="small", prime=PRIME) def test_builtin_list(): # This should work. program = compile_cairo( - code=[('%builtins output pedersen range_check ecdsa\n', '')], prime=PRIME) - CairoRunner(program, layout='small') + code=[("%builtins output pedersen range_check ecdsa\n", "")], prime=PRIME + ) + CairoRunner(program, layout="small") # These should fail. - program = compile_cairo(code=[('%builtins pedersen output\n', '')], prime=PRIME) + program = compile_cairo(code=[("%builtins pedersen output\n", "")], prime=PRIME) with pytest.raises( - AssertionError, - match=r"\['pedersen', 'output'\] is not a subsequence of " - r"\['output', 'pedersen', 'range_check', 'ecdsa', 'bitwise']."): - CairoRunner(program, layout='small') + AssertionError, + match=r"\['pedersen', 'output'\] is not a subsequence of " + r"\['output', 'pedersen', 'range_check', 'ecdsa', 'bitwise'].", + ): + CairoRunner(program, layout="small") - program = compile_cairo(code=[('%builtins pedersen foo\n', '')], prime=PRIME) + program = compile_cairo(code=[("%builtins pedersen foo\n", "")], prime=PRIME) with pytest.raises( - AssertionError, - match=r'Builtins {\'foo\'} are not present in layout "small"'): - CairoRunner(program, layout='small') + AssertionError, match=r'Builtins {\'foo\'} are not present in layout "small"' + ): + CairoRunner(program, layout="small") def test_missing_exit_scope(): @@ -98,9 +100,10 @@ def test_missing_exit_scope(): end """ with pytest.raises( - VmExceptionBase, - match=re.escape('Every enter_scope() requires a corresponding exit_scope().')): - runner = get_runner_from_code(code, layout='small', prime=PRIME) + VmExceptionBase, + match=re.escape("Every enter_scope() requires a corresponding exit_scope()."), + ): + runner = get_runner_from_code(code, layout="small", prime=PRIME) def test_load_data_after_init(): @@ -109,7 +112,7 @@ def test_load_data_after_init(): ret end """ - runner = get_runner_from_code(code, layout='plain', prime=PRIME) + runner = get_runner_from_code(code, layout="plain", prime=PRIME) addr = runner.segments.add() runner.vm_memory.unfreeze_for_testing() runner.load_data(addr, [42]) @@ -125,7 +128,7 @@ def test_small_memory_hole(): ret end """ - runner = get_runner_from_code(code, layout='plain', prime=PRIME) + runner = get_runner_from_code(code, layout="plain", prime=PRIME) runner.check_memory_usage() @@ -138,12 +141,14 @@ def test_memory_hole_insufficient(): ret end """ - runner = get_runner_from_code(code, layout='plain', prime=PRIME) + runner = get_runner_from_code(code, layout="plain", prime=PRIME) with pytest.raises( - InsufficientAllocatedCells, - match=re.escape( - 'There are only 8 cells to fill the memory address holes, but 999 are required.')): + InsufficientAllocatedCells, + match=re.escape( + "There are only 8 cells to fill the memory address holes, but 999 are required." + ), + ): runner.check_memory_usage() @@ -167,11 +172,13 @@ def test_hint_memory_holes(): """ code_no_hint, code_untouched_hint, code_touched_hint = [ code_base_format.format(extra_code) - for extra_code in ['', '%{ memory[ap] = 7 %}', '%{ memory[ap] = 7 %}\n [ap]=[ap]']] + for extra_code in ["", "%{ memory[ap] = 7 %}", "%{ memory[ap] = 7 %}\n [ap]=[ap]"] + ] runner_no_hint, runner_untouched_hint, runner_touched_hint = [ - get_runner_from_code(code, layout='plain', prime=PRIME) - for code in (code_no_hint, code_untouched_hint, code_touched_hint)] + get_runner_from_code(code, layout="plain", prime=PRIME) + for code in (code_no_hint, code_untouched_hint, code_touched_hint) + ] def filter_program_segment(addr_lst): return {addr for addr in addr_lst if addr.segment_index != 0} @@ -189,14 +196,20 @@ def filter_program_segment(addr_lst): } assert filter_program_segment(runner_no_hint.vm_memory.keys()) == accessed_addresses assert filter_program_segment(runner_no_hint.accessed_addresses) == accessed_addresses - assert filter_program_segment(runner_untouched_hint.vm_memory.keys()) == \ - accessed_addresses | {initial_ap + 7} + assert filter_program_segment(runner_untouched_hint.vm_memory.keys()) == accessed_addresses | { + initial_ap + 7 + } assert filter_program_segment(runner_untouched_hint.accessed_addresses) == accessed_addresses - assert filter_program_segment(runner_touched_hint.vm_memory.keys()) == \ - accessed_addresses | {initial_ap + 7} - assert filter_program_segment(runner_touched_hint.accessed_addresses) == \ - accessed_addresses | {initial_ap + 7} - - assert runner_no_hint.get_memory_holes() == \ - runner_untouched_hint.get_memory_holes() == \ - runner_touched_hint.get_memory_holes() + 1 == 5 + assert filter_program_segment(runner_touched_hint.vm_memory.keys()) == accessed_addresses | { + initial_ap + 7 + } + assert filter_program_segment(runner_touched_hint.accessed_addresses) == accessed_addresses | { + initial_ap + 7 + } + + assert ( + runner_no_hint.get_memory_holes() + == runner_untouched_hint.get_memory_holes() + == runner_touched_hint.get_memory_holes() + 1 + == 5 + ) diff --git a/src/starkware/cairo/lang/vm/memory_dict.py b/src/starkware/cairo/lang/vm/memory_dict.py index 613745be..bbdd9025 100644 --- a/src/starkware/cairo/lang/vm/memory_dict.py +++ b/src/starkware/cairo/lang/vm/memory_dict.py @@ -9,7 +9,7 @@ class UnknownMemoryError(KeyError): def __init__(self, addr): self.addr = addr - super().__init__(f'Unknown value for memory cell at address {addr}.') + super().__init__(f"Unknown value for memory cell at address {addr}.") __str__: Callable[[BaseException], str] = Exception.__str__ @@ -20,7 +20,8 @@ def __init__(self, addr, old_value, new_value): self.old_value = old_value self.new_value = new_value super().__init__( - f'Inconsistent memory assignment at address {addr}. {old_value} != {new_value}.') + f"Inconsistent memory assignment at address {addr}. {old_value} != {new_value}." + ) class MemoryDict(UserDict): @@ -48,12 +49,11 @@ def _check_element(self, num: MaybeRelocatable, name: str, exc_type: Type[Except if isinstance(num, RelocatableValue): return if not isinstance(num, int): - raise exc_type(f'{name} must be an int, not {type(num).__name__}.') + raise exc_type(f"{name} must be an int, not {type(num).__name__}.") if num < 0: - raise exc_type(f'{name} must be nonnegative. Got {num}.') + raise exc_type(f"{name} must be nonnegative. Got {num}.") - def add_relocation_rule( - self, src_ptr: RelocatableValue, dest_ptr: RelocatableValue): + def add_relocation_rule(self, src_ptr: RelocatableValue, dest_ptr: RelocatableValue): """ Adds a relocation rule that moves values from the 'src_ptr' segment to 'dest_ptr'. @@ -65,11 +65,12 @@ def add_relocation_rule( and consequently adding a relocation rule does not allow the VM to read the value at memory['dest_ptr']. """ - assert src_ptr.segment_index < 0, f'src_ptr.segment_index must be < 0, src_ptr={src_ptr}.' - assert src_ptr.offset == 0, f'src_ptr.offset must be 0, src_ptr={src_ptr}.' + assert src_ptr.segment_index < 0, f"src_ptr.segment_index must be < 0, src_ptr={src_ptr}." + assert src_ptr.offset == 0, f"src_ptr.offset must be 0, src_ptr={src_ptr}." segment_index = src_ptr.segment_index - assert segment_index not in self.relocation_rules, \ - f'The segment with index {segment_index} already has a relocation rule.' + assert ( + segment_index not in self.relocation_rules + ), f"The segment with index {segment_index} already has a relocation rule." self.relocation_rules[segment_index] = dest_ptr @@ -111,19 +112,18 @@ def relocate_memory(self): """ Relocates the memory according to the relocation rules and clears self.relocation_rules. """ - assert not self._frozen, 'Memory is frozen and cannot be changed.' + assert not self._frozen, "Memory is frozen and cannot be changed." if len(self.relocation_rules) == 0: return self.data = { - self.relocate_value(addr): self.relocate_value(value) - for addr, value in self.items() + self.relocate_value(addr): self.relocate_value(value) for addr, value in self.items() } self.relocation_rules = {} def __getitem__(self, addr: MaybeRelocatable) -> MaybeRelocatable: - self._check_element(addr, 'Memory address', KeyError) + self._check_element(addr, "Memory address", KeyError) try: value = super().__getitem__(addr) except KeyError: @@ -132,14 +132,15 @@ def __getitem__(self, addr: MaybeRelocatable) -> MaybeRelocatable: return self.relocate_value(value) def __setitem__(self, addr: MaybeRelocatable, value: MaybeRelocatable): - assert not self._frozen, 'Memory is frozen and cannot be changed.' - self._check_element(addr, 'Memory address', KeyError) - self._check_element(value, 'Memory value', ValueError) + assert not self._frozen, "Memory is frozen and cannot be changed." + self._check_element(addr, "Memory address", KeyError) + self._check_element(value, "Memory value", ValueError) # Additionally, check that address doesn't have a negative offset. if isinstance(addr, RelocatableValue) and addr.offset < 0: raise ValueError( - f'The offset of a relocatable value must be nonnegative. Found: {addr}.') + f"The offset of a relocatable value must be nonnegative. Found: {addr}." + ) current = self.data.setdefault(addr, value) self.verify_same_value(addr, current, value) @@ -160,13 +161,15 @@ def set_without_checks(self, addr: MaybeRelocatable, value: MaybeRelocatable): self.data[addr] = value def serialize(self, field_bytes): - assert len(self.relocation_rules) == 0, \ - 'Cannot serialize a MemoryDict with active segment relocation rules.' + assert ( + len(self.relocation_rules) == 0 + ), "Cannot serialize a MemoryDict with active segment relocation rules." - return b''.join( - RelocatableValue.to_bytes(addr, ADDR_SIZE_IN_BYTES, 'little') + - RelocatableValue.to_bytes(value, field_bytes, 'little') - for addr, value in self.items()) + return b"".join( + RelocatableValue.to_bytes(addr, ADDR_SIZE_IN_BYTES, "little") + + RelocatableValue.to_bytes(value, field_bytes, "little") + for addr, value in self.items() + ) def get_range(self, addr, size) -> List[MaybeRelocatable]: return [self[addr + i] for i in range(size)] @@ -183,14 +186,16 @@ def get_range_as_ints(self, addr, size) -> List[int]: @classmethod def deserialize(cls, data, field_bytes): pair_size = ADDR_SIZE_IN_BYTES + field_bytes - assert len(data) % (pair_size) == 0,\ - f'Data must consist of pairs of address (8 bytes) and value ({field_bytes} bytes).' + assert ( + len(data) % (pair_size) == 0 + ), f"Data must consist of pairs of address (8 bytes) and value ({field_bytes} bytes)." pair_stream = ( - data[pair_size * i: pair_size * (i + 1)] - for i in range(len(data) // pair_size)) + data[pair_size * i : pair_size * (i + 1)] for i in range(len(data) // pair_size) + ) return cls( ( - RelocatableValue.from_bytes(pair[:ADDR_SIZE_IN_BYTES], 'little'), - RelocatableValue.from_bytes(pair[ADDR_SIZE_IN_BYTES:], 'little') + RelocatableValue.from_bytes(pair[:ADDR_SIZE_IN_BYTES], "little"), + RelocatableValue.from_bytes(pair[ADDR_SIZE_IN_BYTES:], "little"), ) - for pair in pair_stream) + for pair in pair_stream + ) diff --git a/src/starkware/cairo/lang/vm/memory_dict_test.py b/src/starkware/cairo/lang/vm/memory_dict_test.py index 8d534e04..a320da24 100644 --- a/src/starkware/cairo/lang/vm/memory_dict_test.py +++ b/src/starkware/cairo/lang/vm/memory_dict_test.py @@ -1,15 +1,52 @@ import pytest from starkware.cairo.lang.vm.memory_dict import ( - InconsistentMemoryError, MemoryDict, UnknownMemoryError) + InconsistentMemoryError, + MemoryDict, + UnknownMemoryError, +) from starkware.cairo.lang.vm.relocatable import RelocatableValue def test_memory_dict_serialize(): memory = MemoryDict({1: 2, 3: 4, 5: 6}) - expected_serialized = bytes([ - 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, - 6, 0, 0]) + expected_serialized = bytes( + [ + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 0, + 0, + 3, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 4, + 0, + 0, + 5, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 6, + 0, + 0, + ] + ) serialized = memory.serialize(3) assert expected_serialized == serialized assert MemoryDict.deserialize(serialized, 3) == memory @@ -23,11 +60,11 @@ def test_memory_dict_getitem(): def test_memory_dict_check_element(): memory = MemoryDict() - with pytest.raises(KeyError, match='must be an int'): - memory['not a number'] = 12 - with pytest.raises(KeyError, match='must be nonnegative'): + with pytest.raises(KeyError, match="must be an int"): + memory["not a number"] = 12 + with pytest.raises(KeyError, match="must be nonnegative"): memory[-12] = 13 - with pytest.raises(ValueError, match='The offset of a relocatable value must be nonnegative'): + with pytest.raises(ValueError, match="The offset of a relocatable value must be nonnegative"): memory[RelocatableValue(segment_index=10, offset=-2)] = 13 # A value may have a negative offset. memory[13] = RelocatableValue(segment_index=10, offset=-2) @@ -35,9 +72,9 @@ def test_memory_dict_check_element(): def test_memory_dict_get(): memory = MemoryDict({14: 15}) - assert memory.get(14, 'default') == 15 - assert memory.get(1234, 'default') == 'default' - assert memory.get(-10, 'default') == 'default' + assert memory.get(14, "default") == 15 + assert memory.get(1234, "default") == "default" + assert memory.get(-10, "default") == "default" # Attempting to read address with a negative offset is ok, it simply returns None. assert memory.get(RelocatableValue(segment_index=10, offset=-2)) is None @@ -48,11 +85,11 @@ def test_memory_dict_setdefault(): assert memory[14] == 15 memory.setdefault(123, 456) assert memory[123] == 456 - with pytest.raises(ValueError, match='must be an int'): - memory.setdefault(10, 'default') - with pytest.raises(KeyError, match='must be nonnegative'): + with pytest.raises(ValueError, match="must be an int"): + memory.setdefault(10, "default") + with pytest.raises(KeyError, match="must be nonnegative"): memory.setdefault(-10, 123) - with pytest.raises(ValueError, match='The offset of a relocatable value must be nonnegative'): + with pytest.raises(ValueError, match="The offset of a relocatable value must be nonnegative"): memory[RelocatableValue(segment_index=10, offset=-2)] = 13 @@ -94,18 +131,23 @@ def test_segment_relocation_failures(): memory = MemoryDict() relocation_target = RelocatableValue(segment_index=4, offset=25) - with pytest.raises(AssertionError, match='src_ptr.segment_index must be < 0, src_ptr=1:2.'): - memory.add_relocation_rule(src_ptr=RelocatableValue( - segment_index=1, offset=2), dest_ptr=relocation_target) + with pytest.raises(AssertionError, match="src_ptr.segment_index must be < 0, src_ptr=1:2."): + memory.add_relocation_rule( + src_ptr=RelocatableValue(segment_index=1, offset=2), dest_ptr=relocation_target + ) - with pytest.raises(AssertionError, match='src_ptr.offset must be 0, src_ptr=-3:2.'): - memory.add_relocation_rule(src_ptr=RelocatableValue( - segment_index=-3, offset=2), dest_ptr=relocation_target) + with pytest.raises(AssertionError, match="src_ptr.offset must be 0, src_ptr=-3:2."): + memory.add_relocation_rule( + src_ptr=RelocatableValue(segment_index=-3, offset=2), dest_ptr=relocation_target + ) - memory.add_relocation_rule(src_ptr=RelocatableValue( - segment_index=-3, offset=0), dest_ptr=relocation_target) + memory.add_relocation_rule( + src_ptr=RelocatableValue(segment_index=-3, offset=0), dest_ptr=relocation_target + ) with pytest.raises( - AssertionError, match='The segment with index -3 already has a relocation rule.'): - memory.add_relocation_rule(src_ptr=RelocatableValue( - segment_index=-3, offset=0), dest_ptr=relocation_target) + AssertionError, match="The segment with index -3 already has a relocation rule." + ): + memory.add_relocation_rule( + src_ptr=RelocatableValue(segment_index=-3, offset=0), dest_ptr=relocation_target + ) diff --git a/src/starkware/cairo/lang/vm/memory_segments.py b/src/starkware/cairo/lang/vm/memory_segments.py index 3c573818..e689e28c 100644 --- a/src/starkware/cairo/lang/vm/memory_segments.py +++ b/src/starkware/cairo/lang/vm/memory_segments.py @@ -53,8 +53,11 @@ def add_temp_segment(self) -> RelocatableValue: return RelocatableValue(segment_index=segment_index, offset=0) def finalize( - self, segment_index: int, size: Optional[int] = None, - public_memory: Sequence[Tuple[int, int]] = []): + self, + segment_index: int, + size: Optional[int] = None, + public_memory: Sequence[Tuple[int, int]] = [], + ): """ Writes the following information for the given segment: * size - The size of the segment (to be used in relocate_segments). @@ -75,14 +78,16 @@ def compute_effective_sizes(self, include_tmp_segments: bool = False): # segment_sizes is already cached. return - assert self.memory.is_frozen(), 'Memory has to be frozen before calculating effective size.' + assert self.memory.is_frozen(), "Memory has to be frozen before calculating effective size." first_segment_index = -self.n_temp_segments if include_tmp_segments else 0 self._segment_used_sizes = { - index: 0 for index in range(first_segment_index, self.n_segments)} + index: 0 for index in range(first_segment_index, self.n_segments) + } for addr in self.memory: - assert isinstance(addr, RelocatableValue), \ - f'Expected memory address to be relocatable value. Found: {addr}.' + assert isinstance( + addr, RelocatableValue + ), f"Expected memory address to be relocatable value. Found: {addr}." previous_max_size = self._segment_used_sizes[addr.segment_index] self._segment_used_sizes[addr.segment_index] = max(previous_max_size, addr.offset + 1) @@ -90,13 +95,14 @@ def relocate_segments(self) -> Dict[int, int]: current_addr = FIRST_MEMORY_ADDR res = {} - assert self._segment_used_sizes is not None, \ - 'compute_effective_sizes must be called before relocate_segments.' + assert ( + self._segment_used_sizes is not None + ), "compute_effective_sizes must be called before relocate_segments." for segment_index, used_size in self._segment_used_sizes.items(): res[segment_index] = current_addr size = self.get_segment_size(segment_index=segment_index) - assert size >= used_size, f'Segment {segment_index} exceeded its allocated size.' + assert size >= used_size, f"Segment {segment_index} exceeded its allocated size." current_addr += size return res @@ -113,18 +119,20 @@ def get_public_memory_addresses(self, segment_offsets: Dict[int, int]) -> List[T res.append((segment_start + offset, page_id)) return res - def initialize_segments_from(self, other: 'MemorySegmentManager'): + def initialize_segments_from(self, other: "MemorySegmentManager"): """ Adds the segments used by the given MemorySegmentManager. Note that this function must be called before any segments are added, to make the segment indices identical. """ - assert self.n_segments == 0, \ - 'initialize_segments_from() must be called before segments are added.' + assert ( + self.n_segments == 0 + ), "initialize_segments_from() must be called before segments are added." self.n_segments = other.n_segments - def load_data(self, ptr: MaybeRelocatable, data: Sequence[MaybeRelocatable]) -> \ - MaybeRelocatable: + def load_data( + self, ptr: MaybeRelocatable, data: Sequence[MaybeRelocatable] + ) -> MaybeRelocatable: """ Writes data into the memory at address ptr and returns the first address after the data. """ @@ -159,24 +167,29 @@ def get_memory_holes(self, accessed_addresses: Set[MaybeRelocatable]) -> int: # A map from segment index to the set of accessed offsets. accessed_offsets_sets: Dict[int, Set] = defaultdict(set) for addr in accessed_addresses: - assert isinstance(addr, RelocatableValue), \ - f'Expected memory address to be relocatable value. Found: {addr}.' + assert isinstance( + addr, RelocatableValue + ), f"Expected memory address to be relocatable value. Found: {addr}." index, offset = addr.segment_index, addr.offset - assert offset >= 0, f'Address offsets must be non-negative. Found: {offset}.' - assert offset <= self.get_segment_size(segment_index=index), \ - f'Accessed address {addr} has higher offset than the maximal offset ' \ - f'{self.get_segment_size(segment_index=index)} encountered in the memory segment.' + assert offset >= 0, f"Address offsets must be non-negative. Found: {offset}." + assert offset <= self.get_segment_size(segment_index=index), ( + f"Accessed address {addr} has higher offset than the maximal offset " + f"{self.get_segment_size(segment_index=index)} encountered in the memory segment." + ) accessed_offsets_sets[index].add(offset) - assert self._segment_used_sizes is not None, \ - 'compute_effective_sizes must be called before get_memory_holes.' + assert ( + self._segment_used_sizes is not None + ), "compute_effective_sizes must be called before get_memory_holes." return sum( self.get_segment_size(segment_index=index) - len(accessed_offsets_sets[index]) - for index in self._segment_sizes.keys() | self._segment_used_sizes.keys()) + for index in self._segment_sizes.keys() | self._segment_used_sizes.keys() + ) def get_segment_used_size(self, segment_index: int) -> int: - assert self._segment_used_sizes is not None, \ - 'compute_effective_sizes must be called before get_segment_used_size.' + assert ( + self._segment_used_sizes is not None + ), "compute_effective_sizes must be called before get_segment_used_size." return self._segment_used_sizes[segment_index] @@ -185,5 +198,8 @@ def get_segment_size(self, segment_index: int) -> int: Returns the finalized size of the given segment. If the segment has not been finalized, returns its used size. """ - return self._segment_sizes[segment_index] if segment_index in self._segment_sizes \ + return ( + self._segment_sizes[segment_index] + if segment_index in self._segment_sizes else self.get_segment_used_size(segment_index=segment_index) + ) diff --git a/src/starkware/cairo/lang/vm/memory_segments_test.py b/src/starkware/cairo/lang/vm/memory_segments_test.py index b15ad3d4..a020d9c1 100644 --- a/src/starkware/cairo/lang/vm/memory_segments_test.py +++ b/src/starkware/cairo/lang/vm/memory_segments_test.py @@ -4,7 +4,7 @@ from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager from starkware.cairo.lang.vm.relocatable import RelocatableValue -PRIME = 2**251 + 17 * 2**192 + 1 +PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 def test_relocate_segments(): @@ -35,30 +35,43 @@ def test_relocate_segments(): segment_offsets = segments.relocate_segments() assert segment_offsets == {0: 1, 1: 4, 2: 12, 3: 12, 4: 13, 5: 15, 6: 20} assert segments.get_public_memory_addresses(segment_offsets) == [ - (1, 0), (2, 1), (4, 0), (5, 0), (6, 0), (7, 0), (8, 0), (9, 0), (10, 0), (11, 0), (14, 2)] + (1, 0), + (2, 1), + (4, 0), + (5, 0), + (6, 0), + (7, 0), + (8, 0), + (9, 0), + (10, 0), + (11, 0), + (14, 2), + ] # Negative flows. segments = MemorySegmentManager(memory=MemoryDict({}), prime=PRIME) segments.add(size=1) - with pytest.raises(AssertionError, match='compute_effective_sizes must be called before'): + with pytest.raises(AssertionError, match="compute_effective_sizes must be called before"): segments.relocate_segments() segments.memory[RelocatableValue(0, 2)] = 0 segments.memory.freeze() segments.compute_effective_sizes() - with pytest.raises(AssertionError, match='Segment 0 exceeded its allocated size'): + with pytest.raises(AssertionError, match="Segment 0 exceeded its allocated size"): segments.relocate_segments() def test_get_segment_used_size(): - memory = MemoryDict({ - RelocatableValue(0, 0): 0, - RelocatableValue(0, 2): 0, - RelocatableValue(1, 5): 0, - RelocatableValue(1, 7): 0, - RelocatableValue(3, 0): 0, - RelocatableValue(4, 1): 0, - }) + memory = MemoryDict( + { + RelocatableValue(0, 0): 0, + RelocatableValue(0, 2): 0, + RelocatableValue(1, 5): 0, + RelocatableValue(1, 7): 0, + RelocatableValue(3, 0): 0, + RelocatableValue(4, 1): 0, + } + ) segments = MemorySegmentManager(memory=memory, prime=PRIME) segments.n_segments = 5 memory.freeze() diff --git a/src/starkware/cairo/lang/vm/output_builtin_runner.py b/src/starkware/cairo/lang/vm/output_builtin_runner.py index 852da2ac..5d90b34d 100644 --- a/src/starkware/cairo/lang/vm/output_builtin_runner.py +++ b/src/starkware/cairo/lang/vm/output_builtin_runner.py @@ -27,16 +27,17 @@ def initialize_segments(self, runner): self.stop_ptr: Optional[RelocatableValue] = None def initial_stack(self) -> List[MaybeRelocatable]: - assert self.base is not None, 'Uninitialized self.base.' + assert self.base is not None, "Uninitialized self.base." return [self.base] if self.included else [] def final_stack(self, runner, pointer): if self.included: self.stop_ptr = runner.vm_memory[pointer - 1] used = self.get_used_cells(runner=runner) - assert self.stop_ptr == self.base + used, \ - 'Invalid stop pointer for output. ' + \ - f'Expected: {self.base + used}, found: {self.stop_ptr}' + assert self.stop_ptr == self.base + used, ( + "Invalid stop pointer for output. " + + f"Expected: {self.base + used}, found: {self.stop_ptr}" + ) return pointer - 1 else: self.stop_ptr = self.base @@ -64,23 +65,25 @@ def finalize_segments(self, runner): # A map from an offset to its page id. offset_to_page = {} for page_id, page in self.pages.items(): - assert page.start + page.size <= size, f'Page {page_id} is out of bounds.' + assert page.start + page.size <= size, f"Page {page_id} is out of bounds." for i in range(page.start, page.start + page.size): - assert offset_to_page.setdefault(i, page_id) == page_id, \ - f'Offset {i} was already assigned a page.' + assert ( + offset_to_page.setdefault(i, page_id) == page_id + ), f"Offset {i} was already assigned a page." public_memory: List[Tuple[int, int]] = [] for i in range(size): public_memory.append((i, offset_to_page.get(i, 0))) - runner.segments.finalize( - self.base.segment_index, size=size, public_memory=public_memory) + runner.segments.finalize(self.base.segment_index, size=size, public_memory=public_memory) def get_memory_segment_addresses(self, runner): - return {'output': MemorySegmentAddresses( - begin_addr=self.base, - stop_ptr=self.stop_ptr, - )} + return { + "output": MemorySegmentAddresses( + begin_addr=self.base, + stop_ptr=self.stop_ptr, + ) + } def add_page(self, page_id: int, page_start: MaybeRelocatable, page_size: int): """ @@ -89,10 +92,11 @@ def add_page(self, page_id: int, page_start: MaybeRelocatable, page_size: int): All public memory cells which were not assigned a page, will be in page 0. This function should be used in Cairo hints. """ - assert page_id not in self.pages, f'Page {page_id} was already assigned.' - assert isinstance(page_start, RelocatableValue) and \ - page_start.segment_index == self.base.segment_index, \ - 'page_start must be in the output segment.' + assert page_id not in self.pages, f"Page {page_id} was already assigned." + assert ( + isinstance(page_start, RelocatableValue) + and page_start.segment_index == self.base.segment_index + ), "page_start must be in the output segment." start = page_start - self.base self.pages[page_id] = PublicMemoryPage(start=start, size=page_size) @@ -106,29 +110,33 @@ def add_attribute(self, attribute_name: str, attribute_value: dict): def get_additional_data(self): return { - 'pages': { + "pages": { str(page_id): [page_info.start, page_info.size] - for page_id, page_info in sorted(self.pages.items())}, - 'attributes': self.attributes, + for page_id, page_info in sorted(self.pages.items()) + }, + "attributes": self.attributes, } def extend_additional_data(self, data, relocate_callback, data_is_trusted=True): - assert isinstance(data, dict) and sorted(data.keys()) == ['attributes', 'pages'], \ - 'Invalid output builtin data.' + assert isinstance(data, dict) and sorted(data.keys()) == [ + "attributes", + "pages", + ], "Invalid output builtin data." # Process the 'pages' field. - assert isinstance(data['pages'], dict), 'Invalid output builtin pages field.' - for page_id_str, values in data['pages'].items(): - assert isinstance(page_id_str, str) and \ - isinstance(values, list) and \ - len(values) == 2 and \ - all(isinstance(x, int) and 0 < x < 2**30 for x in values), \ - 'Invalid output builtin pages field.' + assert isinstance(data["pages"], dict), "Invalid output builtin pages field." + for page_id_str, values in data["pages"].items(): + assert ( + isinstance(page_id_str, str) + and isinstance(values, list) + and len(values) == 2 + and all(isinstance(x, int) and 0 < x < 2 ** 30 for x in values) + ), "Invalid output builtin pages field." self.pages[int(page_id_str)] = PublicMemoryPage(start=values[0], size=values[1]) # Process the 'attributes' field. - assert isinstance(data['attributes'], dict), 'Invalid output builtin attributes field.' - self.attributes.update(data['attributes']) + assert isinstance(data["attributes"], dict), "Invalid output builtin attributes field." + self.attributes.update(data["attributes"]) def run_security_checks(self, runner): return @@ -174,6 +182,6 @@ def expected_stack(self, public_input): if not self.included: return [], [] - addresses = public_input.memory_segments['output'] - assert 0 <= addresses.begin_addr <= addresses.stop_ptr < 2**64 + addresses = public_input.memory_segments["output"] + assert 0 <= addresses.begin_addr <= addresses.stop_ptr < 2 ** 64 return [addresses.begin_addr], [addresses.stop_ptr] diff --git a/src/starkware/cairo/lang/vm/output_builtin_runner_test.py b/src/starkware/cairo/lang/vm/output_builtin_runner_test.py index e4182a6e..088dc668 100644 --- a/src/starkware/cairo/lang/vm/output_builtin_runner_test.py +++ b/src/starkware/cairo/lang/vm/output_builtin_runner_test.py @@ -16,11 +16,12 @@ def runner_and_output_runner(): ret end """ - program = compile_cairo(code=[(code, '')], prime=PRIME, add_start=True) + program = compile_cairo(code=[(code, "")], prime=PRIME, add_start=True) runner = CairoRunner( - program=program, layout='plain', proof_mode=True, allow_missing_builtins=True) + program=program, layout="plain", proof_mode=True, allow_missing_builtins=True + ) runner.initialize_segments() - output_builtin_runner = runner.builtin_runners['output'] = OutputBuiltinRunner(included=True) + output_builtin_runner = runner.builtin_runners["output"] = OutputBuiltinRunner(included=True) output_builtin_runner.initialize_segments(runner=runner) runner.initialize_main_entrypoint() runner.initialize_vm(hint_locals={}) @@ -39,32 +40,43 @@ def test_pages(runner_and_output_runner): output_builtin_runner.add_page(page_id=3, page_start=base + 9, page_size=3) # page_start must be in the output segment (base). - with pytest.raises(AssertionError, match='page_start must be in the output segment'): + with pytest.raises(AssertionError, match="page_start must be in the output segment"): output_builtin_runner.add_page( - page_id=4, page_start=RelocatableValue(999, 999), page_size=3) + page_id=4, page_start=RelocatableValue(999, 999), page_size=3 + ) runner.end_run() runner.finalize_segments() # A list of output cells and their page id. offset_page_pairs = [ - (0, 0), (1, 0), (2, 0), - (3, 1), (4, 1), (5, 1), (6, 1), - (7, 0), (8, 0), - (9, 3), (10, 3), (11, 3), - (12, 0), (13, 0), (14, 0), + (0, 0), + (1, 0), + (2, 0), + (3, 1), + (4, 1), + (5, 1), + (6, 1), + (7, 0), + (8, 0), + (9, 3), + (10, 3), + (11, 3), + (12, 0), + (13, 0), + (14, 0), ] - assert runner.segments.public_memory_offsets[base.segment_index] == \ - offset_page_pairs + assert runner.segments.public_memory_offsets[base.segment_index] == offset_page_pairs # Check that get_public_memory_addresses() returns the correct page_id for each value. # The program and execution segments are always in page 0. segment_offsets = {0: 0, 1: 10, 2: 100} assert runner.segments.get_public_memory_addresses(segment_offsets=segment_offsets) == ( - [(i, 0) for i in range(len(runner.program.data))] + # Program segment. - [(10, 0), (11, 0), (12, 0)] + # Execution segment. - [(100 + offset, page_id) for offset, page_id in offset_page_pairs]) # Output segment. + [(i, 0) for i in range(len(runner.program.data))] # Program segment. + + [(10, 0), (11, 0), (12, 0)] # Execution segment. + + [(100 + offset, page_id) for offset, page_id in offset_page_pairs] # Output segment. + ) def test_pages_collision(runner_and_output_runner): @@ -75,7 +87,7 @@ def test_pages_collision(runner_and_output_runner): output_builtin_runner.add_page(page_id=1, page_start=base + 10, page_size=4) output_builtin_runner.add_page(page_id=2, page_start=base + 12, page_size=4) runner.end_run() - with pytest.raises(AssertionError, match='Offset 12 was already assigned a page.'): + with pytest.raises(AssertionError, match="Offset 12 was already assigned a page."): output_builtin_runner.finalize_segments(runner=runner) @@ -88,5 +100,5 @@ def test_pages_out_of_bounds(runner_and_output_runner): output_builtin_runner.add_page(page_id=2, page_start=base + 7, page_size=4) output_builtin_runner.add_page(page_id=3, page_start=base + 11, page_size=2) runner.end_run() - with pytest.raises(AssertionError, match='Page 2 is out of bounds.'): + with pytest.raises(AssertionError, match="Page 2 is out of bounds."): output_builtin_runner.finalize_segments(runner=runner) diff --git a/src/starkware/cairo/lang/vm/reconstruct_traceback.py b/src/starkware/cairo/lang/vm/reconstruct_traceback.py index cc5f7505..613b8fb2 100755 --- a/src/starkware/cairo/lang/vm/reconstruct_traceback.py +++ b/src/starkware/cairo/lang/vm/reconstruct_traceback.py @@ -11,54 +11,62 @@ def reconstruct_traceback(program: Program, traceback_txt: str): def location_replacer(match: re.Match, keep_original_line: bool) -> str: - assert program.debug_info is not None, 'Missing debug information in the compiled program.' - pc = int(match.group('pc')) + assert program.debug_info is not None, "Missing debug information in the compiled program." + pc = int(match.group("pc")) instruction_location = program.debug_info.instruction_locations.get(pc) if instruction_location is None: # Return the text unchanged. return match.group(0) res = instruction_location.inst.to_string_with_content( - match.group(0) if keep_original_line else '') + match.group(0) if keep_original_line else "" + ) return res traceback_txt = re.sub( - r'Unknown location \(pc=0:(?P\d+)\)', + r"Unknown location \(pc=0:(?P\d+)\)", lambda match: location_replacer(match=match, keep_original_line=False), - traceback_txt) + traceback_txt, + ) traceback_txt = re.sub( - r'Error at pc=0:(?P\d+):', + r"Error at pc=0:(?P\d+):", lambda match: location_replacer(match=match, keep_original_line=True), - traceback_txt) + traceback_txt, + ) return traceback_txt def main(): parser = argparse.ArgumentParser( - description='A tool to reconstruct Cairo traceback given a compiled program with debug ' - 'information.') - parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') - parser.add_argument('--program', type=str, help='A path to the Cairo program.') - parser.add_argument('--contract', type=str, help='A path to the StarkNet contract.') + description="A tool to reconstruct Cairo traceback given a compiled program with debug " + "information." + ) + parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument("--program", type=str, help="A path to the Cairo program.") + parser.add_argument("--contract", type=str, help="A path to the StarkNet contract.") parser.add_argument( - '--traceback', type=str, required=True, - help='A path to the traceback file with the missing location information. ' - 'Use "-" to read the traceback from stdin.') + "--traceback", + type=str, + required=True, + help="A path to the traceback file with the missing location information. " + 'Use "-" to read the traceback from stdin.', + ) args = parser.parse_args() - assert (0 if args.program is None else 1) + (0 if args.contract is None else 1) == 1, \ - 'Exactly one of --program, --contract must be specified.' + assert (0 if args.program is None else 1) + ( + 0 if args.contract is None else 1 + ) == 1, "Exactly one of --program, --contract must be specified." if args.program is not None: program_json = json.load(open(args.program)) else: assert args.contract is not None - program_json = json.load(open(args.contract))['program'] + program_json = json.load(open(args.contract))["program"] program = Program.load(program_json) - traceback = (open(args.traceback) if args.traceback != '-' else sys.stdin).read() + traceback = (open(args.traceback) if args.traceback != "-" else sys.stdin).read() print(reconstruct_traceback(program, traceback)) return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/lang/vm/reconstruct_traceback_test.py b/src/starkware/cairo/lang/vm/reconstruct_traceback_test.py index c32665cd..e4125274 100644 --- a/src/starkware/cairo/lang/vm/reconstruct_traceback_test.py +++ b/src/starkware/cairo/lang/vm/reconstruct_traceback_test.py @@ -24,27 +24,32 @@ def test_reconstruct_traceback(): return () end """ - codes = [(code, 'filename')] + codes = [(code, "filename")] program_with_debug_info = compile_cairo(code=codes, prime=DEFAULT_PRIME, debug_info=True) program_without_debug_info = compile_cairo(code=codes, prime=DEFAULT_PRIME, debug_info=False) with pytest.raises(VmException) as exc: - get_main_runner(program=program_without_debug_info, hint_locals={}, layout='plain') + get_main_runner(program=program_without_debug_info, hint_locals={}, layout="plain") exception_str = str(exc.value) # The exception before calling reconstruct_traceback(). - assert exception_str == """\ + assert ( + exception_str + == """\ Error at pc=0:2: An ASSERT_EQ instruction failed: 1 != 0 Cairo traceback (most recent call last): Unknown location (pc=0:8) Unknown location (pc=0:5)\ """ + ) res = reconstruct_traceback(program=program_with_debug_info, traceback_txt=exception_str) # The exception after calling reconstruct_traceback(). - assert res == """\ + assert ( + res + == """\ filename:3:9: Error at pc=0:2: assert 0 = 1 ^**********^ @@ -57,3 +62,4 @@ def test_reconstruct_traceback(): bar() ^***^\ """ + ) diff --git a/src/starkware/cairo/lang/vm/relocatable.py b/src/starkware/cairo/lang/vm/relocatable.py index 28fb0ecb..075db20c 100644 --- a/src/starkware/cairo/lang/vm/relocatable.py +++ b/src/starkware/cairo/lang/vm/relocatable.py @@ -1,8 +1,8 @@ import dataclasses from typing import Dict, Tuple, TypeVar, Union -MaybeRelocatable = Union[int, 'RelocatableValue'] -T = TypeVar('T', int, MaybeRelocatable) +MaybeRelocatable = Union[int, "RelocatableValue"] +T = TypeVar("T", int, MaybeRelocatable) @dataclasses.dataclass(frozen=True) @@ -11,28 +11,31 @@ class RelocatableValue: A value in the cairo vm representing an address in some memory segment. This is meant to be replaced by a real memory address (field element) after the VM finished. """ + segment_index: int offset: int SEGMENT_BITS = 16 OFFSET_BITS = 47 - def __add__(self, other: MaybeRelocatable) -> 'RelocatableValue': + def __add__(self, other: MaybeRelocatable) -> "RelocatableValue": if isinstance(other, int): return RelocatableValue(self.segment_index, self.offset + other) - assert not isinstance(other, RelocatableValue), \ - f'Cannot add two relocatable values: {self} + {other}.' + assert not isinstance( + other, RelocatableValue + ), f"Cannot add two relocatable values: {self} + {other}." return NotImplemented - def __radd__(self, other: MaybeRelocatable) -> 'RelocatableValue': + def __radd__(self, other: MaybeRelocatable) -> "RelocatableValue": return self + other def __sub__(self, other: MaybeRelocatable) -> MaybeRelocatable: if isinstance(other, int): return RelocatableValue(self.segment_index, self.offset - other) - assert self.segment_index == other.segment_index, \ - 'Can only subtract two relocatable values of the same segment ' \ - f'({self.segment_index} != {other.segment_index}).' + assert self.segment_index == other.segment_index, ( + "Can only subtract two relocatable values of the same segment " + f"({self.segment_index} != {other.segment_index})." + ) return self.offset - other.offset def __mod__(self, other: int): @@ -47,7 +50,7 @@ def __lt__(self, other: MaybeRelocatable): return (self.segment_index, self.offset) < (other.segment_index, other.offset) def __le__(self, other: MaybeRelocatable): - return (self < other or self == other) + return self < other or self == other def __ge__(self, other: MaybeRelocatable): return not (self < other) @@ -59,10 +62,10 @@ def __hash__(self): return hash((self.segment_index, self.offset)) def __format__(self, format_spec): - return f'{self.segment_index}:{self.offset}'.__format__(format_spec) + return f"{self.segment_index}:{self.offset}".__format__(format_spec) def __str__(self): - return f'{self.segment_index}:{self.offset}' + return f"{self.segment_index}:{self.offset}" def to_bytes(self, n_bytes: int, byte_order: str) -> bytes: """ @@ -100,7 +103,7 @@ def to_tuple(value: MaybeRelocatable) -> Tuple[int, ...]: elif isinstance(value, int): return (value,) else: - raise NotImplementedError(f'Expected MaybeRelocatable, got: {type(value).__name__}.') + raise NotImplementedError(f"Expected MaybeRelocatable, got: {type(value).__name__}.") @classmethod def from_tuple(cls, value: Tuple[int, ...]) -> MaybeRelocatable: @@ -112,12 +115,15 @@ def from_tuple(cls, value: Tuple[int, ...]) -> MaybeRelocatable: elif len(value) == 1: return value[0] else: - raise NotImplementedError(f'Expected a tuple of size 1 or 2, got: {value}.') + raise NotImplementedError(f"Expected a tuple of size 1 or 2, got: {value}.") def relocate_value( - value: MaybeRelocatable, segment_offsets: Dict[int, T], prime: int, - allow_missing_segments: bool = False) -> T: + value: MaybeRelocatable, + segment_offsets: Dict[int, T], + prime: int, + allow_missing_segments: bool = False, +) -> T: if isinstance(value, int): return value elif isinstance(value, RelocatableValue): @@ -134,4 +140,4 @@ def relocate_value( assert value < prime return value else: - raise NotImplementedError('Not relocatable') + raise NotImplementedError("Not relocatable") diff --git a/src/starkware/cairo/lang/vm/relocatable_fields.py b/src/starkware/cairo/lang/vm/relocatable_fields.py index 57b9d38b..e2581262 100644 --- a/src/starkware/cairo/lang/vm/relocatable_fields.py +++ b/src/starkware/cairo/lang/vm/relocatable_fields.py @@ -28,10 +28,8 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None: return None return [ - (RelocatableValue.to_tuple(x), RelocatableValue.to_tuple(y)) - for x, y in value.items()] + (RelocatableValue.to_tuple(x), RelocatableValue.to_tuple(y)) for x, y in value.items() + ] def _deserialize(self, value, attr, data, **kwargs): - return { - RelocatableValue.from_tuple(x): RelocatableValue.from_tuple(y) - for x, y in value} + return {RelocatableValue.from_tuple(x): RelocatableValue.from_tuple(y) for x, y in value} diff --git a/src/starkware/cairo/lang/vm/relocatable_fields_test.py b/src/starkware/cairo/lang/vm/relocatable_fields_test.py index 5040a16e..d0c33171 100644 --- a/src/starkware/cairo/lang/vm/relocatable_fields_test.py +++ b/src/starkware/cairo/lang/vm/relocatable_fields_test.py @@ -6,14 +6,17 @@ from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.cairo.lang.vm.relocatable_fields import ( - MaybeRelocatableDictField, MaybeRelocatableField) + MaybeRelocatableDictField, + MaybeRelocatableField, +) @marshmallow_dataclass.dataclass class DummyStruct: val: MaybeRelocatable = field(metadata=dict(marshmallow_field=MaybeRelocatableField())) dct: Dict[MaybeRelocatable, MaybeRelocatable] = field( - metadata=dict(marshmallow_field=MaybeRelocatableDictField())) + metadata=dict(marshmallow_field=MaybeRelocatableDictField()) + ) Schema: ClassVar[Type[marshmallow.Schema]] = marshmallow.Schema diff --git a/src/starkware/cairo/lang/vm/relocatable_test.py b/src/starkware/cairo/lang/vm/relocatable_test.py index d6380ddd..32324f30 100644 --- a/src/starkware/cairo/lang/vm/relocatable_test.py +++ b/src/starkware/cairo/lang/vm/relocatable_test.py @@ -38,13 +38,16 @@ def test_relocatable_inequalities(): assert not (y <= x) -@pytest.mark.parametrize('byte_order', ['little', 'big']) -@pytest.mark.parametrize('n_bytes', [16, 32]) +@pytest.mark.parametrize("byte_order", ["little", "big"]) +@pytest.mark.parametrize("n_bytes", [16, 32]) def test_relocatable_value_serialization(byte_order, n_bytes): for num in [19, RelocatableValue(2, 5)]: - assert RelocatableValue.from_bytes( - RelocatableValue.to_bytes(num, n_bytes, byte_order), - byte_order) == num + assert ( + RelocatableValue.from_bytes( + RelocatableValue.to_bytes(num, n_bytes, byte_order), byte_order + ) + == num + ) def test_to_tuple_from_tuple(): @@ -59,5 +62,6 @@ def test_to_tuple_from_tuple(): def test_relocatable_value_frozen(): x = RelocatableValue(1, 2) with pytest.raises( - dataclasses.FrozenInstanceError, match="cannot assign to field 'no_such_field'"): + dataclasses.FrozenInstanceError, match="cannot assign to field 'no_such_field'" + ): x.no_such_field = 5 diff --git a/src/starkware/cairo/lang/vm/security.py b/src/starkware/cairo/lang/vm/security.py index 3b644e63..14c97c1f 100644 --- a/src/starkware/cairo/lang/vm/security.py +++ b/src/starkware/cairo/lang/vm/security.py @@ -24,21 +24,22 @@ def verify_secure_runner(runner: CairoRunner, verify_builtins=True): for addr in runner.vm_memory: # Check pure addresses. if not isinstance(addr, RelocatableValue): - raise SecurityError(f'Accessed address {addr} is not relocatable.') + raise SecurityError(f"Accessed address {addr} is not relocatable.") # Check non negative offset. if addr.offset < 0: - raise SecurityError(f'Accessed address {addr} has negative offset.') + raise SecurityError(f"Accessed address {addr} has negative offset.") # Check builtin segment out of bounds. if addr.segment_index in builtin_segment_sizes: if not addr.offset < builtin_segment_sizes[addr.segment_index]: raise SecurityError( - 'Out of bounds access to builtin segment ' - f'{builtin_segment_names[addr.segment_index]} at {addr}.') + "Out of bounds access to builtin segment " + f"{builtin_segment_names[addr.segment_index]} at {addr}." + ) # Check out of bounds for program segment. if addr.segment_index == runner.program_base.segment_index: if not addr.offset < len(runner.program.data): - raise SecurityError(f'Out of bounds access to program segment at {addr}.') + raise SecurityError(f"Out of bounds access to program segment at {addr}.") # Builtin specific checks. try: diff --git a/src/starkware/cairo/lang/vm/security_test.py b/src/starkware/cairo/lang/vm/security_test.py index 11e71bdc..a279789e 100644 --- a/src/starkware/cairo/lang/vm/security_test.py +++ b/src/starkware/cairo/lang/vm/security_test.py @@ -5,36 +5,44 @@ from starkware.cairo.lang.vm.relocatable import RelocatableValue from starkware.cairo.lang.vm.security import SecurityError, verify_secure_runner -PRIME = 2**251 + 17 * 2**192 + 1 +PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 -def run_code_in_runner(code, layout='plain'): +def run_code_in_runner(code, layout="plain"): return get_runner_from_code(code=code, layout=layout, prime=PRIME) def test_completeness(): - verify_secure_runner(run_code_in_runner(""" + verify_secure_runner( + run_code_in_runner( + """ main: [ap] = 1 ret -""")) +""" + ) + ) def test_negative_address(): - runner = run_code_in_runner(""" + runner = run_code_in_runner( + """ main: [ap] = 0; ap++ ret -""") +""" + ) # Access negative offset manually, so it is not taken modulo prime. runner.vm_memory.set_without_checks(RelocatableValue(segment_index=0, offset=-17), 0) - with pytest.raises(SecurityError, match='Accessed address 0:-17 has negative offset.'): + with pytest.raises(SecurityError, match="Accessed address 0:-17 has negative offset."): verify_secure_runner(runner) def test_out_of_program_bounds(): - with pytest.raises(SecurityError, match='Out of bounds access to program segment'): - verify_secure_runner(run_code_in_runner(""" + with pytest.raises(SecurityError, match="Out of bounds access to program segment"): + verify_secure_runner( + run_code_in_runner( + """ main: call test ret @@ -42,26 +50,32 @@ def test_out_of_program_bounds(): [ap] = [fp - 1] # pc. [ap] = [[ap] + 4] # Write right after end of program. ret -""")) +""" + ) + ) def test_pure_address_access(): - runner = run_code_in_runner(""" + runner = run_code_in_runner( + """ main: [fp - 1] = [fp - 1] # nop. ret -""") +""" + ) # Access a pure address manually, because runner disallows it as well. runner.vm_memory.unfreeze_for_testing() runner.vm_memory[1234] = 1 - with pytest.raises(SecurityError, match='Accessed address 1234 is not relocatable.'): + with pytest.raises(SecurityError, match="Accessed address 1234 is not relocatable."): verify_secure_runner(runner) def test_builtin_segment_access(): with get_crypto_lib_context_manager(flavor=None): - verify_secure_runner(run_code_in_runner(""" + verify_secure_runner( + run_code_in_runner( + """ %builtins pedersen main: [ap] = 1; ap++ @@ -70,26 +84,34 @@ def test_builtin_segment_access(): [ap] = [[fp - 3] + 2]; ap++ # Read hash result. [ap] = [fp - 3] + 3; ap++ # Return pedersen_ptr. ret -""", layout='small')) +""", + layout="small", + ) + ) # Out of bound is not ok. - runner = run_code_in_runner(""" + runner = run_code_in_runner( + """ %builtins pedersen main: [fp - 1] = [[fp - 3] + 2] # Access only the result portion of the builtin. [ap] = [fp - 3] + 3; ap++ # Return pedersen_ptr. ret -""", layout='small') +""", + layout="small", + ) # Access out of bounds manually, because runner disallows it as well. - pedersen_base = runner.builtin_runners['pedersen_builtin'].base + pedersen_base = runner.builtin_runners["pedersen_builtin"].base runner.vm_memory.unfreeze_for_testing() runner.vm_memory[pedersen_base + 7] = 1 - with pytest.raises(SecurityError, match='Out of bounds access to builtin segment pedersen'): + with pytest.raises(SecurityError, match="Out of bounds access to builtin segment pedersen"): verify_secure_runner(runner) # Invalid segment size (only first input is written). - with pytest.raises(SecurityError, match=r'Missing memory cells for pedersen: 1, 4\.'): - verify_secure_runner(run_code_in_runner(""" + with pytest.raises(SecurityError, match=r"Missing memory cells for pedersen: 1, 4\."): + verify_secure_runner( + run_code_in_runner( + """ %builtins pedersen func main{pedersen_ptr}(): assert [pedersen_ptr] = 0 @@ -99,4 +121,7 @@ def test_builtin_segment_access(): let pedersen_ptr = pedersen_ptr + 6 return () end -""", layout='small')) +""", + layout="small", + ) + ) diff --git a/src/starkware/cairo/lang/vm/trace_entry.py b/src/starkware/cairo/lang/vm/trace_entry.py index 81db3feb..dfac0149 100644 --- a/src/starkware/cairo/lang/vm/trace_entry.py +++ b/src/starkware/cairo/lang/vm/trace_entry.py @@ -4,7 +4,7 @@ from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, relocate_value -T = TypeVar('T', int, MaybeRelocatable) +T = TypeVar("T", int, MaybeRelocatable) @dataclasses.dataclass @@ -13,6 +13,7 @@ class TraceEntry(Generic[T]): A trace entry for every instruction that was executed. Holds the register values before the instruction was executed. """ + pc: T ap: T fp: T @@ -27,13 +28,13 @@ def serialize(self) -> bytes: for x in values: assert isinstance(x, int) assert 0 <= x < 2 ** 64 - return struct.pack('<3Q', *values) + return struct.pack("<3Q", *values) @classmethod - def deserialize(cls, serialized: bytes) -> 'TraceEntry': - assert len(serialized) == cls.serialization_size(), 'Unexpected input length.' + def deserialize(cls, serialized: bytes) -> "TraceEntry": + assert len(serialized) == cls.serialization_size(), "Unexpected input length." - ap, fp, pc = struct.unpack('<3Q', serialized) + ap, fp, pc = struct.unpack("<3Q", serialized) return cls( pc=pc, @@ -47,17 +48,22 @@ def serialization_size(): def relocate_trace( - trace: List[TraceEntry[MaybeRelocatable]], segment_offsets: Dict[int, T], prime: int, - allow_missing_segments: bool = False) -> List[TraceEntry[T]]: + trace: List[TraceEntry[MaybeRelocatable]], + segment_offsets: Dict[int, T], + prime: int, + allow_missing_segments: bool = False, +) -> List[TraceEntry[T]]: new_trace: List[TraceEntry[T]] = [] def relocate_val(x): return relocate_value(x, segment_offsets, prime, allow_missing_segments) for entry in trace: - new_trace.append(TraceEntry( - pc=relocate_val(entry.pc), - ap=relocate_val(entry.ap), - fp=relocate_val(entry.fp), - )) + new_trace.append( + TraceEntry( + pc=relocate_val(entry.pc), + ap=relocate_val(entry.ap), + fp=relocate_val(entry.fp), + ) + ) return new_trace diff --git a/src/starkware/cairo/lang/vm/trace_entry_test.py b/src/starkware/cairo/lang/vm/trace_entry_test.py index f0dc3321..37955330 100644 --- a/src/starkware/cairo/lang/vm/trace_entry_test.py +++ b/src/starkware/cairo/lang/vm/trace_entry_test.py @@ -4,14 +4,21 @@ def test_trace_entry_serialization(): # Test serialization of a TraceEntry (values taken from the instruction # "[ap] = [ap - 1] + 2; ap++"). - entry = TraceEntry(ap=0x66, fp=0x64, pc=0xa) + entry = TraceEntry(ap=0x66, fp=0x64, pc=0xA) serialized = entry.serialize() assert len(serialized) == TraceEntry.serialization_size() - assert serialized.hex() == """ + assert ( + serialized.hex() + == """ 66 00 00 00 00 00 00 00 64 00 00 00 00 00 00 00 0a 00 00 00 00 00 00 00 -""".replace(' ', '').replace('\n', '') +""".replace( + " ", "" + ).replace( + "\n", "" + ) + ) # Test deserialization. assert TraceEntry.deserialize(serialized).serialize() == serialized diff --git a/src/starkware/cairo/lang/vm/utils.py b/src/starkware/cairo/lang/vm/utils.py index fd78b1be..9dc5301b 100644 --- a/src/starkware/cairo/lang/vm/utils.py +++ b/src/starkware/cairo/lang/vm/utils.py @@ -11,7 +11,7 @@ class IntAsHex(mfields.Field): field elements. """ - default_error_messages = {'invalid': 'Expected hex string, got: "{input}".'} + default_error_messages = {"invalid": 'Expected hex string, got: "{input}".'} def _serialize(self, value, attr, obj, **kwargs): if value is None: @@ -20,8 +20,8 @@ def _serialize(self, value, attr, obj, **kwargs): return hex(value) def _deserialize(self, value, attr, data, **kwargs): - if re.match('^0x[0-9a-f]+$', value) is None: - self.fail('invalid', input=value) + if re.match("^0x[0-9a-f]+$", value) is None: + self.fail("invalid", input=value) return int(value, 16) @@ -48,6 +48,7 @@ class RunResources: """ Maintains the resources of a Cairo run. Can be used across multiple runners. """ + steps: Optional[int] @property diff --git a/src/starkware/cairo/lang/vm/validated_memory_dict.py b/src/starkware/cairo/lang/vm/validated_memory_dict.py index a4b7084a..718550e2 100644 --- a/src/starkware/cairo/lang/vm/validated_memory_dict.py +++ b/src/starkware/cairo/lang/vm/validated_memory_dict.py @@ -3,7 +3,7 @@ from starkware.cairo.lang.vm.memory_dict import MemoryDict from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue -ValidationRule = Callable[['MemoryDict', RelocatableValue], Set[RelocatableValue]] +ValidationRule = Callable[["MemoryDict", RelocatableValue], Set[RelocatableValue]] class ValidatedMemoryDict: @@ -29,8 +29,8 @@ def __setitem__(self, addr: MaybeRelocatable, value: MaybeRelocatable): self._validate_memory_cell(addr, value) def __getattr__(self, name: str): - if name in ['__deepcopy__', '__getstate__', '__setstate__']: - raise AttributeError(f'ValidatedMemoryDict has no attribute named {name}.') + if name in ["__deepcopy__", "__getstate__", "__setstate__"]: + raise AttributeError(f"ValidatedMemoryDict has no attribute named {name}.") return getattr(self.__memory, name) def __iter__(self): diff --git a/src/starkware/cairo/lang/vm/validated_memory_dict_test.py b/src/starkware/cairo/lang/vm/validated_memory_dict_test.py index 170f399c..d36cdf77 100644 --- a/src/starkware/cairo/lang/vm/validated_memory_dict_test.py +++ b/src/starkware/cairo/lang/vm/validated_memory_dict_test.py @@ -21,8 +21,9 @@ def rule_identical_pairs(mem, addr): return set() def rule_constant_value(mem, addr, constant): - assert mem[addr] == constant, \ - f'Expected value in address {addr} to be {constant}, got {mem[addr]}.' + assert ( + mem[addr] == constant + ), f"Expected value in address {addr} to be {constant}, got {mem[addr]}." return {addr} memory_validator.add_validation_rule(1, lambda memory, addr: set()) @@ -52,11 +53,9 @@ def rule_constant_value(mem, addr, constant): memory_validator.validate_existing_memory() # Invalidate existing memory and test negative case. - with pytest.raises( - AssertionError, match='Expected value in address 4:0 to be 0, got 1.'): + with pytest.raises(AssertionError, match="Expected value in address 4:0 to be 0, got 1."): memory_validator[addr4] = 1 # Test validation of existing invalid memory. - with pytest.raises( - AssertionError, match='Expected value in address 4:0 to be 0, got 1.'): + with pytest.raises(AssertionError, match="Expected value in address 4:0 to be 0, got 1."): memory_validator.validate_existing_memory() diff --git a/src/starkware/cairo/lang/vm/vm.py b/src/starkware/cairo/lang/vm/vm.py index 790bacc3..0688f98d 100644 --- a/src/starkware/cairo/lang/vm/vm.py +++ b/src/starkware/cairo/lang/vm/vm.py @@ -11,7 +11,10 @@ from starkware.cairo.lang.compiler.error_handling import LocationError from starkware.cairo.lang.compiler.expression_evaluator import ExpressionEvaluator from starkware.cairo.lang.compiler.instruction import ( - Instruction, Register, decode_instruction_values) + Instruction, + Register, + decode_instruction_values, +) from starkware.cairo.lang.compiler.program import Program, ProgramBase from starkware.cairo.lang.vm.builtin_runner import BuiltinRunner from starkware.cairo.lang.vm.memory_dict import MemoryDict @@ -21,7 +24,7 @@ from starkware.cairo.lang.vm.vm_consts import VmConsts, VmConstsContext from starkware.python.math_utils import div_mod -Rule = Callable[['VirtualMachine', RelocatableValue], Optional[int]] +Rule = Callable[["VirtualMachine", RelocatableValue], Optional[int]] MAX_TRACEBACK_ENTRIES = 20 @@ -46,9 +49,14 @@ class VmExceptionBase(Exception): class VmException(LocationError, VmExceptionBase): def __init__( - self, pc, inst_location: Optional[InstructionLocation], inner_exc, - traceback: Optional[str] = None, notes: Optional[List[str]] = None, - hint_index: Optional[int] = None): + self, + pc, + inst_location: Optional[InstructionLocation], + inner_exc, + traceback: Optional[str] = None, + notes: Optional[List[str]] = None, + hint_index: Optional[int] = None, + ): self.pc = pc self.inner_exc = inner_exc location = None @@ -60,7 +68,8 @@ def __init__( if hint_location is not None: location = hint_location.location super().__init__( - f'Error at pc={self.pc}:\n{inner_exc}', location=location, traceback=traceback) + f"Error at pc={self.pc}:\n{inner_exc}", location=location, traceback=traceback + ) if notes is not None: self.notes += notes @@ -71,16 +80,16 @@ def __init__(self, addr, current_value, new_value): self.current_value = current_value self.new_value = new_value super().__init__( - f'Inconsistent auto deduction rule at address {addr}. {current_value} != {new_value}.') + f"Inconsistent auto deduction rule at address {addr}. {current_value} != {new_value}." + ) class PureValueError(VmExceptionBase): def __init__(self, oper, *values): self.oper = oper self.values = values - values_str = f'values {values}' if len(values) > 1 else f'value {values[0]}' - super().__init__( - f'Could not complete computation {oper} of non pure {values_str}.') + values_str = f"values {values}" if len(values) > 1 else f"value {values[0]}" + super().__init__(f"Could not complete computation {oper} of non pure {values_str}.") class HintException(VmExceptionBase): @@ -89,12 +98,13 @@ def __init__(self, vm, exc_type, exc_value, exc_tb): fix = self.fix_name_and_line(vm, exc_value) if fix is not None: filename, line_num = fix - exc_value = IndentationError(exc_value.msg, ( - filename, line_num, exc_value.offset, exc_value.text)) + exc_value = IndentationError( + exc_value.msg, (filename, line_num, exc_value.offset, exc_value.text) + ) tb_exception = traceback.TracebackException(exc_type, exc_value, exc_tb) # First item in the traceback is the call to exec, remove it. - assert tb_exception.stack[0].filename.endswith('vm.py') + assert tb_exception.stack[0].filename.endswith("vm.py") del tb_exception.stack[0] # If we have location information, replace '' entries with the correct filename @@ -104,31 +114,29 @@ def replace_stack_item(item: traceback.FrameSummary) -> traceback.FrameSummary: if fix is None: return item filename, line_num = fix - return traceback.FrameSummary( - filename=filename, - lineno=line_num, - name=item.name) + return traceback.FrameSummary(filename=filename, lineno=line_num, name=item.name) + tb_exception.stack = traceback.StackSummary.from_list( - map(replace_stack_item, tb_exception.stack)) - super().__init__(f'Got an exception while executing a hint.') - self.exception_str = ''.join(tb_exception.format()) + map(replace_stack_item, tb_exception.stack) + ) + super().__init__(f"Got an exception while executing a hint.") + self.exception_str = "".join(tb_exception.format()) self.inner_exc = exc_value ExcType = Union[IndentationError, SyntaxError, traceback.FrameSummary] - def fix_name_and_line( - self, vm, exc_value: ExcType) -> Optional[Tuple[str, int]]: - m = re.match('^[0-9]+)>$', str(exc_value.filename)) + def fix_name_and_line(self, vm, exc_value: ExcType) -> Optional[Tuple[str, int]]: + m = re.match("^[0-9]+)>$", str(exc_value.filename)) if m is None: return None - pc, index = vm.hint_pc_and_index[int(m.group('index'))] + pc, index = vm.hint_pc_and_index[int(m.group("index"))] location = vm.get_location(pc) if (location is None) or (location.hints[index] is None): return None hint_location = location.hints[index] start_line = hint_location.location.start_line prefix_lines = hint_location.n_prefix_newlines - line_num = (exc_value.lineno + start_line + prefix_lines - 1) + line_num = exc_value.lineno + start_line + prefix_lines - 1 filename = hint_location.location.input_file.filename return filename, line_num @@ -152,8 +160,9 @@ def get_instruction_encoding(self) -> Tuple[int, Optional[int]]: """ instruction_encoding = self.memory[self.pc] - assert isinstance(instruction_encoding, int), \ - f'Instruction should be an int. Found: {instruction_encoding}' + assert isinstance( + instruction_encoding, int + ), f"Instruction should be an int. Found: {instruction_encoding}" imm_addr = (self.pc + 1) % self.prime return instruction_encoding, self.memory.get(imm_addr) @@ -165,7 +174,7 @@ def compute_dst_addr(self, instruction: Instruction): elif instruction.dst_register is Register.FP: base_addr = self.fp else: - raise NotImplementedError('Invalid dst_register value') + raise NotImplementedError("Invalid dst_register value") return (base_addr + instruction.off0) % self.prime def compute_op0_addr(self, instruction: Instruction): @@ -175,7 +184,7 @@ def compute_op0_addr(self, instruction: Instruction): elif instruction.op0_register is Register.FP: base_addr = self.fp else: - raise NotImplementedError('Invalid op0_register value') + raise NotImplementedError("Invalid op0_register value") return (base_addr + instruction.off1) % self.prime def compute_op1_addr(self, instruction: Instruction, op0: Optional[MaybeRelocatable]): @@ -185,13 +194,13 @@ def compute_op1_addr(self, instruction: Instruction, op0: Optional[MaybeRelocata elif instruction.op1_addr is Instruction.Op1Addr.AP: base_addr = self.ap elif instruction.op1_addr is Instruction.Op1Addr.IMM: - assert instruction.off2 == 1, 'In immediate mode, off2 should be 1.' + assert instruction.off2 == 1, "In immediate mode, off2 should be 1." base_addr = self.pc elif instruction.op1_addr is Instruction.Op1Addr.OP0: - assert op0 is not None, 'op0 must be known in double dereference.' + assert op0 is not None, "op0 must be known in double dereference." base_addr = op0 else: - raise NotImplementedError('Invalid op1_register value') + raise NotImplementedError("Invalid op1_register value") return (base_addr + instruction.off2) % self.prime def get_traceback_entries(self): @@ -219,10 +228,14 @@ def get_traceback_entries(self): # instruction1 (with no immediate). # In rare cases this may be ambiguous. if instruction1 is not None and is_call_instruction( - encoded_instruction=instruction1, imm=None): + encoded_instruction=instruction1, imm=None + ): call_pc = ret_pc - 1 - elif instruction0 is not None and instruction1 is not None and is_call_instruction( - encoded_instruction=instruction0, imm=instruction1): + elif ( + instruction0 is not None + and instruction1 is not None + and is_call_instruction(encoded_instruction=instruction0, imm=instruction1) + ): call_pc = ret_pc - 2 else: # If none of them seems like the calling instruction, abort. @@ -241,9 +254,14 @@ class CompiledHint: class VirtualMachine: def __init__( - self, program: ProgramBase, run_context: RunContext, - hint_locals: Dict[str, Any], static_locals: Optional[Dict[str, Any]] = None, - builtin_runners: Dict[str, BuiltinRunner] = {}, program_base: Optional[int] = None): + self, + program: ProgramBase, + run_context: RunContext, + hint_locals: Dict[str, Any], + static_locals: Optional[Dict[str, Any]] = None, + builtin_runners: Dict[str, BuiltinRunner] = {}, + program_base: Optional[int] = None, + ): """ hints - a dictionary from memory addresses to an executable object. When the pc points to the memory address, before the execution of the instruction, @@ -273,7 +291,8 @@ def __init__( # A set to track the memory addresses accessed by actual Cairo instructions (as opposed to # hints), necessary for accurate counting of memory holes. self.accessed_addresses: Set[MaybeRelocatable] = { - self.program_base + i for i in range(len(self.program.data))} + self.program_base + i for i in range(len(self.program.data)) + } # If program is a StrippedProgram, there are no hints or debug information to load. if isinstance(program, Program): @@ -296,18 +315,21 @@ def __init__( self.skip_instruction_execution = False from starkware.python import math_utils + self.static_locals = static_locals.copy() if static_locals is not None else {} - self.static_locals.update({ - 'PRIME': self.prime, - 'fadd': lambda a, b, p=self.prime: (a + b) % p, - 'fsub': lambda a, b, p=self.prime: (a - b) % p, - 'fmul': lambda a, b, p=self.prime: (a * b) % p, - 'fdiv': lambda a, b, p=self.prime: math_utils.div_mod(a, b, p), - 'fpow': lambda a, b, p=self.prime: pow(a, b, p), - 'fis_quad_residue': lambda a, p=self.prime: math_utils.is_quad_residue(a, p), - 'fsqrt': lambda a, p=self.prime: math_utils.sqrt(a, p), - 'safe_div': math_utils.safe_div, - }) + self.static_locals.update( + { + "PRIME": self.prime, + "fadd": lambda a, b, p=self.prime: (a + b) % p, + "fsub": lambda a, b, p=self.prime: (a - b) % p, + "fmul": lambda a, b, p=self.prime: (a * b) % p, + "fdiv": lambda a, b, p=self.prime: math_utils.div_mod(a, b, p), + "fpow": lambda a, b, p=self.prime: pow(a, b, p), + "fis_quad_residue": lambda a, p=self.prime: math_utils.is_quad_residue(a, p), + "fsqrt": lambda a, p=self.prime: math_utils.sqrt(a, p), + "safe_div": math_utils.safe_div, + } + ) def validate_existing_memory(self): """ @@ -322,21 +344,28 @@ def load_hints(self, program: Program, program_base: MaybeRelocatable): for hint_index, hint in enumerate(hints): hint_id = len(self.hint_pc_and_index) self.hint_pc_and_index[hint_id] = (pc + program_base, hint_index) - compiled_hints.append(CompiledHint( - compiled=self.compile_hint( - hint.code, f'', hint_index=hint_index), - # Use hint=hint in the lambda's arguments to capture this value (otherwise, it - # will use the same hint object for all iterations). - consts=lambda pc, ap, fp, memory, hint=hint: VmConsts( - context=VmConstsContext( - identifiers=program.identifiers, - evaluator=ExpressionEvaluator( - self.prime, ap, fp, memory, program.identifiers).eval, - reference_manager=program.reference_manager, - flow_tracking_data=hint.flow_tracking_data, - memory=memory, - pc=pc), - accessible_scopes=hint.accessible_scopes))) + compiled_hints.append( + CompiledHint( + compiled=self.compile_hint( + hint.code, f"", hint_index=hint_index + ), + # Use hint=hint in the lambda's arguments to capture this value (otherwise, it + # will use the same hint object for all iterations). + consts=lambda pc, ap, fp, memory, hint=hint: VmConsts( + context=VmConstsContext( + identifiers=program.identifiers, + evaluator=ExpressionEvaluator( + self.prime, ap, fp, memory, program.identifiers + ).eval, + reference_manager=program.reference_manager, + flow_tracking_data=hint.flow_tracking_data, + memory=memory, + pc=pc, + ), + accessible_scopes=hint.accessible_scopes, + ), + ) + ) self.hints[pc + program_base] = compiled_hints def load_debug_info(self, debug_info: Optional[DebugInfo], program_base: MaybeRelocatable): @@ -349,8 +378,9 @@ def load_debug_info(self, debug_info: Optional[DebugInfo], program_base: MaybeRe self.instruction_debug_info[program_base + offset] = location_info def load_program(self, program: Program, program_base: MaybeRelocatable): - assert self.prime == program.prime, \ - f'Unexpected prime for loaded program: {program.prime} != {self.prime}.' + assert ( + self.prime == program.prime + ), f"Unexpected prime for loaded program: {program.prime} != {self.prime}." self.load_debug_info(program.debug_info, program_base) self.load_hints(program, program_base) @@ -370,7 +400,7 @@ def enter_scope(self, new_scope_locals: Optional[dict] = None): self.exec_scopes.append({**new_scope_locals, **self.builtin_runners}) def exit_scope(self): - assert len(self.exec_scopes) > 1, 'Cannot exit main scope.' + assert len(self.exec_scopes) > 1, "Cannot exit main scope." self.exec_scopes.pop() def update_registers(self, instruction: Instruction, operands: Operands): @@ -380,19 +410,19 @@ def update_registers(self, instruction: Instruction, operands: Operands): elif instruction.fp_update is Instruction.FpUpdate.DST: self.run_context.fp = operands.dst elif instruction.fp_update is not Instruction.FpUpdate.REGULAR: - raise NotImplementedError('Invalid fp_update value') + raise NotImplementedError("Invalid fp_update value") # Update ap. if instruction.ap_update is Instruction.ApUpdate.ADD: if operands.res is None: - raise NotImplementedError('Res.UNCONSTRAINED cannot be used with ApUpdate.ADD') + raise NotImplementedError("Res.UNCONSTRAINED cannot be used with ApUpdate.ADD") self.run_context.ap += operands.res % self.prime elif instruction.ap_update is Instruction.ApUpdate.ADD1: self.run_context.ap += 1 elif instruction.ap_update is Instruction.ApUpdate.ADD2: self.run_context.ap += 2 elif instruction.ap_update is not Instruction.ApUpdate.REGULAR: - raise NotImplementedError('Invalid ap_update value') + raise NotImplementedError("Invalid ap_update value") self.run_context.ap = self.run_context.ap % self.prime # Update pc. @@ -402,13 +432,13 @@ def update_registers(self, instruction: Instruction, operands: Operands): self.run_context.pc += instruction.size elif instruction.pc_update is Instruction.PcUpdate.JUMP: if operands.res is None: - raise NotImplementedError('Res.UNCONSTRAINED cannot be used with PcUpdate.JUMP') + raise NotImplementedError("Res.UNCONSTRAINED cannot be used with PcUpdate.JUMP") self.run_context.pc = operands.res elif instruction.pc_update is Instruction.PcUpdate.JUMP_REL: if operands.res is None: - raise NotImplementedError('Res.UNCONSTRAINED cannot be used with PcUpdate.JUMP_REL') + raise NotImplementedError("Res.UNCONSTRAINED cannot be used with PcUpdate.JUMP_REL") if not isinstance(operands.res, int): - raise PureValueError('jmp rel', operands.res) + raise PureValueError("jmp rel", operands.res) self.run_context.pc += operands.res elif instruction.pc_update is Instruction.PcUpdate.JNZ: if self.is_zero(operands.dst): @@ -416,59 +446,76 @@ def update_registers(self, instruction: Instruction, operands: Operands): else: self.run_context.pc += operands.op1 else: - raise NotImplementedError('Invalid pc_update value') + raise NotImplementedError("Invalid pc_update value") self.run_context.pc = self.run_context.pc % self.prime def deduce_op0( - self, instruction: Instruction, dst: Optional[MaybeRelocatable], - op1: Optional[MaybeRelocatable]) -> \ - Tuple[Optional[MaybeRelocatable], Optional[MaybeRelocatable]]: + self, + instruction: Instruction, + dst: Optional[MaybeRelocatable], + op1: Optional[MaybeRelocatable], + ) -> Tuple[Optional[MaybeRelocatable], Optional[MaybeRelocatable]]: if instruction.opcode is Instruction.Opcode.CALL: return self.run_context.pc + instruction.size, None elif instruction.opcode is Instruction.Opcode.ASSERT_EQ: - if (instruction.res is Instruction.Res.ADD) and (dst is not None) and \ - (op1 is not None): + if (instruction.res is Instruction.Res.ADD) and (dst is not None) and (op1 is not None): return (dst - op1) % self.prime, dst # type: ignore - elif (instruction.res is Instruction.Res.MUL) and isinstance(dst, int) and \ - isinstance(op1, int) and op1 != 0: + elif ( + (instruction.res is Instruction.Res.MUL) + and isinstance(dst, int) + and isinstance(op1, int) + and op1 != 0 + ): return div_mod(dst, op1, self.prime), dst return None, None def deduce_op1( - self, instruction: Instruction, dst: Optional[MaybeRelocatable], - op0: Optional[MaybeRelocatable]) -> \ - Tuple[Optional[MaybeRelocatable], Optional[MaybeRelocatable]]: + self, + instruction: Instruction, + dst: Optional[MaybeRelocatable], + op0: Optional[MaybeRelocatable], + ) -> Tuple[Optional[MaybeRelocatable], Optional[MaybeRelocatable]]: if instruction.opcode is Instruction.Opcode.ASSERT_EQ: if (instruction.res is Instruction.Res.OP1) and (dst is not None): return dst, dst - elif (instruction.res is Instruction.Res.ADD) and (dst is not None) and \ - (op0 is not None): + elif ( + (instruction.res is Instruction.Res.ADD) and (dst is not None) and (op0 is not None) + ): return (dst - op0) % self.prime, dst # type: ignore - elif (instruction.res is Instruction.Res.MUL) and isinstance(dst, int) and \ - isinstance(op0, int) and op0 != 0: + elif ( + (instruction.res is Instruction.Res.MUL) + and isinstance(dst, int) + and isinstance(op0, int) + and op0 != 0 + ): return div_mod(dst, op0, self.prime), dst return None, None def compute_res( - self, instruction: Instruction, op0: MaybeRelocatable, op1: MaybeRelocatable, - op0_addr: MaybeRelocatable) -> Optional[MaybeRelocatable]: + self, + instruction: Instruction, + op0: MaybeRelocatable, + op1: MaybeRelocatable, + op0_addr: MaybeRelocatable, + ) -> Optional[MaybeRelocatable]: if instruction.res is Instruction.Res.OP1: return op1 elif instruction.res is Instruction.Res.ADD: return (op0 + op1) % self.prime elif instruction.res is Instruction.Res.MUL: if isinstance(op0, RelocatableValue) or isinstance(op1, RelocatableValue): - raise PureValueError('*', op0, op1) + raise PureValueError("*", op0, op1) return (op0 * op1) % self.prime elif instruction.res is Instruction.Res.UNCONSTRAINED: # In this case res should be the inverse of dst. # For efficiency, we do not compute it here. return None else: - raise NotImplementedError('Invalid res value') + raise NotImplementedError("Invalid res value") - def compute_operands(self, instruction: Instruction) -> \ - Tuple[Operands, List[int], List[MaybeRelocatable]]: + def compute_operands( + self, instruction: Instruction + ) -> Tuple[Operands, List[int], List[MaybeRelocatable]]: """ Computes the values of the operands. Deduces dst if needed. Returns: @@ -549,11 +596,11 @@ def compute_operands(self, instruction: Instruction) -> \ if should_update_op1: self.validated_memory[op1_addr] = op1 - return Operands( - dst=dst, - op0=op0, - op1=op1, - res=res), [dst_addr, op0_addr, op1_addr], [dst, op0, op1] + return ( + Operands(dst=dst, op0=op0, op1=op1, res=res), + [dst_addr, op0_addr, op1_addr], + [dst, op0, op1], + ) def is_zero(self, value): """ @@ -563,7 +610,7 @@ def is_zero(self, value): if not isinstance(value, int): if isinstance(value, RelocatableValue) and value.offset >= 0: return False - raise PureValueError('jmp != 0', value) + raise PureValueError("jmp != 0", value) return value == 0 def is_integer_value(self, value): @@ -591,51 +638,53 @@ def decode_current_instruction(self): def opcode_assertions(self, instruction: Instruction, operands: Operands): if instruction.opcode is Instruction.Opcode.ASSERT_EQ: if operands.res is None: - raise NotImplementedError( - 'Res.UNCONSTRAINED cannot be used with Opcode.ASSERT_EQ') + raise NotImplementedError("Res.UNCONSTRAINED cannot be used with Opcode.ASSERT_EQ") if operands.dst != operands.res and not self.check_eq(operands.dst, operands.res): raise Exception( - f'An ASSERT_EQ instruction failed: {operands.dst} != {operands.res}') + f"An ASSERT_EQ instruction failed: {operands.dst} != {operands.res}" + ) elif instruction.opcode is Instruction.Opcode.CALL: next_pc = self.run_context.pc + instruction.size if operands.op0 != next_pc and not self.check_eq(operands.op0, next_pc): raise Exception( - 'Call failed to write return-pc (inconsistent op0): ' + - f'{operands.op0} != {next_pc}. Did you forget to increment ap?') + "Call failed to write return-pc (inconsistent op0): " + + f"{operands.op0} != {next_pc}. Did you forget to increment ap?" + ) fp = self.run_context.fp if operands.dst != fp and not self.check_eq(operands.dst, fp): raise Exception( - 'Call failed to write return-fp (inconsistent dst): ' + - f'{operands.dst} != {fp}. Did you forget to increment ap?') + "Call failed to write return-fp (inconsistent dst): " + + f"{operands.dst} != {fp}. Did you forget to increment ap?" + ) elif instruction.opcode in [Instruction.Opcode.RET, Instruction.Opcode.NOP]: # Nothing to check. pass else: - raise NotImplementedError(f'Unsupported opcode {instruction.opcode}') + raise NotImplementedError(f"Unsupported opcode {instruction.opcode}") def step(self): self.skip_instruction_execution = False # Execute hints. for hint_index, hint in enumerate(self.hints.get(self.run_context.pc, [])): exec_locals = self.exec_scopes[-1] - exec_locals['memory'] = memory = self.validated_memory - exec_locals['ap'] = ap = self.run_context.ap - exec_locals['fp'] = fp = self.run_context.fp - exec_locals['pc'] = pc = self.run_context.pc - exec_locals['current_step'] = self.current_step - exec_locals['ids'] = hint.consts(pc, ap, fp, memory) - - exec_locals['vm_load_program'] = self.load_program - exec_locals['vm_enter_scope'] = self.enter_scope - exec_locals['vm_exit_scope'] = self.exit_scope + exec_locals["memory"] = memory = self.validated_memory + exec_locals["ap"] = ap = self.run_context.ap + exec_locals["fp"] = fp = self.run_context.fp + exec_locals["pc"] = pc = self.run_context.pc + exec_locals["current_step"] = self.current_step + exec_locals["ids"] = hint.consts(pc, ap, fp, memory) + + exec_locals["vm_load_program"] = self.load_program + exec_locals["vm_enter_scope"] = self.enter_scope + exec_locals["vm_exit_scope"] = self.exit_scope exec_locals.update(self.static_locals) self.exec_hint(hint.compiled, exec_locals, hint_index=hint_index) # Clear ids (which will be rewritten by the next hint anyway) to make the VM instance # smaller and faster to copy. - del exec_locals['ids'] - del exec_locals['memory'] + del exec_locals["ids"] + del exec_locals["memory"] if self.skip_instruction_execution: return @@ -651,12 +700,12 @@ def compile_hint(self, source, filename, hint_index: int): This function can be overridden by subclasses. """ try: - return compile(source, filename, mode='exec') + return compile(source, filename, mode="exec") except (IndentationError, SyntaxError): hint_exception = HintException(self, *sys.exc_info()) raise self.as_vm_exception( - hint_exception, notes=[hint_exception.exception_str], - hint_index=hint_index) from None + hint_exception, notes=[hint_exception.exception_str], hint_index=hint_index + ) from None def exec_hint(self, code, globals_, hint_index): """ @@ -668,14 +717,15 @@ def exec_hint(self, code, globals_, hint_index): except Exception: hint_exception = HintException(self, *sys.exc_info()) raise self.as_vm_exception( - hint_exception, notes=[hint_exception.exception_str], - hint_index=hint_index) from None + hint_exception, notes=[hint_exception.exception_str], hint_index=hint_index + ) from None def run_instruction(self, instruction, instruction_encoding): try: # Compute operands. operands, operands_mem_addresses, operands_mem_values = self.compute_operands( - instruction) + instruction + ) except Exception as exc: raise self.as_vm_exception(exc) from None @@ -686,11 +736,13 @@ def run_instruction(self, instruction, instruction_encoding): raise self.as_vm_exception(exc) from None # Write to trace. - self.trace.append(TraceEntry( - pc=self.run_context.pc, - ap=self.run_context.ap, - fp=self.run_context.fp, - )) + self.trace.append( + TraceEntry( + pc=self.run_context.pc, + ap=self.run_context.ap, + fp=self.run_context.fp, + ) + ) self.accessed_addresses.update(operands_mem_addresses) self.accessed_addresses.add(self.run_context.pc) @@ -720,8 +772,8 @@ def last_pc(self): return self.trace[-1].pc def as_vm_exception( - self, exc, pc=None, notes: Optional[List[str]] = None, - hint_index: Optional[int] = None): + self, exc, pc=None, notes: Optional[List[str]] = None, hint_index: Optional[int] = None + ): """ Wraps the exception with a VmException, adding to it location information. If pc is not given the current pc is used. @@ -747,16 +799,16 @@ def get_traceback(self) -> Optional[str]: """ Returns the traceback at the current pc. """ - traceback = '' + traceback = "" for traceback_pc in self.run_context.get_traceback_entries(): location = self.get_location(pc=traceback_pc) if location is None: - traceback += f'Unknown location (pc={traceback_pc})\n' + traceback += f"Unknown location (pc={traceback_pc})\n" continue - traceback += location.inst.to_string_with_content(message=f'(pc={traceback_pc})') + '\n' + traceback += location.inst.to_string_with_content(message=f"(pc={traceback_pc})") + "\n" if len(traceback) == 0: return None - return 'Cairo traceback (most recent call last):\n' + traceback + return "Cairo traceback (most recent call last):\n" + traceback def add_validation_rule(self, segment_index, rule: ValidationRule, *args): self.validated_memory.add_validation_rule(segment_index, rule, *args) @@ -808,12 +860,12 @@ def verify_auto_deductions(self): def end_run(self): self.verify_auto_deductions() if len(self.exec_scopes) != 1: - raise VmExceptionBase('Every enter_scope() requires a corresponding exit_scope().') + raise VmExceptionBase("Every enter_scope() requires a corresponding exit_scope().") def get_perm_range_check_limits( - trace: List[TraceEntry[int]], - memory: MemoryDict) -> Tuple[int, int]: + trace: List[TraceEntry[int]], memory: MemoryDict +) -> Tuple[int, int]: """ Returns the minimum value and maximum value in the perm_range_check component. """ diff --git a/src/starkware/cairo/lang/vm/vm_consts.py b/src/starkware/cairo/lang/vm/vm_consts.py index 3d5af65a..e4138ba8 100644 --- a/src/starkware/cairo/lang/vm/vm_consts.py +++ b/src/starkware/cairo/lang/vm/vm_consts.py @@ -3,14 +3,27 @@ from typing import Any, Callable, List, Optional, Union from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, TypeFelt, TypePointer, TypeStruct) + CairoType, + TypeFelt, + TypePointer, + TypeStruct, +) from starkware.cairo.lang.compiler.ast.expr import ExprCast, ExprDeref, Expression from starkware.cairo.lang.compiler.constants import SIZE_CONSTANT from starkware.cairo.lang.compiler.identifier_definition import ( - ConstDefinition, IdentifierDefinition, LabelDefinition, ReferenceDefinition, StructDefinition) + ConstDefinition, + IdentifierDefinition, + LabelDefinition, + ReferenceDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, IdentifierScope, IdentifierSearchResult, - MissingIdentifierError) + IdentifierError, + IdentifierManager, + IdentifierScope, + IdentifierSearchResult, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingData, ReferenceManager from starkware.cairo.lang.compiler.references import FlowTrackingError, Reference @@ -39,19 +52,18 @@ class VmConstsBase(ABC): """ def __init__(self, context: VmConstsContext): - object.__setattr__(self, '_context', context) + object.__setattr__(self, "_context", context) def __getattr__(self, name: str): - if name.startswith('__'): - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'") + if name.startswith("__"): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") try: return self.get_or_set_value(name, None) except FlowTrackingError: raise FlowTrackingError(f"Reference '{name}' is revoked.") from None def __setattr__(self, name: str, value): - assert value is not None, 'Setting a value to None is not allowed.' + assert value is not None, "Setting a value to None is not allowed." try: self.get_or_set_value(name, value) except FlowTrackingError: @@ -66,8 +78,8 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): def search_identifier_or_scope( - identifiers: IdentifierManager, accessible_scopes: List[ScopedName], - name: ScopedName) -> Union[IdentifierSearchResult, 'IdentifierScope']: + identifiers: IdentifierManager, accessible_scopes: List[ScopedName], name: ScopedName +) -> Union[IdentifierSearchResult, "IdentifierScope"]: """ If there is an identifier with the given name, returns an IdentifierSearchResult. Otherwise, if there is a scope with that name, returns the IdentifierScope instance. @@ -86,8 +98,13 @@ def search_identifier_or_scope( class VmConsts(VmConstsBase): def __init__( - self, *, accessible_scopes: List[ScopedName], path: ScopedName = ScopedName(), - instruction_offset: Optional[int] = None, **kw): + self, + *, + accessible_scopes: List[ScopedName], + path: ScopedName = ScopedName(), + instruction_offset: Optional[int] = None, + **kw, + ): """ Constructs a VmConsts which is used to dynamically resolve constant values. The 'path' parameter is the scoped name used to get from the global consts variable @@ -98,10 +115,10 @@ def __init__( is a label with the name 'path' then it holds the offset of said label. """ super().__init__(**kw) - object.__setattr__(self, '_accessible_scopes', accessible_scopes) - object.__setattr__(self, '_path', path) + object.__setattr__(self, "_accessible_scopes", accessible_scopes) + object.__setattr__(self, "_path", path) if instruction_offset is not None: - object.__setattr__(self, 'instruction_offset_', instruction_offset) + object.__setattr__(self, "instruction_offset_", instruction_offset) def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): """ @@ -113,23 +130,24 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): result = search_identifier_or_scope( identifiers=self._context.identifiers, accessible_scopes=self._accessible_scopes, - name=ScopedName.from_string(name)) + name=ScopedName.from_string(name), + ) except MissingIdentifierError as exc: raise MissingIdentifierError(self._path + exc.fullname) from None value: Optional[IdentifierDefinition] if isinstance(result, IdentifierSearchResult): value = result.identifier_definition - handler_name = f'handle_{type(value).__name__}' + handler_name = f"handle_{type(value).__name__}" scope = result.get_canonical_name() identifier_type = value.TYPE elif isinstance(result, IdentifierScope): value = None - handler_name = 'handle_scope' + handler_name = "handle_scope" scope = result.fullname - identifier_type = 'scope' + identifier_type = "scope" else: - raise NotImplementedError(f'Unexpected type {type(result).__name__}.') + raise NotImplementedError(f"Unexpected type {type(result).__name__}.") if handler_name not in dir(self): self.raise_unsupported_error(name=self._path + name, identifier_type=identifier_type) @@ -137,30 +155,43 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): return getattr(self, handler_name)(name, value, scope, set_value) def handle_ConstDefinition( - self, name: str, identifier: ConstDefinition, scope: ScopedName, - set_value: Optional[MaybeRelocatable]): - assert set_value is None, 'Cannot change the value of a constant.' + self, + name: str, + identifier: ConstDefinition, + scope: ScopedName, + set_value: Optional[MaybeRelocatable], + ): + assert set_value is None, "Cannot change the value of a constant." # The current attribute is a const, return its value. return identifier.value def handle_scope( - self, name: str, identifier: Union[IdentifierScope, LabelDefinition], - scope: ScopedName, set_value: Optional[MaybeRelocatable]): - assert set_value is None, 'Cannot change the value of a scope definition.' + self, + name: str, + identifier: Union[IdentifierScope, LabelDefinition], + scope: ScopedName, + set_value: Optional[MaybeRelocatable], + ): + assert set_value is None, "Cannot change the value of a scope definition." # The current attribute is a namespace or a label. return VmConsts( context=self._context, accessible_scopes=[scope], path=self._path + name, - instruction_offset=identifier.pc if isinstance(identifier, LabelDefinition) else None) + instruction_offset=identifier.pc if isinstance(identifier, LabelDefinition) else None, + ) handle_LabelDefinition = handle_scope handle_FunctionDefinition = handle_scope def handle_StructDefinition( - self, name: str, identifier: StructDefinition, scope: ScopedName, - set_value: Optional[MaybeRelocatable]): - assert set_value is None, 'Cannot change the value of a struct definition.' + self, + name: str, + identifier: StructDefinition, + scope: ScopedName, + set_value: Optional[MaybeRelocatable], + ): + assert set_value is None, "Cannot change the value of a struct definition." return VmConstsStruct( context=self._context, @@ -168,22 +199,25 @@ def handle_StructDefinition( ) def handle_ReferenceDefinition( - self, name: str, identifier: ReferenceDefinition, scope: ScopedName, - set_value: Optional[MaybeRelocatable]): + self, + name: str, + identifier: ReferenceDefinition, + scope: ScopedName, + set_value: Optional[MaybeRelocatable], + ): # In set mode, take the address of the given reference instead. reference = self._context.flow_tracking_data.resolve_reference( - reference_manager=self._context.reference_manager, name=identifier.full_name) + reference_manager=self._context.reference_manager, name=identifier.full_name + ) if set_value is None: - expr = reference.eval( - self._context.flow_tracking_data.ap_tracking) - expr, expr_type = simplify_type_system( - expr, - identifiers=self._context.identifiers) + expr = reference.eval(self._context.flow_tracking_data.ap_tracking) + expr, expr_type = simplify_type_system(expr, identifiers=self._context.identifiers) if isinstance(expr_type, TypeStruct): # If the reference is of type T, take its address and treat it as T*. - assert isinstance(expr, ExprDeref), \ - f"Expected expression of type '{expr_type.format()}' to have an address." + assert isinstance( + expr, ExprDeref + ), f"Expected expression of type '{expr_type.format()}' to have an address." expr = expr.addr expr_type = TypePointer(pointee=expr_type) val = self._context.evaluator(expr) @@ -193,19 +227,20 @@ def handle_ReferenceDefinition( return val else: # Typed reference, return VmConstsReference which allows accessing members. - assert isinstance(expr_type, TypePointer) and \ - isinstance(expr_type.pointee, TypeStruct), \ - 'Type must be of the form T*.' + assert isinstance(expr_type, TypePointer) and isinstance( + expr_type.pointee, TypeStruct + ), "Type must be of the form T*." return VmConstsReference( - context=self._context, - struct_name=expr_type.pointee.scope, - reference_value=val) + context=self._context, struct_name=expr_type.pointee.scope, reference_value=val + ) else: - assert str(scope[-1:]) == name, 'Expecting scope to end with name.' + assert str(scope[-1:]) == name, "Expecting scope to end with name." value, value_type = simplify_type_system( - reference.value, - identifiers=self._context.identifiers) - assert isinstance(value, ExprDeref), f"""\ + reference.value, identifiers=self._context.identifiers + ) + assert isinstance( + value, ExprDeref + ), f"""\ {scope} (= {value.format()}) does not reference memory and cannot be assigned.""" value_ref = Reference( @@ -214,8 +249,9 @@ def handle_ReferenceDefinition( ap_tracking_data=reference.ap_tracking_data, ) - addr = self._context.evaluator(value_ref.eval( - self._context.flow_tracking_data.ap_tracking)) + addr = self._context.evaluator( + value_ref.eval(self._context.flow_tracking_data.ap_tracking) + ) self._context.memory[addr] = set_value def raise_unsupported_error(self, name: ScopedName, identifier_type: str): @@ -223,7 +259,8 @@ def raise_unsupported_error(self, name: ScopedName, identifier_type: str): Raises an exception which says that the identifier type is not supported. """ raise NotImplementedError( - f"Unsupported identifier type '{identifier_type}' of identifier '{name}'.") + f"Unsupported identifier type '{identifier_type}' of identifier '{name}'." + ) class VmConstsReference(VmConstsBase): @@ -233,11 +270,16 @@ def __init__(self, *, reference_value, struct_name: ScopedName, **kw): """ super().__init__(**kw) - object.__setattr__(self, '_struct_definition', get_struct_definition( - struct_name=struct_name, identifier_manager=self._context.identifiers)) + object.__setattr__( + self, + "_struct_definition", + get_struct_definition( + struct_name=struct_name, identifier_manager=self._context.identifiers + ), + ) - object.__setattr__(self, '_reference_value', reference_value) - object.__setattr__(self, 'address_', reference_value) + object.__setattr__(self, "_reference_value", reference_value) + object.__setattr__(self, "address_", reference_value) @property def type_(self): @@ -255,7 +297,8 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): member_def = self._struct_definition.members.get(name) if member_def is None: raise IdentifierError( - f"'{name}' is not a member of '{self._struct_definition.full_name}'.") from None + f"'{name}' is not a member of '{self._struct_definition.full_name}'." + ) from None addr = self._reference_value + member_def.offset @@ -267,18 +310,18 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): return self._context.memory[addr] elif isinstance(expr_type, TypeStruct): return VmConstsReference( - context=self._context, - struct_name=expr_type.scope, - reference_value=addr) + context=self._context, struct_name=expr_type.scope, reference_value=addr + ) else: # Typed reference, return VmConstsReference which allows accessing members. - assert isinstance(expr_type, TypePointer) and \ - isinstance(expr_type.pointee, TypeStruct), \ - 'Type must be of the form T*.' + assert isinstance(expr_type, TypePointer) and isinstance( + expr_type.pointee, TypeStruct + ), "Type must be of the form T*." return VmConstsReference( context=self._context, struct_name=expr_type.pointee.scope, - reference_value=self._context.memory[addr]) + reference_value=self._context.memory[addr], + ) def is_simple_type(expr_type: CairoType) -> bool: @@ -288,9 +331,9 @@ def is_simple_type(expr_type: CairoType) -> bool: the returned value will be a int/relocatable value (unlike T and T* where the returned value is VmConstsReference to allow accessing submembers). """ - is_pointer_to_felt_or_pointer = ( - isinstance(expr_type, TypePointer) and - isinstance(expr_type.pointee, (TypePointer, TypeFelt))) + is_pointer_to_felt_or_pointer = isinstance(expr_type, TypePointer) and isinstance( + expr_type.pointee, (TypePointer, TypeFelt) + ) return isinstance(expr_type, TypeFelt) or is_pointer_to_felt_or_pointer @@ -300,16 +343,17 @@ def __init__(self, *, struct_definition: StructDefinition, **kw): Constructs a VmConstsStruct which allows accessing structs. """ super().__init__(**kw) - object.__setattr__(self, '_struct_definition', struct_definition) + object.__setattr__(self, "_struct_definition", struct_definition) def __eq__(self, other): if not isinstance(other, self.__class__): return False - return self._struct_definition == other._struct_definition and \ - self._context is other._context + return ( + self._struct_definition == other._struct_definition and self._context is other._context + ) def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): - assert set_value is None, 'Cannot change the value of a constant.' + assert set_value is None, "Cannot change the value of a constant." if name == str(SIZE_CONSTANT): return self._struct_definition.size @@ -317,6 +361,7 @@ def get_or_set_value(self, name: str, set_value: Optional[MaybeRelocatable]): member_def = self._struct_definition.members.get(name) if member_def is None: raise IdentifierError( - f"'{name}' is not a member of '{self._struct_definition.full_name}'.") from None + f"'{name}' is not a member of '{self._struct_definition.full_name}'." + ) from None return member_def.offset diff --git a/src/starkware/cairo/lang/vm/vm_consts_test.py b/src/starkware/cairo/lang/vm/vm_consts_test.py index fc907769..6b78f560 100644 --- a/src/starkware/cairo/lang/vm/vm_consts_test.py +++ b/src/starkware/cairo/lang/vm/vm_consts_test.py @@ -6,10 +6,19 @@ from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer, TypeStruct from starkware.cairo.lang.compiler.expression_evaluator import ExpressionEvaluator from starkware.cairo.lang.compiler.identifier_definition import ( - AliasDefinition, ConstDefinition, IdentifierDefinition, LabelDefinition, MemberDefinition, - ReferenceDefinition, StructDefinition) + AliasDefinition, + ConstDefinition, + IdentifierDefinition, + LabelDefinition, + MemberDefinition, + ReferenceDefinition, + StructDefinition, +) from starkware.cairo.lang.compiler.identifier_manager import ( - IdentifierError, IdentifierManager, MissingIdentifierError) + IdentifierError, + IdentifierManager, + MissingIdentifierError, +) from starkware.cairo.lang.compiler.parser import parse_expr from starkware.cairo.lang.compiler.preprocessor.flow import FlowTrackingDataActual, ReferenceManager from starkware.cairo.lang.compiler.preprocessor.reg_tracking import RegTrackingData @@ -27,14 +36,18 @@ def dummy_evaluator(expr): def test_vmconsts_simple(): identifier_values = { - scope('x.y.z'): ConstDefinition(1), - scope('x.z'): ConstDefinition(2), - scope('y'): ConstDefinition(3), + scope("x.y.z"): ConstDefinition(1), + scope("x.z"): ConstDefinition(2), + scope("y"): ConstDefinition(3), } context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, + identifiers=IdentifierManager.from_dict(identifier_values), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) assert consts.x.y.z == 1 assert consts.x.z == 2 @@ -44,14 +57,18 @@ def test_vmconsts_simple(): def test_label(): identifier_values = { - scope('x'): LabelDefinition(10), - scope('x.y'): ConstDefinition(1), - scope('y'): ConstDefinition(2), + scope("x"): LabelDefinition(10), + scope("x.y"): ConstDefinition(1), + scope("y"): ConstDefinition(2), } context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, + identifiers=IdentifierManager.from_dict(identifier_values), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) assert consts.x.instruction_offset_ == 10 assert consts.y == 2 @@ -60,14 +77,18 @@ def test_label(): def test_alias(): identifier_values = { - scope('w'): AliasDefinition(scope('z.y')), - scope('x'): AliasDefinition(scope('z')), - scope('z.y'): ConstDefinition(1), + scope("w"): AliasDefinition(scope("z.y")), + scope("x"): AliasDefinition(scope("z")), + scope("z.y"): ConstDefinition(1), } context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, + identifiers=IdentifierManager.from_dict(identifier_values), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) assert consts.x.y == 1 assert consts.w == 1 @@ -75,14 +96,18 @@ def test_alias(): def test_scope_order(): identifier_values = { - scope('x.y'): ConstDefinition(1), - scope('y'): ConstDefinition(2), + scope("x.y"): ConstDefinition(1), + scope("y"): ConstDefinition(2), } context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, + identifiers=IdentifierManager.from_dict(identifier_values), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) - consts = VmConsts(context=context, accessible_scopes=[ScopedName(), scope('x')]) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) + consts = VmConsts(context=context, accessible_scopes=[ScopedName(), scope("x")]) assert consts.y == 1 assert consts.x.y == 1 @@ -90,63 +115,72 @@ def test_scope_order(): def test_references(): reference_manager = ReferenceManager() references = { - scope('x.ref'): reference_manager.alloc_id(Reference( - pc=0, - value=parse_expr('[ap + 1]'), - ap_tracking_data=RegTrackingData(group=0, offset=2), - )), - scope('x.ref2'): reference_manager.alloc_id(Reference( - pc=0, - value=parse_expr('[ap + 1] + 0'), - ap_tracking_data=RegTrackingData(group=0, offset=2), - )), - scope('x.ref3'): reference_manager.alloc_id(Reference( - pc=0, - value=parse_expr('ap + 1'), - ap_tracking_data=RegTrackingData(group=0, offset=2), - )), - scope('x.typeref'): reference_manager.alloc_id(Reference( - pc=0, - value=mark_types_in_expr_resolved(parse_expr('cast(ap + 1, MyStruct*)')), - ap_tracking_data=RegTrackingData(group=0, offset=3), - )), - scope('x.typeref2'): reference_manager.alloc_id(Reference( - pc=0, - value=mark_types_in_expr_resolved(parse_expr('cast([ap + 1], MyStruct*)')), - ap_tracking_data=RegTrackingData(group=0, offset=3), - )), + scope("x.ref"): reference_manager.alloc_id( + Reference( + pc=0, + value=parse_expr("[ap + 1]"), + ap_tracking_data=RegTrackingData(group=0, offset=2), + ) + ), + scope("x.ref2"): reference_manager.alloc_id( + Reference( + pc=0, + value=parse_expr("[ap + 1] + 0"), + ap_tracking_data=RegTrackingData(group=0, offset=2), + ) + ), + scope("x.ref3"): reference_manager.alloc_id( + Reference( + pc=0, + value=parse_expr("ap + 1"), + ap_tracking_data=RegTrackingData(group=0, offset=2), + ) + ), + scope("x.typeref"): reference_manager.alloc_id( + Reference( + pc=0, + value=mark_types_in_expr_resolved(parse_expr("cast(ap + 1, MyStruct*)")), + ap_tracking_data=RegTrackingData(group=0, offset=3), + ) + ), + scope("x.typeref2"): reference_manager.alloc_id( + Reference( + pc=0, + value=mark_types_in_expr_resolved(parse_expr("cast([ap + 1], MyStruct*)")), + ap_tracking_data=RegTrackingData(group=0, offset=3), + ) + ), } - my_struct = TypeStruct( - scope=scope('MyStruct'), is_fully_resolved=True) + my_struct = TypeStruct(scope=scope("MyStruct"), is_fully_resolved=True) my_struct_star = TypePointer(pointee=my_struct) identifier_values = { - scope('x.ref'): ReferenceDefinition( - full_name=scope('x.ref'), cairo_type=TypeFelt(), references=[] + scope("x.ref"): ReferenceDefinition( + full_name=scope("x.ref"), cairo_type=TypeFelt(), references=[] ), - scope('x.ref2'): ReferenceDefinition( - full_name=scope('x.ref2'), cairo_type=TypeFelt(), references=[] + scope("x.ref2"): ReferenceDefinition( + full_name=scope("x.ref2"), cairo_type=TypeFelt(), references=[] ), - scope('x.ref3'): ReferenceDefinition( - full_name=scope('x.ref3'), cairo_type=TypeFelt(), references=[] + scope("x.ref3"): ReferenceDefinition( + full_name=scope("x.ref3"), cairo_type=TypeFelt(), references=[] ), - scope('x.typeref'): ReferenceDefinition( - full_name=scope('x.typeref'), cairo_type=my_struct_star, references=[] + scope("x.typeref"): ReferenceDefinition( + full_name=scope("x.typeref"), cairo_type=my_struct_star, references=[] ), - scope('x.typeref2'): ReferenceDefinition( - full_name=scope('x.typeref2'), cairo_type=my_struct_star, references=[] + scope("x.typeref2"): ReferenceDefinition( + full_name=scope("x.typeref2"), cairo_type=my_struct_star, references=[] ), - scope('MyStruct'): StructDefinition( - full_name=scope('MyStruct'), + scope("MyStruct"): StructDefinition( + full_name=scope("MyStruct"), members={ - 'member': MemberDefinition(offset=10, cairo_type=TypeFelt()), - 'struct': MemberDefinition(offset=11, cairo_type=my_struct), + "member": MemberDefinition(offset=10, cairo_type=TypeFelt()), + "struct": MemberDefinition(offset=11, cairo_type=my_struct), }, size=11, ), } identifiers = IdentifierManager.from_dict(identifier_values) - prime = 2**64 + 13 + prime = 2 ** 64 + 13 ap = 100 fp = 200 memory = { @@ -166,42 +200,31 @@ def test_references(): reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, - pc=0) + pc=0, + ) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) assert consts.x.ref == memory[(ap - 2) + 1] assert consts.x.typeref.address_ == (ap - 1) + 1 assert consts.x.typeref.member == memory[(ap - 1) + 1 + 10] - with pytest.raises( - IdentifierError, - match="'abc' is not a member of 'MyStruct'."): + with pytest.raises(IdentifierError, match="'abc' is not a member of 'MyStruct'."): consts.x.typeref.abc - with pytest.raises( - IdentifierError, - match="'SIZE' is not a member of 'MyStruct'."): + with pytest.raises(IdentifierError, match="'SIZE' is not a member of 'MyStruct'."): consts.x.typeref.SIZE - with pytest.raises( - AssertionError, - match='Cannot change the value of a struct definition.'): + with pytest.raises(AssertionError, match="Cannot change the value of a struct definition."): consts.MyStruct = 13 assert consts.MyStruct.member == 10 - with pytest.raises( - AssertionError, - match='Cannot change the value of a constant.'): + with pytest.raises(AssertionError, match="Cannot change the value of a constant."): consts.MyStruct.member = 13 assert consts.MyStruct.SIZE == 11 - with pytest.raises( - AssertionError, - match='Cannot change the value of a constant.'): + with pytest.raises(AssertionError, match="Cannot change the value of a constant."): consts.MyStruct.SIZE = 13 - with pytest.raises( - IdentifierError, - match="'abc' is not a member of 'MyStruct'."): + with pytest.raises(IdentifierError, match="'abc' is not a member of 'MyStruct'."): consts.MyStruct.abc # Test that VmConsts can be used to assign values to references of the form '[...]'. @@ -231,16 +254,16 @@ def test_references(): 4321 + 11 + 10: 2, } - with pytest.raises(AssertionError, match='Cannot change the value of a scope definition'): + with pytest.raises(AssertionError, match="Cannot change the value of a scope definition"): consts.x = 1000 with pytest.raises( AssertionError, - match=r'x.ref2 \(= \[ap \+ 1\] \+ 0\) does not reference memory and cannot be assigned.', + match=r"x.ref2 \(= \[ap \+ 1\] \+ 0\) does not reference memory and cannot be assigned.", ): consts.x.ref2 = 1000 with pytest.raises( AssertionError, - match=r'x.typeref \(= ap \+ 1\) does not reference memory and cannot be assigned.', + match=r"x.typeref \(= ap \+ 1\) does not reference memory and cannot be assigned.", ): consts.x.typeref = 1000 @@ -252,16 +275,21 @@ def get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memo identifiers = IdentifierManager.from_dict(identifier_values) context = VmConstsContext( identifiers=identifiers, - evaluator=ExpressionEvaluator(2**64 + 13, 0, 0, memory, identifiers).eval, + evaluator=ExpressionEvaluator(2 ** 64 + 13, 0, 0, memory, identifiers).eval, reference_manager=reference_manager, - flow_tracking_data=flow_tracking_data, memory=memory, pc=9) + flow_tracking_data=flow_tracking_data, + memory=memory, + pc=9, + ) return VmConsts(context=context, accessible_scopes=[ScopedName()]) def test_reference_rebinding(): identifier_values = { - scope('ref'): ReferenceDefinition( - full_name=scope('ref'), cairo_type=TypeFelt(), references=[], + scope("ref"): ReferenceDefinition( + full_name=scope("ref"), + cairo_type=TypeFelt(), + references=[], ) } @@ -273,10 +301,10 @@ def test_reference_rebinding(): flow_tracking_data = flow_tracking_data.add_reference( reference_manager=reference_manager, - name=scope('ref'), + name=scope("ref"), ref=Reference( pc=10, - value=parse_expr('10'), + value=parse_expr("10"), ap_tracking_data=RegTrackingData(group=0, offset=2), ), ) @@ -285,16 +313,14 @@ def test_reference_rebinding(): def test_reference_to_structs(): - t = TypeStruct(scope=scope('T'), is_fully_resolved=True) + t = TypeStruct(scope=scope("T"), is_fully_resolved=True) t_star = TypePointer(pointee=t) identifier_values = { - scope('ref'): ReferenceDefinition( - full_name=scope('ref'), cairo_type=t, references=[] - ), - scope('T'): StructDefinition( - full_name=scope('T'), + scope("ref"): ReferenceDefinition(full_name=scope("ref"), cairo_type=t, references=[]), + scope("T"): StructDefinition( + full_name=scope("T"), members={ - 'x': MemberDefinition(offset=3, cairo_type=t_star), + "x": MemberDefinition(offset=3, cairo_type=t_star), }, size=4, ), @@ -303,16 +329,15 @@ def test_reference_to_structs(): flow_tracking_data = FlowTrackingDataActual(ap_tracking=RegTrackingData()) flow_tracking_data = flow_tracking_data.add_reference( reference_manager=reference_manager, - name=scope('ref'), + name=scope("ref"), ref=Reference( pc=0, - value=mark_types_in_expr_resolved(parse_expr('[cast(100, T*)]')), + value=mark_types_in_expr_resolved(parse_expr("[cast(100, T*)]")), ap_tracking_data=RegTrackingData(group=0, offset=2), ), ) memory = {103: 200} - consts = get_vm_consts( - identifier_values, reference_manager, flow_tracking_data, memory=memory) + consts = get_vm_consts(identifier_values, reference_manager, flow_tracking_data, memory=memory) assert consts.ref.address_ == 100 assert consts.ref.x.address_ == 200 @@ -326,15 +351,19 @@ def test_reference_to_structs(): def test_missing_attributes(): identifier_values = { - scope('x.y'): ConstDefinition(1), - scope('z'): AliasDefinition(scope('x')), - scope('x.missing'): AliasDefinition(scope('nothing')), + scope("x.y"): ConstDefinition(1), + scope("z"): AliasDefinition(scope("x")), + scope("x.missing"): AliasDefinition(scope("nothing")), } context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, + identifiers=IdentifierManager.from_dict(identifier_values), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) # Identifier not exists anywhere. @@ -355,60 +384,72 @@ def test_missing_attributes(): # Pass through bad alias. with pytest.raises( - IdentifierError, - match="Alias resolution failed: x.missing -> nothing. Unknown identifier 'nothing'."): + IdentifierError, + match="Alias resolution failed: x.missing -> nothing. Unknown identifier 'nothing'.", + ): consts.x.missing.y def test_unsupported_attribute(): class UnsupportedIdentifier(IdentifierDefinition): - TYPE: ClassVar[str] = 'tested_t' + TYPE: ClassVar[str] = "tested_t" identifier_values = { - scope('x'): UnsupportedIdentifier(), - scope('y.z'): UnsupportedIdentifier(), + scope("x"): UnsupportedIdentifier(), + scope("y.z"): UnsupportedIdentifier(), } context = VmConstsContext( - identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, + identifiers=IdentifierManager.from_dict(identifier_values), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) - consts = VmConsts(context=context, accessible_scopes=[scope('')]) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) + consts = VmConsts(context=context, accessible_scopes=[scope("")]) # Identifier in root namespace. with pytest.raises( - NotImplementedError, - match="Unsupported identifier type 'tested_t' of identifier 'x'."): + NotImplementedError, match="Unsupported identifier type 'tested_t' of identifier 'x'." + ): consts.x # Identifier in sub namespace. with pytest.raises( - NotImplementedError, - match="Unsupported identifier type 'tested_t' of identifier 'y.z'."): + NotImplementedError, match="Unsupported identifier type 'tested_t' of identifier 'y.z'." + ): consts.y.z def test_get_dunder_something(): context = VmConstsContext( - identifiers=IdentifierManager(), evaluator=dummy_evaluator, + identifiers=IdentifierManager(), + evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) - consts = VmConsts(context=context, accessible_scopes=[scope('')]) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) + consts = VmConsts(context=context, accessible_scopes=[scope("")]) with pytest.raises( - AttributeError, - match=re.escape("'VmConsts' object has no attribute '__something'")): + AttributeError, match=re.escape("'VmConsts' object has no attribute '__something'") + ): consts.__something def test_unparsed(): identifier_values = { - scope('x'): LabelDefinition(10), + scope("x"): LabelDefinition(10), } context = VmConstsContext( identifiers=IdentifierManager.from_dict(identifier_values), evaluator=dummy_evaluator, reference_manager=ReferenceManager(), - flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), memory={}, pc=0) - consts = VmConsts(context=context, accessible_scopes=[scope('')]) + flow_tracking_data=FlowTrackingDataActual(ap_tracking=RegTrackingData()), + memory={}, + pc=0, + ) + consts = VmConsts(context=context, accessible_scopes=[scope("")]) with pytest.raises(IdentifierError, match="Unexpected '.' after 'x' which is label."): consts.x.z @@ -416,26 +457,26 @@ def test_unparsed(): def test_revoked_reference(): reference_manager = ReferenceManager() - ref_id = reference_manager.alloc_id(reference=Reference( - pc=0, - value=parse_expr('[ap + 1]'), - ap_tracking_data=RegTrackingData(group=0, offset=2), - )) + ref_id = reference_manager.alloc_id( + reference=Reference( + pc=0, + value=parse_expr("[ap + 1]"), + ap_tracking_data=RegTrackingData(group=0, offset=2), + ) + ) identifier_values = { - scope('x'): ReferenceDefinition( - full_name=scope('x'), cairo_type=TypeFelt(), references=[] - ), + scope("x"): ReferenceDefinition(full_name=scope("x"), cairo_type=TypeFelt(), references=[]), } identifiers = IdentifierManager.from_dict(identifier_values) - prime = 2**64 + 13 + prime = 2 ** 64 + 13 ap = 100 fp = 200 memory = {} flow_tracking_data = FlowTrackingDataActual( ap_tracking=RegTrackingData(group=1, offset=4), - reference_ids={scope('x'): ref_id}, + reference_ids={scope("x"): ref_id}, ) context = VmConstsContext( identifiers=identifiers, @@ -443,7 +484,8 @@ def test_revoked_reference(): reference_manager=reference_manager, flow_tracking_data=flow_tracking_data, memory=memory, - pc=0) + pc=0, + ) consts = VmConsts(context=context, accessible_scopes=[ScopedName()]) with pytest.raises(FlowTrackingError, match="Reference 'x' is revoked."): diff --git a/src/starkware/cairo/lang/vm/vm_test.py b/src/starkware/cairo/lang/vm/vm_test.py index 06fad7c9..621d5c4d 100644 --- a/src/starkware/cairo/lang/vm/vm_test.py +++ b/src/starkware/cairo/lang/vm/vm_test.py @@ -5,13 +5,20 @@ from starkware.cairo.lang.compiler.cairo_compile import compile_cairo from starkware.cairo.lang.vm.memory_dict import ( - InconsistentMemoryError, MemoryDict, UnknownMemoryError) + InconsistentMemoryError, + MemoryDict, + UnknownMemoryError, +) from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue from starkware.cairo.lang.vm.vm import ( - InconsistentAutoDeductionError, RunContext, VirtualMachine, VmException) + InconsistentAutoDeductionError, + RunContext, + VirtualMachine, + VmException, +) from starkware.python.test_utils import maybe_raises -PRIME = 2**64 + 13 +PRIME = 2 ** 64 + 13 def run_single(code: str, steps: int, *, pc=RelocatableValue(0, 10), ap=100, fp=100, extra_mem={}): @@ -21,7 +28,7 @@ def run_single(code: str, steps: int, *, pc=RelocatableValue(0, 10), ap=100, fp= memory: Dict[MaybeRelocatable, MaybeRelocatable] = { **{pc + i: v for i, v in enumerate(program.data)}, fp - 1: 1234, - **extra_mem + **extra_mem, } context = RunContext( pc=pc, @@ -68,8 +75,11 @@ def test_simple(): vm = run_single(code, 9, pc=10, ap=102, extra_mem={101: 1}) assert [vm.run_context.memory[101 + i] for i in range(7)] == [1, 3, 9, 10, 16, 48, 10] - assert vm.accessed_addresses == \ - set(vm.run_context.memory.keys()) == {*range(10, 28), 99, *range(101, 108)} + assert ( + vm.accessed_addresses + == set(vm.run_context.memory.keys()) + == {*range(10, 28), 99, *range(101, 108)} + ) def test_jnz(): @@ -87,11 +97,16 @@ def test_jnz(): vm = run_single(code, 100, ap=101) - assert [vm.run_context.memory[101 + i] for i in range(8 + 25)] == \ - [7, 6, 5, 4, 3, 2, 1, 0] + [4, 3, 2, 1, 0] * 5 + assert [vm.run_context.memory[101 + i] for i in range(8 + 25)] == [7, 6, 5, 4, 3, 2, 1, 0] + [ + 4, + 3, + 2, + 1, + 0, + ] * 5 -@pytest.mark.parametrize('offset', [0, -1]) +@pytest.mark.parametrize("offset", [0, -1]) def test_jnz_relocatables(offset: int): code = """ jmp body if [ap - 1] != 0 @@ -100,9 +115,11 @@ def test_jnz_relocatables(offset: int): [ap] = 1; ap++ """ relocatable_value = RelocatableValue(segment_index=5, offset=offset) - error_message = \ - None if relocatable_value.offset >= 0 else \ - f'Could not complete computation jmp != 0 of non pure value {relocatable_value}' + error_message = ( + None + if relocatable_value.offset >= 0 + else f"Could not complete computation jmp != 0 of non pure value {relocatable_value}" + ) with maybe_raises(expected_exception=VmException, error_message=error_message): vm = run_single(code, 2, ap=102, extra_mem={101: relocatable_value}) assert vm.run_context.memory[102] == 1 @@ -136,8 +153,21 @@ def test_call_ret(): # Consider the memory cells which are at least 1000 to filter out pc and fp addresses. mem = [vm.run_context.memory[100 + i] for i in range(25)] - assert [x for x in mem if isinstance(x, int) and x >= 1000] == \ - [1000, 2000, 3000, 2001, 3000, 2002, 1001, 2000, 3000, 2001, 3000, 2002, 1002] + assert [x for x in mem if isinstance(x, int) and x >= 1000] == [ + 1000, + 2000, + 3000, + 2001, + 3000, + 2002, + 1001, + 2000, + 3000, + 2001, + 3000, + 2002, + 1002, + ] def test_addap(): @@ -259,11 +289,10 @@ def f(): # In this test we actually do write the code to a file, to allow the linecache module to fetch # the line raising the exception. - cairo_file = tempfile.NamedTemporaryFile('w') + cairo_file = tempfile.NamedTemporaryFile("w") print(code, file=cairo_file) cairo_file.flush() - program = compile_cairo( - code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) + program = compile_cairo(code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) program_base = 10 memory = {program_base + i: v for i, v in enumerate(program.data)} @@ -283,7 +312,9 @@ def f(): vm.step() with pytest.raises(VmException) as excinfo: vm.step() - assert str(excinfo.value) == f"""\ + assert ( + str(excinfo.value) + == f"""\ {cairo_file.name}:13:1: Error at pc=12: Got an exception while executing a hint. %{{ @@ -295,6 +326,7 @@ def f(): 0 / 0 # Raises exception. ZeroDivisionError: division by zero\ """ + ) def test_hint_indentation_error(): @@ -311,11 +343,10 @@ def f(): # In this test we actually do write the code to a file, to allow the linecache module to fetch # the line raising the exception. - cairo_file = tempfile.NamedTemporaryFile('w') + cairo_file = tempfile.NamedTemporaryFile("w") print(code, file=cairo_file) cairo_file.flush() - program = compile_cairo( - code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) + program = compile_cairo(code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) program_base = 10 memory = {program_base + i: v for i, v in enumerate(program.data)} @@ -360,11 +391,10 @@ def f(): # In this test we actually do write the code to a file, to allow the linecache module to fetch # the line raising the exception. - cairo_file = tempfile.NamedTemporaryFile('w') + cairo_file = tempfile.NamedTemporaryFile("w") print(code, file=cairo_file) cairo_file.flush() - program = compile_cairo( - code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) + program = compile_cairo(code=[(code, cairo_file.name)], prime=PRIME, debug_info=True) program_base = 10 memory = {program_base + i: v for i, v in enumerate(program.data)} @@ -466,16 +496,16 @@ def test_skip_instruction_execution(): ) vm = VirtualMachine(program, context, {}) - vm.enter_scope({'vm': vm}) + vm.enter_scope({"vm": vm}) exec_locals = vm.exec_scopes[-1] - assert 'x' not in exec_locals + assert "x" not in exec_locals assert vm.run_context.pc == 0 vm.step() - assert exec_locals['x'] == 0 + assert exec_locals["x"] == 0 assert vm.run_context.pc == 2 vm.step() - assert exec_locals['x'] == 1 + assert exec_locals["x"] == 1 assert vm.run_context.pc == 4 assert vm.run_context.ap == initial_ap + 1 assert vm.run_context.memory[vm.run_context.ap - 1] == 10 @@ -516,7 +546,7 @@ def rule_ap_segment(vm, addr, val): assert vm.run_context.memory[initial_fp] == 200 assert vm.run_context.memory[initial_fp + 1] == 300 - with pytest.raises(InconsistentAutoDeductionError, match='at address 2:100. 200 != 456'): + with pytest.raises(InconsistentAutoDeductionError, match="at address 2:100. 200 != 456"): vm.verify_auto_deductions() @@ -550,10 +580,10 @@ def test_memory_validation_in_hints(): assert vm.validated_memory._ValidatedMemoryDict__validated_addresses == {initial_ap_and_fp} def fail_validation(memory, addr): - raise Exception('Validation failed.') + raise Exception("Validation failed.") vm.add_validation_rule(1, fail_validation) - with pytest.raises(VmException, match='Exception: Validation failed.'): + with pytest.raises(VmException, match="Exception: Validation failed."): vm.step() @@ -562,7 +592,7 @@ def test_nonpure_mul(): [ap] = [ap - 1] * 2; ap++ """ - with pytest.raises(VmException, match='Could not complete computation *'): + with pytest.raises(VmException, match="Could not complete computation *"): run_single(code, 1, ap=102, extra_mem={101: RelocatableValue(1, 0)}) @@ -571,7 +601,7 @@ def test_nonpure_jmp_rel(): jmp rel [ap - 1] """ - with pytest.raises(VmException, match='Could not complete computation jmp rel'): + with pytest.raises(VmException, match="Could not complete computation jmp rel"): run_single(code, 1, ap=102, extra_mem={101: RelocatableValue(1, 0)}) @@ -589,7 +619,7 @@ def test_jmp_segment(): **{program_base_b + i: v for i, v in enumerate(program.data)}, 99: 0, 100: program_base_b, - 101: program_base_a + 101: program_base_a, } context = RunContext( pc=program_base_a, @@ -627,9 +657,12 @@ def test_simple_deductions(): vm = run_single(code, 6, ap=101, extra_mem={99: 3, 100: 2}) assert [vm.run_context.memory[101 + i] for i in range(6)] == [ - (2 * PRIME + 2) // 3, (2 * PRIME + 2) // 3, - PRIME - 1, PRIME - 1, - 2, 2 + (2 * PRIME + 2) // 3, + (2 * PRIME + 2) // 3, + PRIME - 1, + PRIME - 1, + 2, + 2, ] @@ -638,7 +671,7 @@ def test_failing_assert_eq(): [ap] = [ap + 1] + [ap + 2] """ - with pytest.raises(VmException, match='An ASSERT_EQ instruction failed'): + with pytest.raises(VmException, match="An ASSERT_EQ instruction failed"): run_single(code, 1, extra_mem={100: 1, 101: 3, 102: 2}) @@ -646,7 +679,7 @@ def test_call_unknown(): code = """ call rel [ap] """ - with pytest.raises(VmException, match='Unknown value for memory cell at address 100'): + with pytest.raises(VmException, match="Unknown value for memory cell at address 100"): run_single(code, 1) @@ -655,12 +688,16 @@ def test_call_wrong_operands(): call rel 0 """ with pytest.raises( - VmException, match=r'Call failed to write return-pc \(inconsistent op0\): 0 != 0:12. ' + - 'Did you forget to increment ap?'): + VmException, + match=r"Call failed to write return-pc \(inconsistent op0\): 0 != 0:12. " + + "Did you forget to increment ap?", + ): run_single(code, 1, extra_mem={101: 0}) with pytest.raises( - VmException, match=r'Call failed to write return-fp \(inconsistent dst\): 0 != 100. ' + - 'Did you forget to increment ap?'): + VmException, + match=r"Call failed to write return-fp \(inconsistent dst\): 0 != 100. " + + "Did you forget to increment ap?", + ): run_single(code, 1, extra_mem={100: 0}) @@ -688,7 +725,9 @@ def test_traceback(): with pytest.raises(VmException) as exc_info: run_single(code, 100, ap=101, extra_mem={99: 3, 100: 2}) - assert str(exc_info.value) == """\ + assert ( + str(exc_info.value) + == """\ :5:9: Error at pc=0:12: Got an exception while executing a hint. %{ assert ids.x != 0 %} @@ -708,3 +747,4 @@ def test_traceback(): File "", line 5, in AssertionError\ """ + ) diff --git a/src/starkware/cairo/sharp/client_lib.py b/src/starkware/cairo/sharp/client_lib.py index a35e99ba..b4d92b8e 100644 --- a/src/starkware/cairo/sharp/client_lib.py +++ b/src/starkware/cairo/sharp/client_lib.py @@ -27,9 +27,10 @@ def add_job(self, cairo_pie: CairoPie) -> str: """ res = self._send( - 'add_job', {'cairo_pie': base64.b64encode(cairo_pie.serialize()).decode('ascii')}) - assert 'cairo_job_key' in res, f'Error when sending job to SHARP: {res}.' - return res['cairo_job_key'] + "add_job", {"cairo_pie": base64.b64encode(cairo_pie.serialize()).decode("ascii")} + ) + assert "cairo_job_key" in res, f"Error when sending job to SHARP: {res}." + return res["cairo_job_key"] def get_status(self, job_key: str) -> str: """ @@ -37,10 +38,9 @@ def get_status(self, job_key: str) -> str: job_key: used to query the state of the job in the system - returned by 'add_job'. """ - res = self._send('get_status', {'cairo_job_key': job_key}) - assert 'status' in res, \ - f"Error when checking status of job with key '{job_key}': {res}." - return res['status'] + res = self._send("get_status", {"cairo_job_key": job_key}) + assert "status" in res, f"Error when checking status of job with key '{job_key}': {res}." + return res["status"] def _send(self, action: str, payload: dict) -> dict: """ @@ -50,10 +50,11 @@ def _send(self, action: str, payload: dict) -> dict: """ data = { - 'action': action, - 'request': payload, + "action": action, + "request": payload, } http = urllib3.PoolManager() res = http.request( - method='POST', url=self.service_url, body=json.dumps(data).encode('utf-8')) - return json.loads(res.data.decode('utf-8')) + method="POST", url=self.service_url, body=json.dumps(data).encode("utf-8") + ) + return json.loads(res.data.decode("utf-8")) diff --git a/src/starkware/cairo/sharp/client_lib_test.py b/src/starkware/cairo/sharp/client_lib_test.py index e43e0897..ba140608 100644 --- a/src/starkware/cairo/sharp/client_lib_test.py +++ b/src/starkware/cairo/sharp/client_lib_test.py @@ -12,8 +12,9 @@ class MockCairoPie: """ Mock classes used in the test. """ + def serialize(self): - return b'' + return b"" @dataclasses.dataclass @@ -22,20 +23,21 @@ class Response: def test_add_job(monkeypatch): - expected_url = 'some url' + expected_url = "some url" expected_data = { - 'action': 'add_job', - 'request': {'cairo_pie': base64.b64encode(MockCairoPie().serialize()).decode('ascii')} + "action": "add_job", + "request": {"cairo_pie": base64.b64encode(MockCairoPie().serialize()).decode("ascii")}, } - expected_res = 'some id' + expected_res = "some id" # A mock function enforcing expected scenario. def check_expected(_, method: str, url: str, body: str): - assert method == 'POST' + assert method == "POST" assert url == expected_url assert json.loads(body) == expected_data - return Response(json.dumps({'cairo_job_key': expected_res}).encode('utf-8')) - monkeypatch.setattr(PoolManager, 'request', check_expected) + return Response(json.dumps({"cairo_job_key": expected_res}).encode("utf-8")) + + monkeypatch.setattr(PoolManager, "request", check_expected) # Test the scenario. client = ClientLib(expected_url) @@ -44,21 +46,19 @@ def check_expected(_, method: str, url: str, body: str): def test_get_status(monkeypatch): - expected_url = 'some url' - expected_id = 'some id' - expected_data = { - 'action': 'get_status', - 'request': {'cairo_job_key': expected_id} - } - expected_res = 'the status' + expected_url = "some url" + expected_id = "some id" + expected_data = {"action": "get_status", "request": {"cairo_job_key": expected_id}} + expected_res = "the status" # A mock function enforcing expected scenario. def check_expected(_, method: str, url: str, body: str): - assert method == 'POST' + assert method == "POST" assert url == expected_url assert json.loads(body) == expected_data - return Response(json.dumps({'status': expected_res}).encode('utf-8')) - monkeypatch.setattr(PoolManager, 'request', check_expected) + return Response(json.dumps({"status": expected_res}).encode("utf-8")) + + monkeypatch.setattr(PoolManager, "request", check_expected) # Test the scenario. client = ClientLib(expected_url) @@ -70,14 +70,15 @@ def test_error(monkeypatch): # A mock function enforcing expected scenario. def check_expected(_, method: str, url: str, body: str): # Return an empty response - this should be invalid. - return Response(b'{}') - monkeypatch.setattr(PoolManager, 'request', check_expected) + return Response(b"{}") + + monkeypatch.setattr(PoolManager, "request", check_expected) # Test the scenario. - client = ClientLib('') + client = ClientLib("") - with pytest.raises(AssertionError, match='Error when sending job to SHARP:'): + with pytest.raises(AssertionError, match="Error when sending job to SHARP:"): client.add_job(MockCairoPie()) - with pytest.raises(AssertionError, match='Error when checking status of job with key'): - client.get_status('') + with pytest.raises(AssertionError, match="Error when checking status of job with key"): + client.get_status("") diff --git a/src/starkware/cairo/sharp/fact_checker.py b/src/starkware/cairo/sharp/fact_checker.py index 7979de08..999c9cc8 100644 --- a/src/starkware/cairo/sharp/fact_checker.py +++ b/src/starkware/cairo/sharp/fact_checker.py @@ -2,25 +2,13 @@ FACT_REGISTRY_ABI = [ { - 'constant': True, - 'inputs': [ - { - 'internalType': 'bytes32', - 'name': 'fact', - 'type': 'bytes32' - } - ], - 'name': 'isValid', - 'outputs': [ - { - 'internalType': 'bool', - 'name': '', - 'type': 'bool' - } - ], - 'payable': False, - 'stateMutability': 'view', - 'type': 'function' + "constant": True, + "inputs": [{"internalType": "bytes32", "name": "fact", "type": "bytes32"}], + "name": "isValid", + "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], + "payable": False, + "stateMutability": "view", + "type": "function", } ] @@ -39,7 +27,8 @@ def __init__(self, fact_registry_address: str, node_rpc_url: str): # Initialize a contract instance, used to query the fact-registry contract. w3 = Web3(HTTPProvider(node_rpc_url)) self.contract = w3.eth.contract( # type: ignore - address=fact_registry_address, abi=FACT_REGISTRY_ABI) + address=fact_registry_address, abi=FACT_REGISTRY_ABI + ) def is_valid(self, fact: str) -> bool: """ diff --git a/src/starkware/cairo/sharp/fact_checker_test.py b/src/starkware/cairo/sharp/fact_checker_test.py index 97dc3595..b8618f99 100644 --- a/src/starkware/cairo/sharp/fact_checker_test.py +++ b/src/starkware/cairo/sharp/fact_checker_test.py @@ -6,4 +6,4 @@ def test_init(): Initializes the FactChecker. This test is a basic sanity check. """ - FactChecker(fact_registry_address='', node_rpc_url='') + FactChecker(fact_registry_address="", node_rpc_url="") diff --git a/src/starkware/cairo/sharp/sharp_client.py b/src/starkware/cairo/sharp/sharp_client.py index e6c2a1a8..b758fae3 100755 --- a/src/starkware/cairo/sharp/sharp_client.py +++ b/src/starkware/cairo/sharp/sharp_client.py @@ -22,8 +22,13 @@ class SharpClient: """ def __init__( - self, service_client: ClientLib, contract_client: FactChecker, - steps_limit: int, cairo_compiler_path: str, cairo_run_path: str): + self, + service_client: ClientLib, + contract_client: FactChecker, + steps_limit: int, + cairo_compiler_path: str, + cairo_run_path: str, + ): """ service_client: a client to communicate with the proving service. contract_client: a client to inspect verified statements. @@ -43,14 +48,17 @@ def compile_cairo(self, source_code_path: str, flags: Optional[List[str]] = None and returns the compiled program. """ used_flags = [] if flags is None else flags - with tempfile.NamedTemporaryFile('w') as compiled_program_file: + with tempfile.NamedTemporaryFile("w") as compiled_program_file: # Compile the program. - subprocess.check_call([ - self.cairo_compiler_path, - source_code_path, - f'--output={compiled_program_file.name}', - ] + used_flags) - program = Program.Schema().load(json.load(open(compiled_program_file.name, 'r'))) + subprocess.check_call( + [ + self.cairo_compiler_path, + source_code_path, + f"--output={compiled_program_file.name}", + ] + + used_flags + ) + program = Program.Schema().load(json.load(open(compiled_program_file.name, "r"))) return program def run_program(self, program: Program, program_input_path: Optional[str]) -> CairoPie: @@ -58,18 +66,25 @@ def run_program(self, program: Program, program_input_path: Optional[str]) -> Ca Runs the program, with the provided input, and returns the Cairo PIE (Position Independent Execution). """ - with tempfile.NamedTemporaryFile('w') as cairo_pie_file, \ - tempfile.NamedTemporaryFile('w') as program_file: - json.dump( - Program.Schema().dump(program), program_file, indent=4, sort_keys=True) + with tempfile.NamedTemporaryFile("w") as cairo_pie_file, tempfile.NamedTemporaryFile( + "w" + ) as program_file: + json.dump(Program.Schema().dump(program), program_file, indent=4, sort_keys=True) program_file.flush() - cairo_run_cmd = list(filter(None, [ - self.cairo_run_path, - '--layout=all', - f'--program={program_file.name}', - f'--program_input={program_input_path}' if program_input_path is not None else None, - f'--cairo_pie_output={cairo_pie_file.name}', - ])) + cairo_run_cmd = list( + filter( + None, + [ + self.cairo_run_path, + "--layout=all", + f"--program={program_file.name}", + f"--program_input={program_input_path}" + if program_input_path is not None + else None, + f"--cairo_pie_output={cairo_pie_file.name}", + ], + ) + ) subprocess.check_call(cairo_run_cmd) cairo_pie = CairoPie.from_file(cairo_pie_file.name) return cairo_pie @@ -95,9 +110,10 @@ def submit_cairo_pie(self, cairo_pie: CairoPie) -> str: Asserts that the number of execution steps does not exceed the allowed limit. """ n_steps = cairo_pie.execution_resources.n_steps - assert n_steps < self.steps_limit, \ - f'Execution trace length exceeds limit. The execution length is {n_steps} ' \ - f'and the limit is {self.steps_limit}.' + assert n_steps < self.steps_limit, ( + f"Execution trace length exceeds limit. The execution length is {n_steps} " + f"and the limit is {self.steps_limit}." + ) return self.service_client.add_job(cairo_pie=cairo_pie) @@ -105,7 +121,7 @@ def job_failed(self, job_key: str) -> bool: """ Returns True if and only if the job has failed, thus is not expected to be proven. """ - return self.service_client.get_status(job_key) in ['INVALID', 'FAILED'] + return self.service_client.get_status(job_key) in ["INVALID", "FAILED"] def get_job_status(self, job_key: str) -> str: """ @@ -126,22 +142,22 @@ def init_client(bin_dir: str, node_rpc_url: Optional[str] = None) -> SharpClient Initialized a SharpClient instance, with or without node access. """ # Load configuration file. - CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json') - with open(CONFIG_PATH, 'r') as config_file: + CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.json") + with open(CONFIG_PATH, "r") as config_file: config = json.load(config_file) # Get Cairo toolchain executable paths. - CAIRO_COMPILE_EXE = os.path.join(os.path.join(bin_dir, 'cairo-compile')) - CAIRO_RUN_EXE = os.path.join(os.path.join(bin_dir, 'cairo-run')) + CAIRO_COMPILE_EXE = os.path.join(os.path.join(bin_dir, "cairo-compile")) + CAIRO_RUN_EXE = os.path.join(os.path.join(bin_dir, "cairo-run")) # Initialize the SharpClient. client = SharpClient( - service_client=ClientLib(config['prover_url']), + service_client=ClientLib(config["prover_url"]), contract_client=FactChecker( - fact_registry_address=config['verifier_address'], - node_rpc_url=node_rpc_url if node_rpc_url is not None else '' + fact_registry_address=config["verifier_address"], + node_rpc_url=node_rpc_url if node_rpc_url is not None else "", ), - steps_limit=config['steps_limit'], + steps_limit=config["steps_limit"], cairo_compiler_path=CAIRO_COMPILE_EXE, cairo_run_path=CAIRO_RUN_EXE, ) @@ -151,59 +167,62 @@ def init_client(bin_dir: str, node_rpc_url: Optional[str] = None) -> SharpClient def submit(args, command_args): parser = argparse.ArgumentParser( - description='Submits a Cairo job to SHARP. ' - 'You can provide (1) the source code and the program input OR (2) the compiled program and ' - 'the program input OR (3) the Cairo PIE.') + description="Submits a Cairo job to SHARP. " + "You can provide (1) the source code and the program input OR (2) the compiled program and " + "the program input OR (3) the Cairo PIE." + ) parser.add_argument( - '--source', type=str, required=False, help='A path to the Cairo source code.') - parser.add_argument( - '--program', type=str, required=False, help='A path to the compiled program.') + "--source", type=str, required=False, help="A path to the Cairo source code." + ) parser.add_argument( - '--program_input', type=str, required=False, help='A path to the program input.') + "--program", type=str, required=False, help="A path to the compiled program." + ) parser.add_argument( - '--cairo_pie', type=str, required=False, help='A path to the Cairo PIE.') + "--program_input", type=str, required=False, help="A path to the program input." + ) + parser.add_argument("--cairo_pie", type=str, required=False, help="A path to the Cairo PIE.") parser.parse_args(command_args, namespace=args) is_not_none = lambda x: 1 if x is not None else 0 assert ( - is_not_none(args.source) + is_not_none(args.program) + is_not_none(args.cairo_pie) == 1), \ - 'Exactly one of --source, --program, --cairo_pie must be specified.' + is_not_none(args.source) + is_not_none(args.program) + is_not_none(args.cairo_pie) == 1 + ), "Exactly one of --source, --program, --cairo_pie must be specified." client = init_client(bin_dir=args.bin_dir) if args.cairo_pie is not None: - assert args.program_input is None, \ - 'Error: --program_input cannot be specified with --cairo_pie.' + assert ( + args.program_input is None + ), "Error: --program_input cannot be specified with --cairo_pie." cairo_pie = CairoPie.from_file(args.cairo_pie) else: if args.program is not None: program = Program.Schema().load(json.load(open(args.program))) else: assert args.source is not None - print('Compiling...', file=sys.stderr) + print("Compiling...", file=sys.stderr) program = client.compile_cairo(source_code_path=args.source) - print('Running...', file=sys.stderr) + print("Running...", file=sys.stderr) cairo_pie = client.run_program(program=program, program_input_path=args.program_input) fact = client.get_fact(cairo_pie) - print('Submitting to SHARP...', file=sys.stderr) + print("Submitting to SHARP...", file=sys.stderr) job_key = client.submit_cairo_pie(cairo_pie=cairo_pie) - print('Job sent.', file=sys.stderr) + print("Job sent.", file=sys.stderr) - print(f'Job key: {job_key}') - print(f'Fact: {fact}') + print(f"Job key: {job_key}") + print(f"Fact: {fact}") return 0 def get_job_status(args, command_args): - parser = argparse.ArgumentParser( - description='Retreive the status of a SHARP Cairo job.') - parser.add_argument('job_key', type=str, help='The key identifying the job.') + parser = argparse.ArgumentParser(description="Retreive the status of a SHARP Cairo job.") + parser.add_argument("job_key", type=str, help="The key identifying the job.") parser.parse_args(command_args, namespace=args) @@ -219,10 +238,12 @@ def is_verified(args, command_args): The fact is provided in the command args. """ parser = argparse.ArgumentParser( - description='Verify a fact is registered on the SHARP fact-registry.') - parser.add_argument('fact', type=str, help='The fact to verify if registered.') + description="Verify a fact is registered on the SHARP fact-registry." + ) + parser.add_argument("fact", type=str, help="The fact to verify if registered.") parser.add_argument( - '--node_url', required=True, type=str, help='URL for a Goerli Ethereum node RPC API.') + "--node_url", required=True, type=str, help="URL for a Goerli Ethereum node RPC API." + ) parser.parse_args(command_args, namespace=args) @@ -234,20 +255,27 @@ def is_verified(args, command_args): def main(): subparsers = { - 'submit': submit, - 'status': get_job_status, - 'is_verified': is_verified, + "submit": submit, + "status": get_job_status, + "is_verified": is_verified, } - parser = argparse.ArgumentParser(description='A tool to communicate with SHARP.') - parser.add_argument('command', choices=subparsers.keys()) + parser = argparse.ArgumentParser(description="A tool to communicate with SHARP.") + parser.add_argument("command", choices=subparsers.keys()) parser.add_argument( - '--bin_dir', type=str, default='', - help='The path to a directory that contains the cairo-compile and cairo-run scripts. ' - "If not specified, files are assumed to be in the system's PATH.") + "--bin_dir", + type=str, + default="", + help="The path to a directory that contains the cairo-compile and cairo-run scripts. " + "If not specified, files are assumed to be in the system's PATH.", + ) parser.add_argument( - '--flavor', type=str, default='Release', choices=['Debug', 'Release', 'RelWithDebInfo'], - help='Build flavor') + "--flavor", + type=str, + default="Release", + choices=["Debug", "Release", "RelWithDebInfo"], + help="Build flavor", + ) args, unknown = parser.parse_known_args() @@ -256,8 +284,8 @@ def main(): # Invoke the requested command. return subparsers[args.command](args, unknown) except Exception as exc: - print(f'Error: {exc}', file=sys.stderr) + print(f"Error: {exc}", file=sys.stderr) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/src/starkware/cairo/sharp/sharp_client_test.py b/src/starkware/cairo/sharp/sharp_client_test.py index e9ae51ed..f2f99997 100644 --- a/src/starkware/cairo/sharp/sharp_client_test.py +++ b/src/starkware/cairo/sharp/sharp_client_test.py @@ -8,9 +8,9 @@ from starkware.cairo.sharp.sharp_client import SharpClient DIR = os.path.dirname(__file__) -CAIRO_SCRIPTS_DIR = os.path.join(DIR, '../lang/scripts') -CAIRO_COMPILE_EXE = os.path.join(CAIRO_SCRIPTS_DIR, 'cairo-compile') -CAIRO_RUN_EXE = os.path.join(CAIRO_SCRIPTS_DIR, 'cairo-run') +CAIRO_SCRIPTS_DIR = os.path.join(DIR, "../lang/scripts") +CAIRO_COMPILE_EXE = os.path.join(CAIRO_SCRIPTS_DIR, "cairo-compile") +CAIRO_RUN_EXE = os.path.join(CAIRO_SCRIPTS_DIR, "cairo-run") def test_compile_and_run(): @@ -19,8 +19,12 @@ def test_compile_and_run(): Verifies the output of the execution is as expected. """ client = SharpClient( - service_client=None, contract_client=None, steps_limit=0, - cairo_compiler_path=CAIRO_COMPILE_EXE, cairo_run_path=CAIRO_RUN_EXE) + service_client=None, + contract_client=None, + steps_limit=0, + cairo_compiler_path=CAIRO_COMPILE_EXE, + cairo_run_path=CAIRO_RUN_EXE, + ) cairo_program = """ %builtins output @@ -31,18 +35,18 @@ def test_compile_and_run(): return (output_ptr=output_ptr + 1) end """ - program_input = {'x': 3} + program_input = {"x": 3} - with tempfile.NamedTemporaryFile('w') as cairo_prog_file: + with tempfile.NamedTemporaryFile("w") as cairo_prog_file: cairo_prog_file.write(cairo_program) cairo_prog_file.flush() compiled_program = client.compile_cairo(cairo_prog_file.name) - with tempfile.NamedTemporaryFile('w') as prog_input_file: + with tempfile.NamedTemporaryFile("w") as prog_input_file: prog_input_file.write(json.dumps(program_input)) prog_input_file.flush() cairo_pie = client.run_program(compiled_program, prog_input_file.name) - assert get_program_output(cairo_pie) == [3**2] + assert get_program_output(cairo_pie) == [3 ** 2] def test_get_fact(monkeypatch): @@ -56,60 +60,78 @@ def __init__(self, program: str, output: str): self.output = output client = SharpClient( - service_client=None, contract_client=None, steps_limit=0, - cairo_compiler_path='', cairo_run_path='') + service_client=None, + contract_client=None, + steps_limit=0, + cairo_compiler_path="", + cairo_run_path="", + ) monkeypatch.setattr( - sharp_client, 'compute_program_hash_chain', lambda program: f'hash({program})') + sharp_client, "compute_program_hash_chain", lambda program: f"hash({program})" + ) monkeypatch.setattr( - sharp_client, 'get_cairo_pie_fact_info', + sharp_client, + "get_cairo_pie_fact_info", lambda cairo_pie, program_hash: FactInfo( - fact=f'hash({program_hash}, hash({cairo_pie.output}))', + fact=f"hash({program_hash}, hash({cairo_pie.output}))", program_output=None, - fact_topology=None)) + fact_topology=None, + ), + ) - assert client.get_fact(CairoPieStub('program', 'output')) == 'hash(hash(program), hash(output))' + assert client.get_fact(CairoPieStub("program", "output")) == "hash(hash(program), hash(output))" def test_fact_registered(): """ Tests that fact_registered() checks facts as expected, using FactChecker mock. """ + class FactCheckerStub: def is_valid(self, fact: str) -> bool: - return fact == 'valid' + return fact == "valid" client = SharpClient( - service_client=None, contract_client=FactCheckerStub(), steps_limit=0, - cairo_compiler_path='', cairo_run_path='') + service_client=None, + contract_client=FactCheckerStub(), + steps_limit=0, + cairo_compiler_path="", + cairo_run_path="", + ) - assert client.fact_registered('valid') - assert not client.fact_registered('not valid') + assert client.fact_registered("valid") + assert not client.fact_registered("not valid") def test_job_failed(): """ Tests that job_failed() interacts with the SHARP service correctly, using ClientLib mock. """ + class ClientLibStub: def get_status(self, job_key): - if job_key == 'invalid_job': - return 'INVALID' - if job_key == 'failed_job': - return 'FAILED' - return 'Success' + if job_key == "invalid_job": + return "INVALID" + if job_key == "failed_job": + return "FAILED" + return "Success" client = SharpClient( - service_client=ClientLibStub(), contract_client=None, steps_limit=0, - cairo_compiler_path='', cairo_run_path='') + service_client=ClientLibStub(), + contract_client=None, + steps_limit=0, + cairo_compiler_path="", + cairo_run_path="", + ) # Test job_failed() - assert client.job_failed('invalid_job') - assert client.job_failed('failed_job') - assert not client.job_failed('valid_job') + assert client.job_failed("invalid_job") + assert client.job_failed("failed_job") + assert not client.job_failed("valid_job") # Test get_status() - assert client.get_job_status('valid_job') == 'Success' - assert client.get_job_status('invalid_job') == 'INVALID' - assert client.get_job_status('failed_job') == 'FAILED' + assert client.get_job_status("valid_job") == "Success" + assert client.get_job_status("invalid_job") == "INVALID" + assert client.get_job_status("failed_job") == "FAILED" diff --git a/src/starkware/crypto/starkware/crypto/signature/fast_pedersen_hash.py b/src/starkware/crypto/starkware/crypto/signature/fast_pedersen_hash.py index aae758d4..658d5fbe 100644 --- a/src/starkware/crypto/starkware/crypto/signature/fast_pedersen_hash.py +++ b/src/starkware/crypto/starkware/crypto/signature/fast_pedersen_hash.py @@ -2,18 +2,19 @@ from fastecdsa.point import Point from starkware.crypto.signature.signature import ( - ALPHA, BETA, CONSTANT_POINTS, EC_ORDER, FIELD_PRIME, N_ELEMENT_BITS_HASH, SHIFT_POINT) - -curve = Curve( - 'Curve0', - FIELD_PRIME, ALPHA, BETA, + CONSTANT_POINTS, EC_ORDER, - *SHIFT_POINT) + FIELD_PRIME, + N_ELEMENT_BITS_HASH, + SHIFT_POINT, +) + +curve = Curve("Curve0", FIELD_PRIME, ALPHA, BETA, EC_ORDER, *SHIFT_POINT) LOW_PART_BITS = 248 -LOW_PART_MASK = 2**248 - 1 +LOW_PART_MASK = 2 ** 248 - 1 HASH_SHIFT_POINT = Point(*SHIFT_POINT, curve=curve) P_0 = Point(*CONSTANT_POINTS[2], curve=curve) P_1 = Point(*CONSTANT_POINTS[2 + LOW_PART_BITS], curve=curve) @@ -22,7 +23,7 @@ def process_single_element(element: int, p1, p2) -> Point: - assert element < FIELD_PRIME, 'Element integer value >= FIELD_PRIME' + assert element < FIELD_PRIME, "Element integer value >= FIELD_PRIME" high_nibble = element >> LOW_PART_BITS low_part = element & LOW_PART_MASK @@ -37,17 +38,19 @@ def pedersen_hash(x: int, y: int) -> int: where x_low is the 248 low bits of x, x_high is the 4 high bits of x and similarly for y. shift_point, P_0, P_1, P_2, P_3 are constant points generated from the digits of pi. """ - return (HASH_SHIFT_POINT + process_single_element(x, P_0, P_1) + - process_single_element(y, P_2, P_3)).x + return ( + HASH_SHIFT_POINT + process_single_element(x, P_0, P_1) + process_single_element(y, P_2, P_3) + ).x def pedersen_hash_func(x: bytes, y: bytes) -> bytes: """ A variant of 'pedersen_hash', where the elements and their resulting hash are in bytes. """ - assert len(x) == len(y) == 32, 'Unexpected element length.' + assert len(x) == len(y) == 32, "Unexpected element length." return pedersen_hash( - *(int.from_bytes(element, 'big', signed=False) for element in (x, y))).to_bytes(32, 'big') + *(int.from_bytes(element, "big", signed=False) for element in (x, y)) + ).to_bytes(32, "big") async def async_pedersen_hash_func(x: bytes, y: bytes) -> bytes: diff --git a/src/starkware/crypto/starkware/crypto/signature/math_utils.py b/src/starkware/crypto/starkware/crypto/signature/math_utils.py index e2a901ea..2dc3a4d4 100644 --- a/src/starkware/crypto/starkware/crypto/signature/math_utils.py +++ b/src/starkware/crypto/starkware/crypto/signature/math_utils.py @@ -30,7 +30,7 @@ def pi_as_string(digits: int) -> str: Returns pi as a string of decimal digits without the decimal point ("314..."). """ mpmath.mp.dps = digits # Set number of digits. - return '3' + str(mpmath.mp.pi)[2:] + return "3" + str(mpmath.mp.pi)[2:] def is_quad_residue(n: int, p: int) -> bool: diff --git a/src/starkware/crypto/starkware/crypto/signature/nothing_up_my_sleeve_gen.py b/src/starkware/crypto/starkware/crypto/signature/nothing_up_my_sleeve_gen.py index 4efdb656..78d1632f 100644 --- a/src/starkware/crypto/starkware/crypto/signature/nothing_up_my_sleeve_gen.py +++ b/src/starkware/crypto/starkware/crypto/signature/nothing_up_my_sleeve_gen.py @@ -32,20 +32,21 @@ # (a) large, # (b) has a big multiplicative subgroup of size which is a power of two, # (c) sparse representation for efficient modular arithmetics. -FIELD_PRIME = 2**251 + 17 * 2**192 + 1 +FIELD_PRIME = 2 ** 251 + 17 * 2 ** 192 + 1 # Generator of the multiplicative group of the field. FIELD_GEN = 3 # Elliptic curve parameters. ALPHA = 1 -EC_ORDER = 0x800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f +EC_ORDER = 0x800000000000010FFFFFFFFFFFFFFFFB781126DCAE7B2321E66A241ADC64D2F ############################ # Parameters and constants # ############################ + def generate_constant_points(n_points): """ Generates points from the curve y^2 = x^3 + x + beta over GF(FIELD_PRIME) where beta and the @@ -73,9 +74,9 @@ def generate_constant_points(n_points): i = 0 while len(constant_points) < n_points: i += 1 - x = int(pi_str[i * 76:(i + 1) * 76]) + x = int(pi_str[i * 76 : (i + 1) * 76]) while True: - y_squared = x**3 + ALPHA * x + beta + y_squared = x ** 3 + ALPHA * x + beta if is_quad_residue(y_squared, FIELD_PRIME): y = sqrt_mod(y_squared, FIELD_PRIME) break @@ -98,9 +99,9 @@ def generate_constant_points(n_points): N_ECDSA_POINTS = 1 N_HASH_POINTS = N_INPUTS * N_ELEMENT_BITS -print('Generating points, this may take a while...') +print("Generating points, this may take a while...") BETA, CONSTANT_POINTS = generate_constant_points(N_SHIFT_POINTS + N_ECDSA_POINTS + N_HASH_POINTS) -assert BETA == 0x6f21413efbe40de150e596d72f7a8c5609ad26c15c915c1f4cdfcb99cee9e89 +assert BETA == 0x6F21413EFBE40DE150E596D72F7A8C5609AD26C15C915C1F4CDFCB99CEE9E89 COPYRIGHT_STRING = """\ ############################################################################### @@ -120,19 +121,22 @@ def generate_constant_points(n_points): ############################################################################### """ -AUTO_GENERATED_STRING = \ - 'The following data was auto-generated. PLEASE DO NOT EDIT.' +AUTO_GENERATED_STRING = "The following data was auto-generated. PLEASE DO NOT EDIT." # Write generated parameters to file. -PEDERSEN_HASH_POINT_FILENAME = os.path.join( - os.path.dirname(__file__), 'pedersen_params.json') -open(PEDERSEN_HASH_POINT_FILENAME, 'w').write( - json.dumps({ - '_license': COPYRIGHT_STRING.splitlines(), - '_comment': AUTO_GENERATED_STRING, - 'FIELD_PRIME': FIELD_PRIME, - 'FIELD_GEN': FIELD_GEN, - 'EC_ORDER': EC_ORDER, - 'ALPHA': ALPHA, - 'BETA': BETA, - 'CONSTANT_POINTS': CONSTANT_POINTS}, indent=4)) +PEDERSEN_HASH_POINT_FILENAME = os.path.join(os.path.dirname(__file__), "pedersen_params.json") +open(PEDERSEN_HASH_POINT_FILENAME, "w").write( + json.dumps( + { + "_license": COPYRIGHT_STRING.splitlines(), + "_comment": AUTO_GENERATED_STRING, + "FIELD_PRIME": FIELD_PRIME, + "FIELD_GEN": FIELD_GEN, + "EC_ORDER": EC_ORDER, + "ALPHA": ALPHA, + "BETA": BETA, + "CONSTANT_POINTS": CONSTANT_POINTS, + }, + indent=4, + ) +) diff --git a/src/starkware/crypto/starkware/crypto/signature/signature.py b/src/starkware/crypto/starkware/crypto/signature/signature.py index 31a9dac9..d8988669 100644 --- a/src/starkware/crypto/starkware/crypto/signature/signature.py +++ b/src/starkware/crypto/starkware/crypto/signature/signature.py @@ -24,18 +24,24 @@ from ecdsa.rfc6979 import generate_k from starkware.crypto.signature.math_utils import ( - ECPoint, div_mod, ec_add, ec_double, ec_mult, is_quad_residue, sqrt_mod) - -PEDERSEN_HASH_POINT_FILENAME = os.path.join( - os.path.dirname(__file__), 'pedersen_params.json') + ECPoint, + div_mod, + ec_add, + ec_double, + ec_mult, + is_quad_residue, + sqrt_mod, +) + +PEDERSEN_HASH_POINT_FILENAME = os.path.join(os.path.dirname(__file__), "pedersen_params.json") PEDERSEN_PARAMS = json.load(open(PEDERSEN_HASH_POINT_FILENAME)) -FIELD_PRIME = PEDERSEN_PARAMS['FIELD_PRIME'] -FIELD_GEN = PEDERSEN_PARAMS['FIELD_GEN'] -ALPHA = PEDERSEN_PARAMS['ALPHA'] -BETA = PEDERSEN_PARAMS['BETA'] -EC_ORDER = PEDERSEN_PARAMS['EC_ORDER'] -CONSTANT_POINTS = PEDERSEN_PARAMS['CONSTANT_POINTS'] +FIELD_PRIME = PEDERSEN_PARAMS["FIELD_PRIME"] +FIELD_GEN = PEDERSEN_PARAMS["FIELD_GEN"] +ALPHA = PEDERSEN_PARAMS["ALPHA"] +BETA = PEDERSEN_PARAMS["BETA"] +EC_ORDER = PEDERSEN_PARAMS["EC_ORDER"] +CONSTANT_POINTS = PEDERSEN_PARAMS["CONSTANT_POINTS"] N_ELEMENT_BITS_ECDSA = math.floor(math.log(FIELD_PRIME, 2)) assert N_ELEMENT_BITS_ECDSA == 251 @@ -44,18 +50,20 @@ assert N_ELEMENT_BITS_HASH == 252 # Elliptic curve parameters. -assert 2**N_ELEMENT_BITS_ECDSA < EC_ORDER < FIELD_PRIME +assert 2 ** N_ELEMENT_BITS_ECDSA < EC_ORDER < FIELD_PRIME SHIFT_POINT = CONSTANT_POINTS[0] MINUS_SHIFT_POINT = (SHIFT_POINT[0], FIELD_PRIME - SHIFT_POINT[1]) EC_GEN = CONSTANT_POINTS[1] assert SHIFT_POINT == [ - 0x49ee3eba8c1600700ee1b87eb599f16716b0b1022947733551fde4050ca6804, - 0x3ca0cfe4b3bc6ddf346d49d06ea0ed34e621062c0e056c1d0405d266e10268a] + 0x49EE3EBA8C1600700EE1B87EB599F16716B0B1022947733551FDE4050CA6804, + 0x3CA0CFE4B3BC6DDF346D49D06EA0ED34E621062C0E056C1D0405D266E10268A, +] assert EC_GEN == [ - 0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca, - 0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f] + 0x1EF15C18599971B7BECED415A40F0C7DEACFD9B0D1819E03D723D8BC943CFCA, + 0x5668060AA49730B7BE4801DF46EC62DE53ECD11ABE43A32873000C36E8DC1F, +] ######### @@ -68,7 +76,7 @@ class InvalidPublicKeyError(Exception): def __init__(self): - super().__init__('Given x coordinate does not represent any point on the elliptic curve.') + super().__init__("Given x coordinate does not represent any point on the elliptic curve.") def get_y_coordinate(stark_key_x_coordinate: int) -> int: @@ -111,21 +119,24 @@ def generate_k_rfc6979(msg_hash: int, priv_key: int, seed: Optional[int] = None) msg_hash *= 16 if seed is None: - extra_entropy = b'' + extra_entropy = b"" else: - extra_entropy = seed.to_bytes(math.ceil(seed.bit_length() / 8), 'big') + extra_entropy = seed.to_bytes(math.ceil(seed.bit_length() / 8), "big") return generate_k( - EC_ORDER, priv_key, hashlib.sha256, - msg_hash.to_bytes(math.ceil(msg_hash.bit_length() / 8), 'big'), - extra_entropy=extra_entropy) + EC_ORDER, + priv_key, + hashlib.sha256, + msg_hash.to_bytes(math.ceil(msg_hash.bit_length() / 8), "big"), + extra_entropy=extra_entropy, + ) def sign(msg_hash: int, priv_key: int, seed: Optional[int] = None) -> ECSignature: # Note: msg_hash must be smaller than 2**N_ELEMENT_BITS_ECDSA. # Message whose hash is >= 2**N_ELEMENT_BITS_ECDSA cannot be signed. # This happens with a very small probability. - assert 0 <= msg_hash < 2**N_ELEMENT_BITS_ECDSA, 'Message not signable.' + assert 0 <= msg_hash < 2 ** N_ELEMENT_BITS_ECDSA, "Message not signable." # Choose a valid k. In our version of ECDSA not every k value is valid, # and there is a negligible probability a drawn k cannot be used for signing. @@ -143,7 +154,7 @@ def sign(msg_hash: int, priv_key: int, seed: Optional[int] = None) -> ECSignatur # DIFF: in classic ECDSA, we take int(x) % n. r = int(x) - if not (1 <= r < 2**N_ELEMENT_BITS_ECDSA): + if not (1 <= r < 2 ** N_ELEMENT_BITS_ECDSA): # Bad value. This fails with negligible probability. continue @@ -152,7 +163,7 @@ def sign(msg_hash: int, priv_key: int, seed: Optional[int] = None) -> ECSignatur continue w = div_mod(k, msg_hash + r * priv_key, EC_ORDER) - if not (1 <= w < 2**N_ELEMENT_BITS_ECDSA): + if not (1 <= w < 2 ** N_ELEMENT_BITS_ECDSA): # Bad value. This fails with negligible probability. continue @@ -165,7 +176,7 @@ def mimic_ec_mult_air(m: int, point: ECPoint, shift_point: ECPoint) -> ECPoint: Computes m * point + shift_point using the same steps like the AIR and throws an exception if and only if the AIR errors. """ - assert 0 < m < 2**N_ELEMENT_BITS_ECDSA + assert 0 < m < 2 ** N_ELEMENT_BITS_ECDSA partial_sum = shift_point for _ in range(N_ELEMENT_BITS_ECDSA): assert partial_sum[0] != point[0] @@ -179,15 +190,15 @@ def mimic_ec_mult_air(m: int, point: ECPoint, shift_point: ECPoint) -> ECPoint: def verify(msg_hash: int, r: int, s: int, public_key: Union[int, ECPoint]) -> bool: # Compute w = s^-1 (mod EC_ORDER). - assert 1 <= s < EC_ORDER, 's = %s' % s + assert 1 <= s < EC_ORDER, "s = %s" % s w = inv_mod_curve_size(s) # Preassumptions: # DIFF: in classic ECDSA, we assert 1 <= r, w <= EC_ORDER-1. # Since r, w < 2**N_ELEMENT_BITS_ECDSA < EC_ORDER, we only need to verify r, w != 0. - assert 1 <= r < 2**N_ELEMENT_BITS_ECDSA, 'r = %s' % r - assert 1 <= w < 2**N_ELEMENT_BITS_ECDSA, 'w = %s' % w - assert 0 <= msg_hash < 2**N_ELEMENT_BITS_ECDSA, 'msg_hash = %s' % msg_hash + assert 1 <= r < 2 ** N_ELEMENT_BITS_ECDSA, "r = %s" % r + assert 1 <= w < 2 ** N_ELEMENT_BITS_ECDSA, "w = %s" % w + assert 0 <= msg_hash < 2 ** N_ELEMENT_BITS_ECDSA, "msg_hash = %s" % msg_hash if isinstance(public_key, int): # Only the x coordinate of the point is given, check the two possibilities for the y @@ -196,15 +207,19 @@ def verify(msg_hash: int, r: int, s: int, public_key: Union[int, ECPoint]) -> bo y = get_y_coordinate(public_key) except InvalidPublicKeyError: return False - assert pow(y, 2, FIELD_PRIME) == ( - pow(public_key, 3, FIELD_PRIME) + ALPHA * public_key + BETA) % FIELD_PRIME - return verify(msg_hash, r, s, (public_key, y)) or \ - verify(msg_hash, r, s, (public_key, (-y) % FIELD_PRIME)) + assert ( + pow(y, 2, FIELD_PRIME) + == (pow(public_key, 3, FIELD_PRIME) + ALPHA * public_key + BETA) % FIELD_PRIME + ) + return verify(msg_hash, r, s, (public_key, y)) or verify( + msg_hash, r, s, (public_key, (-y) % FIELD_PRIME) + ) else: # The public key is provided as a point. # Verify it is on the curve. - assert (public_key[1]**2 - ( - public_key[0]**3 + ALPHA * public_key[0] + BETA)) % FIELD_PRIME == 0 + assert ( + public_key[1] ** 2 - (public_key[0] ** 3 + ALPHA * public_key[0] + BETA) + ) % FIELD_PRIME == 0 # Signature validation. # DIFF: original formula is: @@ -230,6 +245,7 @@ def verify(msg_hash: int, r: int, s: int, public_key: Union[int, ECPoint]) -> bo # Pedersen hash # ################# + def pedersen_hash(*elements: int) -> int: return pedersen_hash_as_point(*elements)[0] @@ -242,10 +258,12 @@ def pedersen_hash_as_point(*elements: int) -> ECPoint: point = SHIFT_POINT for i, x in enumerate(elements): assert 0 <= x < FIELD_PRIME - point_list = CONSTANT_POINTS[2 + i * N_ELEMENT_BITS_HASH:2 + (i + 1) * N_ELEMENT_BITS_HASH] + point_list = CONSTANT_POINTS[ + 2 + i * N_ELEMENT_BITS_HASH : 2 + (i + 1) * N_ELEMENT_BITS_HASH + ] assert len(point_list) == N_ELEMENT_BITS_HASH for pt in point_list: - assert point[0] != pt[0], 'Unhashable input.' + assert point[0] != pt[0], "Unhashable input." if x & 1: point = ec_add(point, pt, FIELD_PRIME) x >>= 1 diff --git a/src/starkware/python/async_subprocess.py b/src/starkware/python/async_subprocess.py index d8e79389..60373ab9 100644 --- a/src/starkware/python/async_subprocess.py +++ b/src/starkware/python/async_subprocess.py @@ -7,11 +7,12 @@ async def async_check_output(args: Union[str, List[str]], shell: bool = False, c An async equivalent to subprocess.check_output(). """ if shell: - assert isinstance(args, str), 'args must be a string where shell=True.' + assert isinstance(args, str), "args must be a string where shell=True." # Pass '-e' to stop after failure if args consists of multiple commands. - args = ['bash', '-e', '-c', args] + args = ["bash", "-e", "-c", args] proc = await asyncio.create_subprocess_exec( - *args, cwd=cwd, env=env, stdout=asyncio.subprocess.PIPE) + *args, cwd=cwd, env=env, stdout=asyncio.subprocess.PIPE + ) return_code = await proc.wait() assert return_code == 0 assert proc.stdout is not None diff --git a/src/starkware/python/expression_string.py b/src/starkware/python/expression_string.py index 3e19f166..8e2192e1 100644 --- a/src/starkware/python/expression_string.py +++ b/src/starkware/python/expression_string.py @@ -19,11 +19,11 @@ class OperatorPrecedence(Enum): Represents the precedence of an operator. """ - LOWEST = 0 # Unary minus. - PLUS = auto() # Either + or -. - MUL = auto() # Either * or /. - POW = auto() # Power (**). - ADDROF = auto() # Address-of operator (&). + LOWEST = 0 # Unary minus. + PLUS = auto() # Either + or -. + MUL = auto() # Either * or /. + POW = auto() # Power (**). + ADDROF = auto() # Address-of operator (&). HIGHEST = auto() # Numeric values, variables, parenthesized expressions, ... def __lt__(self, other): @@ -75,28 +75,28 @@ def __str__(self): def __add__(self, other): other = to_expr_string(other) - return ExpressionString(f'{self:PLUS} + {other:PLUS}', OperatorPrecedence.PLUS) + return ExpressionString(f"{self:PLUS} + {other:PLUS}", OperatorPrecedence.PLUS) def __sub__(self, other): # Note that self and other are not symmetric. For example (a + b) - (c + d) should be: # a + b - (c + d). other = to_expr_string(other) - return ExpressionString(f'{self:PLUS} - {other:MUL}', OperatorPrecedence.PLUS) + return ExpressionString(f"{self:PLUS} - {other:MUL}", OperatorPrecedence.PLUS) def __mul__(self, other): other = to_expr_string(other) - return ExpressionString(f'{self:MUL} * {other:MUL}', OperatorPrecedence.MUL) + return ExpressionString(f"{self:MUL} * {other:MUL}", OperatorPrecedence.MUL) def __truediv__(self, other): # Note that self and other are not symmetric. For example (a * b) / (c * d) should be: # a * b / (c * d). other = to_expr_string(other) - return ExpressionString(f'{self:MUL} / {other:POW}', OperatorPrecedence.MUL) + return ExpressionString(f"{self:MUL} / {other:POW}", OperatorPrecedence.MUL) def __pow__(self, other): other = to_expr_string(other) # For the two expressions (a ** b) ** c and a ** (b ** c), parentheses will always be added. - return ExpressionString(f'{self:HIGHEST}^{other:HIGHEST}', OperatorPrecedence.POW) + return ExpressionString(f"{self:HIGHEST}^{other:HIGHEST}", OperatorPrecedence.POW) def double_star_pow(self, other): """ @@ -104,13 +104,13 @@ def double_star_pow(self, other): """ other = to_expr_string(other) # For the two expressions (a ** b) ** c and a ** (b ** c), parentheses will always be added. - return ExpressionString(f'{self:HIGHEST} ** {other:HIGHEST}', OperatorPrecedence.POW) + return ExpressionString(f"{self:HIGHEST} ** {other:HIGHEST}", OperatorPrecedence.POW) def __neg__(self): - return ExpressionString(f'-{self:ADDROF}', OperatorPrecedence.LOWEST) + return ExpressionString(f"-{self:ADDROF}", OperatorPrecedence.LOWEST) def address_of(self): - return ExpressionString(f'&{self:ADDROF}', OperatorPrecedence.ADDROF) + return ExpressionString(f"&{self:ADDROF}", OperatorPrecedence.ADDROF) def prepend(self, txt): """ @@ -124,7 +124,7 @@ def _maybe_add_parentheses(self, operator_precedence: OperatorPrecedence) -> str to operator_precedence and with parentheses otherwise. """ if self.outmost_operator_precedence < operator_precedence: - return '(%s)' % self.txt + return "(%s)" % self.txt else: return self.txt diff --git a/src/starkware/python/expression_string_test.py b/src/starkware/python/expression_string_test.py index 2eb01407..ca8545d3 100644 --- a/src/starkware/python/expression_string_test.py +++ b/src/starkware/python/expression_string_test.py @@ -3,38 +3,38 @@ def test_expression_string(): # Declare a few variables. - a = ExpressionString.highest('a') - b = ExpressionString.highest('b') - c = ExpressionString.highest('c') - d = ExpressionString.highest('d') - e = ExpressionString.highest('e') - f = ExpressionString.highest('f') - - assert str(a + b + c + d) == 'a + b + c + d' - assert str((a + b) + (c + (d + e) + f)) == 'a + b + c + d + e + f' - assert str((a + b) - (c - (d - e + f))) == 'a + b - (c - (d - e + f))' - assert str(-a + (-b)) == '(-a) + (-b)' - - assert str(a * b * c * d) == 'a * b * c * d' - assert str((a * b) * (c * (d * e) * f)) == 'a * b * c * d * e * f' - assert str((a * b) / (c / (d / e * f))) == 'a * b / (c / (d / e * f))' - assert str((-a) * b) == '(-a) * b' - assert str(-(a * b)) == '-(a * b)' - - assert str((a + b) * c + d + e * f) == '(a + b) * c + d + e * f' - assert str(a - (b - c) / (d - e) / f) == 'a - (b - c) / (d - e) / f' - - assert str((a ** b) ** c) == '(a^b)^c' - assert str(a ** b ** c) == 'a^(b^c)' - assert str(a ** ((b ** c) ** (d ** e)) ** f) == 'a^(((b^c)^(d^e))^f)' - assert str(a / b ** (c + d) * (e + f)) == 'a / b^(c + d) * (e + f)' - - assert str(-a) == '-a' - assert str(-(a + b) + (-(a - b)) - (-(a - b))) == '(-(a + b)) + (-(a - b)) - (-(a - b))' - assert str(-(a * b) * (-(a / b)) / (-(a / b))) == '(-(a * b)) * (-(a / b)) / (-(a / b))' - assert str((-((-a) ** (-b))) ** c) == '(-((-a)^(-b)))^c' - assert str(-(-a)) == '-(-a)' - - assert str((a ** b).address_of().address_of()) == '&&(a^b)' - assert str((-((-a).address_of())).address_of()) == '&(-&(-a))' - assert str(a.address_of() - b.address_of()) == '&a - &b' + a = ExpressionString.highest("a") + b = ExpressionString.highest("b") + c = ExpressionString.highest("c") + d = ExpressionString.highest("d") + e = ExpressionString.highest("e") + f = ExpressionString.highest("f") + + assert str(a + b + c + d) == "a + b + c + d" + assert str((a + b) + (c + (d + e) + f)) == "a + b + c + d + e + f" + assert str((a + b) - (c - (d - e + f))) == "a + b - (c - (d - e + f))" + assert str(-a + (-b)) == "(-a) + (-b)" + + assert str(a * b * c * d) == "a * b * c * d" + assert str((a * b) * (c * (d * e) * f)) == "a * b * c * d * e * f" + assert str((a * b) / (c / (d / e * f))) == "a * b / (c / (d / e * f))" + assert str((-a) * b) == "(-a) * b" + assert str(-(a * b)) == "-(a * b)" + + assert str((a + b) * c + d + e * f) == "(a + b) * c + d + e * f" + assert str(a - (b - c) / (d - e) / f) == "a - (b - c) / (d - e) / f" + + assert str((a ** b) ** c) == "(a^b)^c" + assert str(a ** b ** c) == "a^(b^c)" + assert str(a ** ((b ** c) ** (d ** e)) ** f) == "a^(((b^c)^(d^e))^f)" + assert str(a / b ** (c + d) * (e + f)) == "a / b^(c + d) * (e + f)" + + assert str(-a) == "-a" + assert str(-(a + b) + (-(a - b)) - (-(a - b))) == "(-(a + b)) + (-(a - b)) - (-(a - b))" + assert str(-(a * b) * (-(a / b)) / (-(a / b))) == "(-(a * b)) * (-(a / b)) / (-(a / b))" + assert str((-((-a) ** (-b))) ** c) == "(-((-a)^(-b)))^c" + assert str(-(-a)) == "-(-a)" + + assert str((a ** b).address_of().address_of()) == "&&(a^b)" + assert str((-((-a).address_of())).address_of()) == "&(-&(-a))" + assert str(a.address_of() - b.address_of()) == "&a - &b" diff --git a/src/starkware/python/json_rpc/client.py b/src/starkware/python/json_rpc/client.py index bc10fec4..af79e0f2 100644 --- a/src/starkware/python/json_rpc/client.py +++ b/src/starkware/python/json_rpc/client.py @@ -18,11 +18,11 @@ def call(self, *args, **kwargs) -> str: """ Returns a JSON-RPC call. """ - assert len(args) == 0, 'JSON-RPC call can only contain named arguments.' + assert len(args) == 0, "JSON-RPC call can only contain named arguments." - call_dict: Dict[str, Any] = {'jsonrpc': '2.0', 'method': self.name, 'id': None} + call_dict: Dict[str, Any] = {"jsonrpc": "2.0", "method": self.name, "id": None} if len(kwargs) != 0: - call_dict['params'] = kwargs + call_dict["params"] = kwargs return json.dumps(call_dict) diff --git a/src/starkware/python/json_rpc/client_test.py b/src/starkware/python/json_rpc/client_test.py index f92391e8..92a0989c 100644 --- a/src/starkware/python/json_rpc/client_test.py +++ b/src/starkware/python/json_rpc/client_test.py @@ -15,19 +15,19 @@ def test_json_rpc_encoder(): """ encoder = JsonRpcEncoder() - assert json.loads(encoder.bar.call(x=1, y='abc', z={'a': 3, 'b': 'c'})) == { - 'jsonrpc': '2.0', - 'method': 'bar', - 'params': { - 'x': 1, - 'y': 'abc', - 'z': { - 'a': 3, - 'b': 'c', + assert json.loads(encoder.bar.call(x=1, y="abc", z={"a": 3, "b": "c"})) == { + "jsonrpc": "2.0", + "method": "bar", + "params": { + "x": 1, + "y": "abc", + "z": { + "a": 3, + "b": "c", }, }, - 'id': None, + "id": None, } - with pytest.raises(AssertionError, match='JSON-RPC call can only contain named arguments.'): + with pytest.raises(AssertionError, match="JSON-RPC call can only contain named arguments."): encoder.foo.call(1, 2, x=3) diff --git a/src/starkware/python/math_utils.py b/src/starkware/python/math_utils.py index ef68d1c6..9c763cdd 100644 --- a/src/starkware/python/math_utils.py +++ b/src/starkware/python/math_utils.py @@ -11,7 +11,7 @@ def safe_div(x: int, y: int): """ assert isinstance(x, int) and isinstance(y, int) assert y != 0 - assert x % y == 0, f'{x} is not divisible by {y}.' + assert x % y == 0, f"{x} is not divisible by {y}." return x // y @@ -35,7 +35,7 @@ def next_power_of_2(x: int): """ assert isinstance(x, int) and x > 0 res = 2 ** (x - 1).bit_length() - assert x <= res < 2 * x, f'{x}, {res}' + assert x <= res < 2 * x, f"{x}, {res}" return res @@ -87,6 +87,7 @@ def ec_add(point1, point2, p): return x, y + def ec_double(point, alpha, p): """ Doubles a point on an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p. diff --git a/src/starkware/python/math_utils_test.py b/src/starkware/python/math_utils_test.py index fcae26a4..a905a378 100644 --- a/src/starkware/python/math_utils_test.py +++ b/src/starkware/python/math_utils_test.py @@ -1,8 +1,18 @@ import pytest from starkware.python.math_utils import ( - div_ceil, div_mod, ec_add, ec_double, ec_mult, is_power_of_2, is_quad_residue, next_power_of_2, - safe_div, safe_log2, sqrt) + div_ceil, + div_mod, + ec_add, + ec_double, + ec_mult, + is_power_of_2, + is_quad_residue, + next_power_of_2, + safe_div, + safe_log2, + sqrt, +) def test_ec_add(): diff --git a/src/starkware/python/merkle_tree.py b/src/starkware/python/merkle_tree.py index 14b40032..8c77323d 100644 --- a/src/starkware/python/merkle_tree.py +++ b/src/starkware/python/merkle_tree.py @@ -35,10 +35,10 @@ def decode_node(node): """ left_child, right_child = node if left_child is None: - assert right_child is not None, 'No updates in tree' - case = 'right' + assert right_child is not None, "No updates in tree" + case = "right" elif right_child is None: - case = 'left' + case = "left" else: - case = 'both' + case = "both" return left_child, right_child, case diff --git a/src/starkware/python/object_utils.py b/src/starkware/python/object_utils.py index f02b64cc..cbf364d4 100644 --- a/src/starkware/python/object_utils.py +++ b/src/starkware/python/object_utils.py @@ -2,31 +2,31 @@ import marshmallow -T = TypeVar('T') +T = TypeVar("T") def show_attr_predicate(obj: Any, attr_name: str) -> bool: """ Attribute selector that can be used in 'generic_object_repr'. """ - return not callable(getattr(obj, attr_name)) and not attr_name.startswith('_') + return not callable(getattr(obj, attr_name)) and not attr_name.startswith("_") def skip_config_attr_predicate(obj: Any, attr_name: str) -> bool: """ Attribute selector that can be used in 'generic_object_repr'. """ - return not attr_name.endswith('_config') and show_attr_predicate(obj=obj, attr_name=attr_name) + return not attr_name.endswith("_config") and show_attr_predicate(obj=obj, attr_name=attr_name) def _object_dumps(obj: Any) -> str: """ A helper function for 'generic_object_repr'. """ - if hasattr(obj, 'dumps'): + if hasattr(obj, "dumps"): return obj.dumps() - if hasattr(obj, 'Schema'): + if hasattr(obj, "Schema"): schema: marshmallow.Schema = obj.Schema() # type: ignore[attr-defined] return schema.dumps(obj=obj) @@ -34,14 +34,16 @@ def _object_dumps(obj: Any) -> str: def generic_object_repr( - obj: T, show_attr_predicate: Callable[[T, str], bool] = show_attr_predicate, - exclude: Optional[Set[str]] = None): + obj: T, + show_attr_predicate: Callable[[T, str], bool] = show_attr_predicate, + exclude: Optional[Set[str]] = None, +): """ A generic repr function implementation. The given show_attr_predicate argument determines which attributes to show, while the given exclude argument determines which attributes not to show. """ - if hasattr(obj, 'dumps') or hasattr(obj, 'Schema'): + if hasattr(obj, "dumps") or hasattr(obj, "Schema"): return _object_dumps(obj=obj) if exclude is None: @@ -49,9 +51,12 @@ def generic_object_repr( attributes_to_show = { attr: getattr(obj, attr) - for attr in obj.__dict__ if show_attr_predicate(obj, attr) and attr not in exclude} - attributes_repr = ', '.join( - f'{attr_name}={_object_dumps(obj=attr_value)}' - for attr_name, attr_value in attributes_to_show.items()) - - return f'{type(obj).__name__}({attributes_repr})' + for attr in obj.__dict__ + if show_attr_predicate(obj, attr) and attr not in exclude + } + attributes_repr = ", ".join( + f"{attr_name}={_object_dumps(obj=attr_value)}" + for attr_name, attr_value in attributes_to_show.items() + ) + + return f"{type(obj).__name__}({attributes_repr})" diff --git a/src/starkware/python/python_dependencies.py b/src/starkware/python/python_dependencies.py index 69dd3256..029a477d 100644 --- a/src/starkware/python/python_dependencies.py +++ b/src/starkware/python/python_dependencies.py @@ -1,9 +1,10 @@ import os import sys -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) -assert os.path.basename(ROOT_DIR) in ['src', 'site-packages', 'dist-packages'] or \ - os.path.basename(ROOT_DIR).endswith('-site') +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +assert os.path.basename(ROOT_DIR) in ["src", "site-packages", "dist-packages"] or os.path.basename( + ROOT_DIR +).endswith("-site") def generate_python_dependencies(dependencies_path, start_time): @@ -11,16 +12,18 @@ def generate_python_dependencies(dependencies_path, start_time): Creates a CMake file with the loaded python module. """ files = [ - x.__file__ for x in sys.modules.values() - if hasattr(x, '__file__') and x.__file__ is not None] + x.__file__ + for x in sys.modules.values() + if hasattr(x, "__file__") and x.__file__ is not None + ] - res = 'SET (DEPENDENCIES\n' + res = "SET (DEPENDENCIES\n" for filename in sorted(files): if filename.startswith(ROOT_DIR): - res += filename + '\n' - res += ')\n' + res += filename + "\n" + res += ")\n" - with open(dependencies_path, 'w') as dependencies_file: + with open(dependencies_path, "w") as dependencies_file: dependencies_file.write(res) # Change the modification time of the file to make sure it is older than the generated files. @@ -33,8 +36,10 @@ def add_argparse_argument(parser): Use process_args at the end of the program to generate the dependency file. """ parser.add_argument( - '--python_dependencies', type=str, - help='Output the starkware python modules this file depends on as a CMake file.') + "--python_dependencies", + type=str, + help="Output the starkware python modules this file depends on as a CMake file.", + ) def process_args(args, start_time): diff --git a/src/starkware/python/random_test.py b/src/starkware/python/random_test.py index f28ad2cb..31eff59c 100644 --- a/src/starkware/python/random_test.py +++ b/src/starkware/python/random_test.py @@ -14,14 +14,14 @@ def _get_seeds(n_nightly_runs: int, seed: Optional[int]) -> List[int]: Gets a list of seeds based on environment variables and the seed function argument. If RANDOM_TEST_N_RUNS is specified, returns a list of RANDOM_TEST_N_RUNS random seeds. """ - n_iters_env_var = os.environ.get('RANDOM_TEST_N_RUNS') + n_iters_env_var = os.environ.get("RANDOM_TEST_N_RUNS") if n_iters_env_var is None: - n_iters = n_nightly_runs if (os.environ.get('NIGHTLY_TEST') == '1') else 1 + n_iters = n_nightly_runs if (os.environ.get("NIGHTLY_TEST") == "1") else 1 else: return [random.randrange(sys.maxsize) for _ in range(int(n_iters_env_var))] - seed_env_var = os.environ.get('RANDOM_TEST_SEED') - if seed_env_var == 'random': + seed_env_var = os.environ.get("RANDOM_TEST_SEED") + if seed_env_var == "random": return [random.randrange(sys.maxsize) for _ in range(n_iters)] elif seed_env_var is not None: return [int(seed_env_var)] @@ -29,21 +29,23 @@ def _get_seeds(n_nightly_runs: int, seed: Optional[int]) -> List[int]: return [seed] # If we got here, then the seed wasn't set with an environment variable or a function argument. - if os.environ.get('NIGHTLY_TEST') == '1': + if os.environ.get("NIGHTLY_TEST") == "1": return [random.randrange(sys.maxsize) for _ in range(n_iters)] return [0] def _print_seed(seed: int, decorator_name: str): - print(f'The seed used in the test is {seed}.') - print(f'To reproduce the results set the environment variable RANDOM_TEST_SEED to {seed}.') + print(f"The seed used in the test is {seed}.") + print(f"To reproduce the results set the environment variable RANDOM_TEST_SEED to {seed}.") print( - f"(This can be done by adding 'RANDOM_TEST_SEED={seed}' at the beginning of the command).") + f"(This can be done by adding 'RANDOM_TEST_SEED={seed}' at the beginning of the command)." + ) print(f"Alternatively, you can add 'seed={seed}' to the '{decorator_name}' decorator") def _convert_function_to_function_or_coroutine( - caller_func: Callable, callee_func: Callable) -> Callable: + caller_func: Callable, callee_func: Callable +) -> Callable: """ Gets a function `caller_func` and a function or co-routine `callee_func`. `caller_func` is expected to yield values of the form `callee_func(...)` (which are either @@ -54,6 +56,7 @@ def _convert_function_to_function_or_coroutine( Exceptions that were thrown will be raised into the caller function. """ if inspect.iscoroutinefunction(callee_func): + @functools.wraps(callee_func) async def return_value(*args, **kwargs): gen = caller_func(*args, **kwargs) @@ -62,11 +65,14 @@ async def return_value(*args, **kwargs): await run except Exception as e: gen.throw(e) + else: + @functools.wraps(callee_func) def return_value(*args, **kwargs): for run in caller_func(*args, **kwargs): pass + return return_value @@ -91,6 +97,7 @@ def random_test(n_nightly_runs: int = 10, seed: Optional[int] = None): Setting the environment variable can be done by prefixing the command line with `RANDOM_TEST_SEED=10` for example. """ + def convert_test_func(test_func: Callable): seeds = _get_seeds(n_nightly_runs=n_nightly_runs, seed=seed) @@ -100,14 +107,18 @@ def fixate_seed_and_yield_test_run(*args, seed, **kwargs): try: yield test_func(*args, seed=seed, **kwargs) except Exception: - _print_seed(seed=seed, decorator_name='random_test') + _print_seed(seed=seed, decorator_name="random_test") raise finally: random.setstate(old_state) + # We need to use pytest.mark.parametrize rather than running the test in a for loop. If we # do the latter, pytest won't re-create the fixtures for each run. - return pytest.mark.parametrize('seed', seeds)(_convert_function_to_function_or_coroutine( - caller_func=fixate_seed_and_yield_test_run, callee_func=test_func)) + return pytest.mark.parametrize("seed", seeds)( + _convert_function_to_function_or_coroutine( + caller_func=fixate_seed_and_yield_test_run, callee_func=test_func + ) + ) return convert_test_func @@ -125,19 +136,23 @@ def parametrize_random_object(n_nightly_runs: int = 10, seed: Optional[int] = No For explanation on environment variables, read the doc of the `random_test` decorator. """ + def convert_test_func( - test_func: Callable[[NamedArg(type=random.Random, name='random_object')], None]): + test_func: Callable[[NamedArg(type=random.Random, name="random_object")], None] + ): seeds = _get_seeds(n_nightly_runs=n_nightly_runs, seed=seed) def fixate_seed_and_yield_test_run(*args, **kwargs): yield test_func(*args, **kwargs) return pytest.mark.parametrize( - 'random_object', + "random_object", [random.Random(seed) for seed in seeds], - ids=[f'Random({seed})' for seed in seeds], + ids=[f"Random({seed})" for seed in seeds], )( _convert_function_to_function_or_coroutine( - caller_func=fixate_seed_and_yield_test_run, callee_func=test_func)) + caller_func=fixate_seed_and_yield_test_run, callee_func=test_func + ) + ) return convert_test_func diff --git a/src/starkware/python/test_utils.py b/src/starkware/python/test_utils.py index 3e8618da..a59231ec 100644 --- a/src/starkware/python/test_utils.py +++ b/src/starkware/python/test_utils.py @@ -6,8 +6,8 @@ def maybe_raises( - expected_exception, error_message: Optional[str], - escape_error_message: bool = True) -> ContextManager: + expected_exception, error_message: Optional[str], escape_error_message: bool = True +) -> ContextManager: """ A utility function for parameterized tests with both positive and negative cases. If error_message is None, it expects no error, diff --git a/src/starkware/python/test_utils_test.py b/src/starkware/python/test_utils_test.py index 1bfbe3e8..d37f44c4 100644 --- a/src/starkware/python/test_utils_test.py +++ b/src/starkware/python/test_utils_test.py @@ -7,10 +7,13 @@ def maybe_trigger_exception(error_message): assert error_message is None, error_message -@pytest.mark.parametrize('error_message, res_type', [ - (None, None), - ('test', pytest.raises(AssertionError, maybe_trigger_exception, 'test')), -]) +@pytest.mark.parametrize( + "error_message, res_type", + [ + (None, None), + ("test", pytest.raises(AssertionError, maybe_trigger_exception, "test")), + ], +) def test_maybe_raises(error_message, res_type): with maybe_raises(AssertionError, error_message) as ex: maybe_trigger_exception(error_message) diff --git a/src/starkware/python/utils.py b/src/starkware/python/utils.py index 47d39fdc..58f2707b 100644 --- a/src/starkware/python/utils.py +++ b/src/starkware/python/utils.py @@ -15,25 +15,26 @@ def get_package_path(): Returns ROOT_PATH s.t. $ROOT_PATH/starkware is the package folder. """ import starkware.python - return os.path.abspath(os.path.join(os.path.dirname(starkware.python.__file__), '../../')) + return os.path.abspath(os.path.join(os.path.dirname(starkware.python.__file__), "../../")) -def get_build_dir_path(rel_path=''): + +def get_build_dir_path(rel_path=""): """ Returns a path to a file inside the build directory (or the docker). rel_path is the relative path of the file with respect to the build directory. """ - build_root = os.environ['BUILD_ROOT'] + build_root = os.environ["BUILD_ROOT"] return os.path.join(build_root, rel_path) -def get_source_dir_path(rel_path=''): +def get_source_dir_path(rel_path=""): """ Returns a path to a file inside the source directory. Does not work in docker. rel_path is the relative path of the file with respect to the source directory. """ - source_root = os.path.join(os.environ['BUILD_ROOT'], '../../') - assert os.path.exists(os.path.join(source_root, 'src')) + source_root = os.path.join(os.environ["BUILD_ROOT"], "../../") + assert os.path.exists(os.path.join(source_root, "src")) return os.path.join(source_root, rel_path) @@ -43,7 +44,7 @@ def assert_same_and_get(*args): For example, assert_same_and_get(5, 5, 5) will return 5, and assert_same_and_get(0, 1) will raise an AssertionError. """ - assert len(set(args)) == 1, 'Values are not the same (%s)' % (args,) + assert len(set(args)) == 1, "Values are not the same (%s)" % (args,) return args[0] @@ -85,15 +86,15 @@ def indent(code, indentation): if len(code) == 0: return code if isinstance(indentation, int): - indentation = ' ' * indentation + indentation = " " * indentation elif not isinstance(indentation, str): - raise TypeError(f'Supports only int or str, got {type(indentation).__name__}') + raise TypeError(f"Supports only int or str, got {type(indentation).__name__}") # Replace every occurrence of \n, with \n followed by indentation, # unless the \n is the last characther of the string or is followed by another \n. # We enforce the "not followed by ..." condition using negative lookahead (?!\n|$), # looking for end of string ($) or another \n. - return indentation + re.sub(r'\n(?!\n|$)', '\n' + indentation, code) + return indentation + re.sub(r"\n(?!\n|$)", "\n" + indentation, code) def get_random_instance() -> random.Random: @@ -104,7 +105,8 @@ def get_random_instance() -> random.Random: def initialize_random( - random_object: Optional[random.Random] = None, seed: Optional[int] = None) -> random.Random: + random_object: Optional[random.Random] = None, seed: Optional[int] = None +) -> random.Random: """ Returns a Random object initialized according to the given parameters. If both are None, the Random instance instantiated in the random module is returned. @@ -129,7 +131,7 @@ def compare_files(src, dst, fix): If 'fix' is False, checks that the files are the same. If 'fix' is True, overrides dst with src. """ - subprocess.check_call(['cp' if fix else 'diff', src, dst]) + subprocess.check_call(["cp" if fix else "diff", src, dst]) def remove_trailing_spaces(code): @@ -137,7 +139,7 @@ def remove_trailing_spaces(code): Removes spaces from end of lines. For example, remove_trailing_spaces('hello \nworld \n') -> 'hello\nworld\n'. """ - return re.sub(' +$', '', code, flags=re.MULTILINE) + return re.sub(" +$", "", code, flags=re.MULTILINE) def should_discard_key(key, exclude: List[str]) -> bool: @@ -159,8 +161,9 @@ class WriteOnceDict(UserDict): """ def __setitem__(self, key, value): - assert key not in self.data, \ - f"Trying to set key={key} to '{value}' but key={key} is already set to '{self[key]}'." + assert ( + key not in self.data + ), f"Trying to set key={key} to '{value}' but key={key} is already set to '{self[key]}'." self.data[key] = value @@ -169,7 +172,7 @@ def camel_to_snake_case(camel_case_name: str) -> str: Converts a name with Capital first letters to lower case with '_' as separators. For example, CamelToSnakeCase -> camel_to_snake_case. """ - return (camel_case_name[0] + re.sub(r'([A-Z])', r'_\1', camel_case_name[1:])).lower() + return (camel_case_name[0] + re.sub(r"([A-Z])", r"_\1", camel_case_name[1:])).lower() def snake_to_camel_case(snake_case_name: str) -> str: @@ -177,7 +180,7 @@ def snake_to_camel_case(snake_case_name: str) -> str: Converts the first letter to upper case (if possible) and all the '_l' to 'L'. For example snake_to_camel_case -> SnakeToCamelCase. """ - return re.subn(r'(^|_)([a-z])', lambda m: m.group(2).upper(), snake_case_name)[0] + return re.subn(r"(^|_)([a-z])", lambda m: m.group(2).upper(), snake_case_name)[0] async def cancel_futures(*futures: asyncio.Future): @@ -201,7 +204,7 @@ def safe_zip(*iterables: Iterable[Any]) -> Iterable: """ sentinel = object() for combo in itertools.zip_longest(*iterables, fillvalue=sentinel): - assert sentinel not in combo, 'Iterables to safe_zip are not equal in length.' + assert sentinel not in combo, "Iterables to safe_zip are not equal in length." yield combo @@ -225,6 +228,7 @@ def composition_function(*args, **kwargs): for func in reversed(funcs[:-1]): return_value = func(return_value) return return_value + return composition_function @@ -237,7 +241,7 @@ def to_bytes(value: int, length: Optional[int] = None, byte_order: Optional[str] length = HASH_BYTES if byte_order is None: - byte_order = 'big' + byte_order = "big" return int.to_bytes(value, length=length, byteorder=byte_order) @@ -248,7 +252,7 @@ def from_bytes(value: bytes, byte_order: Optional[str] = None) -> int: Default byte order is 'big'. """ if byte_order is None: - byte_order = 'big' + byte_order = "big" return int.from_bytes(value, byteorder=byte_order) @@ -257,8 +261,8 @@ def blockify(data, chunk_size: int) -> Iterable: """ Returns the given data partitioned to chunks of chunks_size (last chunk might be smaller). """ - assert chunk_size > 0, f'chunk_size must be greater than 0. Got: {chunk_size}.' - return (data[i:i + chunk_size] for i in range(0, len(data), chunk_size)) + assert chunk_size > 0, f"chunk_size must be greater than 0. Got: {chunk_size}." + return (data[i : i + chunk_size] for i in range(0, len(data), chunk_size)) def all_subclasses(cls: type) -> List[type]: @@ -270,4 +274,5 @@ def all_subclasses(cls: type) -> List[type]: def _all_subclasses(cls: type) -> List[type]: return [cls] + list( - itertools.chain(*[_all_subclasses(subclass) for subclass in cls.__subclasses__()])) + itertools.chain(*[_all_subclasses(subclass) for subclass in cls.__subclasses__()]) + ) diff --git a/src/starkware/python/utils_test.py b/src/starkware/python/utils_test.py index 4b6afbb5..e2b0bafb 100644 --- a/src/starkware/python/utils_test.py +++ b/src/starkware/python/utils_test.py @@ -3,13 +3,20 @@ import pytest from starkware.python.utils import ( - WriteOnceDict, all_subclasses, blockify, composite, indent, safe_zip, unique) + WriteOnceDict, + all_subclasses, + blockify, + composite, + indent, + safe_zip, + unique, +) def test_indent(): - assert indent('aa\n bb', 2) == ' aa\n bb' - assert indent('aa\n bb\n', 2) == ' aa\n bb\n' - assert indent(' aa\n bb\n\ncc\n', 2) == ' aa\n bb\n\n cc\n' + assert indent("aa\n bb", 2) == " aa\n bb" + assert indent("aa\n bb\n", 2) == " aa\n bb\n" + assert indent(" aa\n bb\n\ncc\n", 2) == " aa\n bb\n\n cc\n" def test_unique(): @@ -22,9 +29,11 @@ def test_write_once_dict(): key = 5 value = None d[key] = value - with pytest.raises(AssertionError, match=re.escape( - f"Trying to set key=5 to 'b' but key=5 is already set to 'None'.")): - d[key] = 'b' + with pytest.raises( + AssertionError, + match=re.escape(f"Trying to set key=5 to 'b' but key=5 is already set to 'None'."), + ): + d[key] = "b" def test_safe_zip(): @@ -32,14 +41,14 @@ def test_safe_zip(): assert list(safe_zip()) == list(zip()) # Test equal-length iterables (including a generator). - assert ( - list(safe_zip((i for i in range(3)), range(3, 6), [1, 2, 3])) == - list(zip((i for i in range(3)), range(3, 6), [1, 2, 3]))) + assert list(safe_zip((i for i in range(3)), range(3, 6), [1, 2, 3])) == list( + zip((i for i in range(3)), range(3, 6), [1, 2, 3]) + ) # Test unequal-length iterables. test_cases = [[range(4), range(3)], [[], range(3)]] for iterables in test_cases: - with pytest.raises(AssertionError, match='Iterables to safe_zip are not equal in length.'): + with pytest.raises(AssertionError, match="Iterables to safe_zip are not equal in length."): list(safe_zip(*iterables)) # Consume generator to get to the error. @@ -55,7 +64,7 @@ def test_blockify(): # Edge cases. assert list(blockify(data=[], chunk_size=2)) == [] assert list(blockify(data=data, chunk_size=len(data))) == [data] - with pytest.raises(expected_exception=AssertionError, match='chunk_size'): + with pytest.raises(expected_exception=AssertionError, match="chunk_size"): blockify(data=data, chunk_size=0) assert list(blockify(data=data, chunk_size=4)) == [[1, 2, 3, 4], [5, 6, 7]] diff --git a/src/starkware/starknet/business_logic/internal_transaction.py b/src/starkware/starknet/business_logic/internal_transaction.py index 469249c5..1db5e5b1 100644 --- a/src/starkware/starknet/business_logic/internal_transaction.py +++ b/src/starkware/starknet/business_logic/internal_transaction.py @@ -4,8 +4,6 @@ import logging from abc import abstractmethod from dataclasses import field -from starkware.cairo.lang.vm.relocatable import RelocatableValue -from starkware.starknet.public.abi import STORAGE_PTR_OFFSET, SYSCALL_PTR_OFFSET from typing import ClassVar, Dict, List, Optional, Tuple, Type, cast import marshmallow @@ -15,6 +13,7 @@ from services.everest.api.gateway.transaction import EverestTransaction from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner from starkware.cairo.lang.vm.cairo_pie import ExecutionResources +from starkware.cairo.lang.vm.relocatable import RelocatableValue from starkware.cairo.lang.vm.security import SecurityError from starkware.cairo.lang.vm.utils import ResourcesError, RunResources from starkware.cairo.lang.vm.vm import HintException, VmException, VmExceptionBase @@ -34,6 +33,7 @@ from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.definitions.transaction_type import TransactionType +from starkware.starknet.public.abi import STORAGE_PTR_OFFSET, SYSCALL_PTR_OFFSET from starkware.starknet.services.api.contract_definition import ( ContractDefinition, ContractEntryPoint, @@ -477,9 +477,10 @@ def _run( ) # Complete handler validations. - storage_stop_ptr = segment_utils.get_os_segment_stop_ptr( - runner=runner, ptr_offset=STORAGE_PTR_OFFSET, os_context=os_context - ) + with wrap_with_stark_exception(code=StarknetErrorCode.SECURITY_ERROR): + storage_stop_ptr = segment_utils.get_os_segment_stop_ptr( + runner=runner, ptr_offset=STORAGE_PTR_OFFSET, os_context=os_context + ) syscall_handler.finalize_storage_validations( segments=runner.segments, storage_stop_ptr=storage_stop_ptr ) diff --git a/src/starkware/starknet/business_logic/internal_transaction_interface.py b/src/starkware/starknet/business_logic/internal_transaction_interface.py index e023b315..ce87c023 100644 --- a/src/starkware/starknet/business_logic/internal_transaction_interface.py +++ b/src/starkware/starknet/business_logic/internal_transaction_interface.py @@ -11,7 +11,6 @@ EverestInternalTransaction, EverestTransactionExecutionInfo, ) -from services.everest.business_logic.internal_transaction import EverestInternalTransaction from services.everest.business_logic.state import CarriedStateBase from services.everest.definitions import fields as everest_fields from starkware.cairo.lang.vm.utils import RunResources diff --git a/src/starkware/starknet/business_logic/state.py b/src/starkware/starknet/business_logic/state.py index 7c5ec20e..6b6b9c15 100644 --- a/src/starkware/starknet/business_logic/state.py +++ b/src/starkware/starknet/business_logic/state.py @@ -17,8 +17,8 @@ from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.storage.starknet_storage import StorageLeaf +from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.starkware_utils.config_base import Config -from starkware.starkware_utils.patricia_tree.patricia_tree import PatriciaTree from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass from starkware.storage.storage import FactFetchingContext diff --git a/src/starkware/starknet/business_logic/state_objects.py b/src/starkware/starknet/business_logic/state_objects.py index 04044572..7a5b2ab5 100644 --- a/src/starkware/starknet/business_logic/state_objects.py +++ b/src/starkware/starknet/business_logic/state_objects.py @@ -10,13 +10,13 @@ from starkware.starknet.definitions import fields from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.storage.starknet_storage import StorageLeaf -from starkware.storage.storage import HASH_BYTES, Fact, FactFetchingContext, HashFunctionType +from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import EmptyNodeFact +from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.starkware_utils.validated_dataclass import ( ValidatedDataclass, ValidatedMarshmallowDataclass, ) -from starkware.starkware_utils.patricia_tree.patricia_tree import PatriciaTree -from starkware.starkware_utils.patricia_tree.nodes import EmptyNodeFact +from starkware.storage.storage import HASH_BYTES, Fact, FactFetchingContext, HashFunctionType @marshmallow_dataclass.dataclass(frozen=True) diff --git a/src/starkware/starknet/cli/CMakeLists.txt b/src/starkware/starknet/cli/CMakeLists.txt index 87045581..76c9d903 100644 --- a/src/starkware/starknet/cli/CMakeLists.txt +++ b/src/starkware/starknet/cli/CMakeLists.txt @@ -9,6 +9,7 @@ python_lib(starknet_cli_lib cairo_version_lib cairo_vm_utils_lib services_external_api_lib + starknet_abi_lib starknet_compile_lib starknet_contract_definition_lib starknet_definitions_lib diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 097b9d02..ef9d563b 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -8,11 +8,17 @@ import sys from services.external_api.base_client import RetryConfig +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.type_system import mark_type_resolved +from starkware.cairo.lang.compiler.type_utils import check_felts_only_type from starkware.cairo.lang.version import __version__ from starkware.cairo.lang.vm.reconstruct_traceback import reconstruct_traceback from starkware.starknet.compiler.compile import get_selector_from_name from starkware.starknet.definitions import fields +from starkware.starknet.public.abi_structs import struct_definition_from_abi_entry from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.services.api.feeder_gateway.feeder_gateway_client import FeederGatewayClient from starkware.starknet.services.api.gateway.gateway_client import GatewayClient @@ -125,6 +131,16 @@ async def invoke_or_call(args, command_args, call: bool): ) abi = json.load(args.abi) + + # Load types. + identifiers = IdentifierManager() + for abi_entry in abi: + if abi_entry["type"] == "struct": + struct_definition = struct_definition_from_abi_entry(abi_entry=abi_entry) + identifiers.add_identifier( + name=struct_definition.full_name, definition=struct_definition + ) + try: address = int(args.address, 16) except ValueError: @@ -134,23 +150,25 @@ async def invoke_or_call(args, command_args, call: bool): previous_felt_input = None current_inputs_ptr = 0 for input_desc in abi_entry["inputs"]: - if input_desc["type"] == "felt": - assert current_inputs_ptr < len( - inputs - ), f"Expected at least {current_inputs_ptr + 1} inputs, got {len(inputs)}." - - previous_felt_input = inputs[current_inputs_ptr] - current_inputs_ptr += 1 - elif input_desc["type"] == "felt*": + typ = mark_type_resolved(parse_type(input_desc["type"])) + typ_size = check_felts_only_type(cairo_type=typ, identifier_manager=identifiers) + if typ_size is not None: + assert current_inputs_ptr + typ_size <= len(inputs), ( + f"Expected at least {current_inputs_ptr + typ_size} inputs, " + f"got {len(inputs)}." + ) + + current_inputs_ptr += typ_size + elif typ == TypePointer(pointee=TypeFelt()): assert previous_felt_input is not None, ( f'The array argument {input_desc["name"]} of type felt* must be preceded ' "by a length argument of type felt." ) current_inputs_ptr += previous_felt_input - previous_felt_input = None else: raise Exception(f'Type {input_desc["type"]} is not supported.') + previous_felt_input = inputs[current_inputs_ptr - 1] if typ == TypeFelt() else None break else: raise Exception(f"Function {args.function} not found.") diff --git a/src/starkware/starknet/compiler/compile.py b/src/starkware/starknet/compiler/compile.py index 3a8a3141..4a7fc88e 100644 --- a/src/starkware/starknet/compiler/compile.py +++ b/src/starkware/starknet/compiler/compile.py @@ -1,7 +1,7 @@ import argparse import json import sys -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME from starkware.cairo.lang.compiler.assembler import assemble @@ -86,9 +86,14 @@ def get_entry_points_by_decorators( def compile_starknet_files( - files, debug_info: bool = False, disable_hint_validation: bool = False + files, + debug_info: bool = False, + disable_hint_validation: bool = False, + cairo_path: Optional[List[str]] = None, ) -> ContractDefinition: - module_reader = get_module_reader(cairo_path=[]) + if cairo_path is None: + cairo_path = [] + module_reader = get_module_reader(cairo_path=cairo_path) pass_manager = starknet_pass_manager( prime=DEFAULT_PRIME, diff --git a/src/starkware/starknet/compiler/contract_interface.py b/src/starkware/starknet/compiler/contract_interface.py index 4e0628db..ae01adc4 100644 --- a/src/starkware/starknet/compiler/contract_interface.py +++ b/src/starkware/starknet/compiler/contract_interface.py @@ -8,6 +8,7 @@ CommentedCodeElement, ) from starkware.cairo.lang.compiler.error_handling import Location, ParentLocation +from starkware.cairo.lang.compiler.parser import ParserContext from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( IdentifierAwareVisitor, ) @@ -163,7 +164,9 @@ def process_contract_function( code_block = autogen_parse_code_block( path=function_info.autogen_code_name, code=code, - parent_location=function_info.parent_location, + parser_context=ParserContext( + parent_location=function_info.parent_location, + ), ) call_func = code_block.code_elements[1].code_elm @@ -196,7 +199,9 @@ def generate_contract_interface_namespace( code_block = autogen_parse_code_block( path=AUTOGEN_PREFIX + contract_name, code=code, - parent_location=contract_info.parent_location, + parser_context=ParserContext( + parent_location=contract_info.parent_location, + ), ) assert len(code_block.code_elements) == 1 res = code_block.code_elements[0].code_elm @@ -265,7 +270,9 @@ def generate_contract_function_body(self, function_info: ContractFunctionInfo): return autogen_parse_code_block( path=function_info.autogen_code_name, code=code, - parent_location=function_info.parent_location, + parser_context=ParserContext( + parent_location=function_info.parent_location, + ), ) @@ -301,7 +308,9 @@ def get_code_elements(code: str) -> List[CommentedCodeElement]: return autogen_parse_code_block( path=function_info.autogen_code_name, code=code, - parent_location=function_info.parent_location, + parser_context=ParserContext( + parent_location=function_info.parent_location, + ), ).code_elements code_elements: List[CommentedCodeElement] = [] @@ -317,7 +326,7 @@ def get_code_elements(code: str) -> List[CommentedCodeElement]: args = [ ArgumentInfo( name=typed_identifier.identifier.name, - cairo_type=typed_identifier.get_type(), + cairo_type=self.resolve_type(typed_identifier.get_type()), location=non_optional_location(typed_identifier.identifier.location), ) for typed_identifier in function_info.elm.arguments.identifiers @@ -325,9 +334,8 @@ def get_code_elements(code: str) -> List[CommentedCodeElement]: code_elements += encode_data( arguments=args, encoding_type=EncodingType.CALLDATA, - # Passing has_range_check_builtin=True will skip the check for the existence of the - # range check builtin. has_range_check_builtin=True, + identifiers=self.identifiers, ) code_elements += get_code_elements( @@ -357,10 +365,9 @@ def get_code_elements(code: str) -> List[CommentedCodeElement]: data_size="retdata_size", arguments=rets, encoding_type=EncodingType.RETURN, - # Passing has_range_check_builtin=True will skip the check for the existence of the - # range check builtin. has_range_check_builtin=True, location=function_info.parent_location[0], + identifiers=self.identifiers, ) # Update the return values. return_str = ret_arg_list.format() diff --git a/src/starkware/starknet/compiler/data_encoder.py b/src/starkware/starknet/compiler/data_encoder.py index f8d76217..dc98b331 100644 --- a/src/starkware/starknet/compiler/data_encoder.py +++ b/src/starkware/starknet/compiler/data_encoder.py @@ -2,14 +2,23 @@ from enum import Enum, auto from typing import List, Optional, Sequence, Tuple -from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeFelt, TypePointer +from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) from starkware.cairo.lang.compiler.ast.code_elements import CommentedCodeElement from starkware.cairo.lang.compiler.ast.expr import ArgList, ExprAssignment, ExprIdentifier from starkware.cairo.lang.compiler.ast.notes import Notes from starkware.cairo.lang.compiler.error_handling import Location, ParentLocation from starkware.cairo.lang.compiler.identifier_definition import StructDefinition +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.parser import ParserContext from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.preprocessor.preprocessor_utils import autogen_parse_code_block +from starkware.cairo.lang.compiler.type_utils import check_felts_only_type class EncodingType(Enum): @@ -65,11 +74,17 @@ class DataEncodingProcessor: EncodingType.RETURN: "return value", } - def __init__(self, encoding_type: EncodingType, has_range_check_builtin: bool): + def __init__( + self, + encoding_type: EncodingType, + has_range_check_builtin: bool, + identifiers: IdentifierManager, + ): self.encoding_type = encoding_type self.has_range_check_builtin = has_range_check_builtin self.code_elements: List[CommentedCodeElement] = [] self.args: List[ExprAssignment] = [] + self.identifiers = identifiers @property def var_name(self): @@ -96,7 +111,10 @@ def add_code(self, code: str, parent_location: ParentLocation): code_block = autogen_parse_code_block( path="autogen/starknet/arg_processor", code=code, - parent_location=parent_location, + parser_context=ParserContext( + parent_location=parent_location, + resolved_types=True, + ), ) self.code_elements += code_block.code_elements @@ -129,6 +147,16 @@ def run(self, arguments: Sequence[ArgumentInfo]): ) code_block_str = self.process_felt_ptr(arg_info=arg_info) + elif isinstance(cairo_type, (TypeTuple, TypeStruct)): + size = check_felts_only_type( + cairo_type=cairo_type, identifier_manager=self.identifiers + ) + if size is None: + raise PreprocessorError( + f"{self.arg_text} must consist only of felts.", + location=arg_info.location, + ) + code_block_str = self.process_felts_object(arg_info=arg_info, size=size) elif isinstance(cairo_type, TypeFelt): code_block_str = self.process_felt(arg_info=arg_info) else: @@ -172,6 +200,14 @@ def process_felt_ptr(self, arg_info: ArgumentInfo): "Array arguments are not supported in this context", location=arg_info.location ) + def process_felts_object(self, arg_info: ArgumentInfo, size: int): + """ + Handles tuples or structs which recursively consist only of felts. + """ + raise PreprocessorError( + "Tuples/structs are not supported in this context", location=arg_info.location + ) + class DataDecoder(DataEncodingProcessor): def __init__( @@ -181,9 +217,12 @@ def __init__( has_range_check_builtin: bool, encoding_type: EncodingType, location: Location, + identifiers: IdentifierManager, ): super().__init__( - encoding_type=encoding_type, has_range_check_builtin=has_range_check_builtin + encoding_type=encoding_type, + has_range_check_builtin=has_range_check_builtin, + identifiers=identifiers, ) self.data_ptr = data_ptr self.data_size = data_size @@ -229,6 +268,13 @@ def process_felt_ptr(self, arg_info: ArgumentInfo): tempvar __{self.var_name}_ptr = __{self.var_name}_ptr + __{self.var_name}_arg_{arg_info.name}_len """ + def process_felts_object(self, arg_info: ArgumentInfo, size: int): + return f"""\ +let __{self.var_name}_arg_{arg_info.name} = [ + cast(__{self.var_name}_ptr, {TypePointer(pointee=arg_info.cairo_type).format()})] +let __{self.var_name}_ptr = __{self.var_name}_ptr + {size} +""" + def decode_data( data_ptr: str, @@ -237,6 +283,7 @@ def decode_data( encoding_type: EncodingType, has_range_check_builtin: bool, location: Location, + identifiers: IdentifierManager, ) -> Tuple[List[CommentedCodeElement], ArgList]: """ Processes the calldata of a function. @@ -256,6 +303,7 @@ def decode_data( encoding_type=encoding_type, has_range_check_builtin=has_range_check_builtin, location=location, + identifiers=identifiers, ) parser.run(arguments) args = parser.args @@ -286,16 +334,30 @@ def process_felt_ptr(self, arg_info: ArgumentInfo): memcpy(dst=__{self.var_name}_ptr_copy, src={arg_info.name}, len={arg_info.name}_len) """ + def process_felts_object(self, arg_info: ArgumentInfo, size: int): + body = "\n".join( + f"assert [__{self.var_name}_ptr + {i}] = [__{self.var_name}_tmp + {i}]" + for i in range(size) + ) + return f"""\ +# Create a reference to {arg_info.name} as felt*. +let __{self.var_name}_tmp : felt* = cast(&{arg_info.name}, felt*) +{body} +let __{self.var_name}_ptr = __{self.var_name}_ptr + {size} +""" + def encode_data( arguments: Sequence[ArgumentInfo], encoding_type: EncodingType, has_range_check_builtin: bool, + identifiers: IdentifierManager, ) -> List[CommentedCodeElement]: parser = DataEncoder( encoding_type=encoding_type, has_range_check_builtin=has_range_check_builtin, + identifiers=identifiers, ) parser.run(arguments) return parser.code_elements diff --git a/src/starkware/starknet/compiler/data_encoder_test.py b/src/starkware/starknet/compiler/data_encoder_test.py index 33f681b5..33ce8dd6 100644 --- a/src/starkware/starknet/compiler/data_encoder_test.py +++ b/src/starkware/starknet/compiler/data_encoder_test.py @@ -3,11 +3,16 @@ import pytest -from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.cairo.lang.compiler.ast.cairo_types import TypeFelt, TypePointer, TypeTuple from starkware.cairo.lang.compiler.error_handling import InputFile, Location +from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager +from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError +from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import preprocess_str from starkware.cairo.lang.compiler.scoped_name import ScopedName from starkware.cairo.lang.compiler.type_casts import FELT_STAR +from starkware.cairo.lang.compiler.type_system import mark_type_resolved from starkware.starknet.compiler.data_encoder import ( ArgumentInfo, EncodingType, @@ -29,7 +34,7 @@ def dummy_location(): ) -def run_data_encoder( +def run_decode_data( arguments: List[ArgumentInfo], encoding_type: EncodingType = EncodingType.CALLDATA, has_range_check_builtin=True, @@ -41,6 +46,7 @@ def run_data_encoder( encoding_type=encoding_type, has_range_check_builtin=has_range_check_builtin, location=dummy_location(), + identifiers=IdentifierManager(), ) @@ -50,15 +56,22 @@ def test_decode_data_flow(): ArgumentInfo(name="a_len", cairo_type=TypeFelt(), location=location), ArgumentInfo(name="a", cairo_type=FELT_STAR, location=location), ArgumentInfo(name="b", cairo_type=TypeFelt(), location=location), + ArgumentInfo( + name="c", + cairo_type=TypeTuple(members=[TypeFelt(), TypeTuple(members=[TypeFelt(), TypeFelt()])]), + location=location, + ), ] - code_elements, expr = run_data_encoder(arguments) + code_elements, expr = run_decode_data(arguments) assert ( "".join(code_element.format(100) + "\n" for code_element in code_elements) == """\ let __calldata_ptr : felt* = cast(data_ptr, felt*) + let __calldata_arg_a_len = [__calldata_ptr] let __calldata_ptr = __calldata_ptr + 1 + # Check that the length is non-negative. assert [range_check_ptr] = __calldata_arg_a_len let range_check_ptr = range_check_ptr + 1 @@ -67,14 +80,25 @@ def test_decode_data_flow(): # Use 'tempvar' instead of 'let' to avoid repeating this computation for the # following arguments. tempvar __calldata_ptr = __calldata_ptr + __calldata_arg_a_len + let __calldata_arg_b = [__calldata_ptr] let __calldata_ptr = __calldata_ptr + 1 + +let __calldata_arg_c = [ + cast(__calldata_ptr, (felt, (felt, felt))*)] +let __calldata_ptr = __calldata_ptr + 3 + let __calldata_actual_size = __calldata_ptr - cast(data_ptr, felt*) assert data_size = __calldata_actual_size -""" +""".replace( + "\n\n", "\n" + ) ) - assert expr.format() == "a_len=__calldata_arg_a_len, a=__calldata_arg_a, b=__calldata_arg_b," + assert ( + expr.format() + == "a_len=__calldata_arg_a_len, a=__calldata_arg_a, b=__calldata_arg_b, c=__calldata_arg_c," + ) assert code_elements[0].code_elm.expr.location.parent_location == ( location, @@ -83,7 +107,7 @@ def test_decode_data_flow(): # Do the same, with encoding_type=EncodingType.RETURN. # Only validate the beginning of the generated code. - code_elements, expr = run_data_encoder(arguments, encoding_type=EncodingType.RETURN) + code_elements, expr = run_decode_data(arguments, encoding_type=EncodingType.RETURN) assert "".join(code_element.format(100) + "\n" for code_element in code_elements).startswith( """\ let __return_value_ptr : felt* = cast(data_ptr, felt*) @@ -99,7 +123,7 @@ def test_decode_data_flow(): def test_decode_data_failure(): location = dummy_location() with pytest.raises(PreprocessorError, match=re.escape("Unsupported argument type felt**.")): - run_data_encoder( + run_decode_data( [ ArgumentInfo(name="arg_a", cairo_type=FELT_STAR_STAR, location=location), ArgumentInfo(name="arg_b", cairo_type=TypeFelt(), location=location), @@ -110,7 +134,7 @@ def test_decode_data_failure(): match='Array argument "arg_a" must be preceded by a length ' 'argument named "arg_a_len" of type felt.', ): - run_data_encoder( + run_decode_data( [ ArgumentInfo(name="arg_a", cairo_type=FELT_STAR, location=location), ArgumentInfo(name="arg_b", cairo_type=TypeFelt(), location=location), @@ -123,7 +147,7 @@ def test_decode_data_failure(): "array arguments in external functions." ), ): - run_data_encoder( + run_decode_data( [ ArgumentInfo(name="arg_len", cairo_type=TypeFelt(), location=location), ArgumentInfo(name="arg", cairo_type=FELT_STAR, location=location), @@ -133,15 +157,31 @@ def test_decode_data_failure(): def test_encode_data_for_return(): + identifiers = preprocess_str( + """ +struct MyStruct: + member x : felt + member y : felt +end +""", + prime=DEFAULT_PRIME, + ).identifiers + location = dummy_location() code_elements = encode_data( [ ArgumentInfo(name="a", cairo_type=TypeFelt(), location=location), ArgumentInfo(name="b_len", cairo_type=TypeFelt(), location=location), ArgumentInfo(name="b", cairo_type=FELT_STAR, location=location), + ArgumentInfo( + name="c", + cairo_type=mark_type_resolved(parse_type("(test_scope.MyStruct, felt)")), + location=location, + ), ], encoding_type=EncodingType.RETURN, has_range_check_builtin=True, + identifiers=identifiers, ) assert ( @@ -149,8 +189,10 @@ def test_encode_data_for_return(): == """\ assert [__return_value_ptr] = a let __return_value_ptr = __return_value_ptr + 1 + assert [__return_value_ptr] = b_len let __return_value_ptr = __return_value_ptr + 1 + # Check that the length is non-negative. assert [range_check_ptr] = b_len # Store the updated range_check_ptr as a local variable to keep it available after @@ -162,7 +204,16 @@ def test_encode_data_for_return(): # the memcpy. local __return_value_ptr : felt* = __return_value_ptr + b_len memcpy(dst=__return_value_ptr_copy, src=b, len=b_len) -""" + +# Create a reference to c as felt*. +let __return_value_tmp : felt* = cast(&c, felt*) +assert [__return_value_ptr + 0] = [__return_value_tmp + 0] +assert [__return_value_ptr + 1] = [__return_value_tmp + 1] +assert [__return_value_ptr + 2] = [__return_value_tmp + 2] +let __return_value_ptr = __return_value_ptr + 3 +""".replace( + "\n\n", "\n" + ) ) assert code_elements[0].code_elm.a.location.parent_location == ( diff --git a/src/starkware/starknet/compiler/starknet_preprocessor.py b/src/starkware/starknet/compiler/starknet_preprocessor.py index 66d74f4c..301f16e5 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor.py @@ -42,6 +42,7 @@ FutureIdentifierDefinition, StructDefinition, ) +from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.instruction import Register from starkware.cairo.lang.compiler.preprocessor.preprocessor import ( PreprocessedProgram, @@ -58,6 +59,10 @@ struct_to_argument_info_list, ) from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE +from starkware.starknet.public.abi_structs import ( + prepare_type_for_abi, + struct_definition_to_abi_entry, +) from starkware.starknet.security.secure_hints import HintsWhitelist, InsecureHintError from starkware.starknet.services.api.contract_definition import SUPPORTED_BUILTINS from starkware.starkware_utils.subsequence import is_subsequence @@ -92,6 +97,10 @@ def __init__(self, **kwargs): self.os_context: Optional[Dict[str, int]] = None # JSON dict for the ABI output. self.abi: List[dict] = [] + # A map from external struct (short) name to its ABI entry. + self.abi_structs: Dict[str, dict] = {} + # A map from external struct (short) name to the fully qualified name. + self.abi_structs_fullnames: Dict[str, ScopedName] = {} def get_external_decorator(self, elm: CodeElementFunction) -> Optional[ExprIdentifier]: """ @@ -278,6 +287,7 @@ def create_func_wrapper(self, elm: CodeElementFunction, func_alias_name: str): encoding_type=EncodingType.CALLDATA, has_range_check_builtin="range_check_ptr" in os_context, location=func_location, + identifiers=self.identifiers, ) for code_element in code_elements: @@ -373,12 +383,15 @@ def add_abi_entry( outputs = [] for m_name, member in arg_struct_def.members.items(): assert is_type_resolved(member.cairo_type) + abi_type_info = prepare_type_for_abi(member.cairo_type) inputs.append( { "name": m_name, - "type": member.cairo_type.format(), + "type": abi_type_info.modified_type.format(), } ) + for struct_name in abi_type_info.structs: + self.add_struct_to_abi(struct_name) for m_name, member in ret_struct_def.members.items(): assert isinstance(member.cairo_type, TypeFelt) outputs.append( @@ -397,11 +410,43 @@ def add_abi_entry( res["stateMutability"] = "view" self.abi.append(res) + def add_struct_to_abi(self, struct_name: ScopedName): + """ + Adds the given struct (add all the structs mentioned in its members) to self.abi_structs. + """ + + struct_definition = get_struct_definition( + struct_name=struct_name, identifier_manager=self.identifiers + ) + + short_name = struct_name.path[-1] + + if short_name in self.abi_structs: + existing_full_name = self.abi_structs_fullnames[short_name] + if existing_full_name != struct_name: + raise PreprocessorError( + f"Found two external structs named {short_name}: " + f"{existing_full_name}, {struct_name}.", + location=struct_definition.location, + ) + return + + abi_entry, inner_structs = struct_definition_to_abi_entry( + struct_definition=struct_definition + ) + + self.abi_structs_fullnames[short_name] = struct_name + self.abi_structs[short_name] = abi_entry + + # Visit the types of the inner structs recursively. + for name in inner_structs: + self.add_struct_to_abi(name) + def get_program(self) -> StarknetPreprocessedProgram: program = super().get_program() return StarknetPreprocessedProgram( # type: ignore **program.__dict__, - abi=self.abi, + abi=list(self.abi_structs.values()) + self.abi, ) def process_retdata( diff --git a/src/starkware/starknet/compiler/starknet_preprocessor_test.py b/src/starkware/starknet/compiler/starknet_preprocessor_test.py index 718463d0..fc8fb4cd 100644 --- a/src/starkware/starknet/compiler/starknet_preprocessor_test.py +++ b/src/starkware/starknet/compiler/starknet_preprocessor_test.py @@ -392,12 +392,26 @@ def test_invalid_hint(): ) -def test_abi(): +def test_abi_basic(): program = preprocess_str( """ %lang starknet %builtins range_check +namespace MyNamespace: + struct ExternalStruct: + member y: (felt, felt) + end +end + +struct ExternalStruct2: + member x: (felt, MyNamespace.ExternalStruct) +end + +struct NonExternalStruct: +end + + @external func f(a : felt, arr_len : felt, arr : felt*) -> (b : felt, c : felt): return (0, 1) @@ -409,13 +423,25 @@ def test_abi(): end @l1_handler -func handler(from_address): +func handler(from_address, a: ExternalStruct2): return () end """ ) assert program.abi == [ + { + "type": "struct", + "name": "ExternalStruct2", + "members": [{"name": "x", "offset": 0, "type": "(felt, ExternalStruct)"}], + "size": 3, + }, + { + "type": "struct", + "name": "ExternalStruct", + "members": [{"name": "y", "offset": 0, "type": "(felt, felt)"}], + "size": 2, + }, { "inputs": [ {"name": "a", "type": "felt"}, @@ -439,9 +465,49 @@ def test_abi(): "stateMutability": "view", }, { - "inputs": [{"name": "from_address", "type": "felt"}], + "inputs": [ + {"name": "from_address", "type": "felt"}, + {"name": "a", "type": "ExternalStruct2"}, + ], "name": "handler", "outputs": [], "type": "l1_handler", }, ] + + +def test_abi_failures(): + verify_exception( + """ +%lang starknet + +namespace a: + struct MyStruct: + end +end + +namespace b: + struct MyStruct: + end + + struct MyStruct2: + member x: ((MyStruct, MyStruct), felt) + end +end + +@external +func f(x : (felt, a.MyStruct)): + return() +end + +@view +func g(y : b.MyStruct2): + return() +end +""", + """ +file:?:?: Found two external structs named MyStruct: test_scope.a.MyStruct, test_scope.b.MyStruct. + struct MyStruct: + ^******^ +""", + ) diff --git a/src/starkware/starknet/compiler/storage_var.py b/src/starkware/starknet/compiler/storage_var.py index f1801223..0af90140 100644 --- a/src/starkware/starknet/compiler/storage_var.py +++ b/src/starkware/starknet/compiler/storage_var.py @@ -1,26 +1,17 @@ import dataclasses from typing import Optional, Tuple -from starkware.cairo.lang.compiler.ast.cairo_types import ( - CairoType, - TypeFelt, - TypePointer, - TypeStruct, - TypeTuple, -) -from starkware.cairo.lang.compiler.ast.code_elements import ( - CodeElementFunction, -) +from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeFelt, TypePointer +from starkware.cairo.lang.compiler.ast.code_elements import CodeElementFunction from starkware.cairo.lang.compiler.ast.formatting_utils import get_max_line_length from starkware.cairo.lang.compiler.error_handling import Location -from starkware.cairo.lang.compiler.identifier_manager import IdentifierManager -from starkware.cairo.lang.compiler.identifier_utils import get_struct_definition from starkware.cairo.lang.compiler.parser import parse from starkware.cairo.lang.compiler.preprocessor.identifier_aware_visitor import ( IdentifierAwareVisitor, ) from starkware.cairo.lang.compiler.preprocessor.preprocessor_error import PreprocessorError from starkware.cairo.lang.compiler.preprocessor.preprocessor_utils import verify_empty_code_block +from starkware.cairo.lang.compiler.type_utils import check_felts_only_type from starkware.starknet.definitions.constants import STARKNET_LANG_DIRECTIVE from starkware.starknet.public.abi import MAX_STORAGE_ITEM_SIZE, get_storage_var_address @@ -150,7 +141,10 @@ def process_storage_var(visitor: IdentifierAwareVisitor, elm: CodeElementFunctio unresolved_return_type = get_return_type(elm=elm) return_type = visitor.resolve_type(unresolved_return_type) - if not check_felts_only_type(cairo_type=return_type, identifier_manager=visitor.identifiers): + if ( + check_felts_only_type(cairo_type=return_type, identifier_manager=visitor.identifiers) + is None + ): raise PreprocessorError( "The return type of storage variables must consist of felts.", location=elm.returns.location if elm.returns is not None else elm.identifier.location, @@ -229,35 +223,6 @@ def is_storage_var(elm: CodeElementFunction) -> Tuple[bool, Optional[Location]]: return False, None -def check_felts_only_type(cairo_type: CairoType, identifier_manager: IdentifierManager) -> bool: - """ - A felts-only type defined to be either felt or a struct whose members are all felts-only types. - Return True if the given type is felts-only. - """ - - if isinstance(cairo_type, TypeFelt): - return True - elif isinstance(cairo_type, TypeStruct): - struct_definition = get_struct_definition( - cairo_type.resolved_scope, identifier_manager=identifier_manager - ) - for member_def in struct_definition.members.values(): - res = check_felts_only_type( - member_def.cairo_type, identifier_manager=identifier_manager - ) - if not res: - return False - return True - elif isinstance(cairo_type, TypeTuple): - for item_type in cairo_type.members: - res = check_felts_only_type(item_type, identifier_manager=identifier_manager) - if not res: - return False - return True - else: - return False - - class StorageVarDeclVisitor(IdentifierAwareVisitor): """ Replaces @storage_var decorated functions with a namespace with empty functions. diff --git a/src/starkware/starknet/compiler/storage_var_test.py b/src/starkware/starknet/compiler/storage_var_test.py index a24b7429..1d517582 100644 --- a/src/starkware/starknet/compiler/storage_var_test.py +++ b/src/starkware/starknet/compiler/storage_var_test.py @@ -1,62 +1,12 @@ import re -from starkware.cairo.lang.compiler.parser import parse_type from starkware.cairo.lang.compiler.preprocessor.preprocessor_test_utils import ( strip_comments_and_linebreaks, ) -from starkware.cairo.lang.compiler.type_system import mark_type_resolved -from starkware.starknet.compiler.storage_var import check_felts_only_type from starkware.starknet.compiler.test_utils import preprocess_str, verify_exception from starkware.starknet.public.abi import starknet_keccak -def test_check_felts_only_type(): - program = preprocess_str( - """ -struct A: - member x : felt -end - -struct B: -end - -struct C: - member x : felt - member y : (felt, A, B) - member z : A -end - -struct D: - member x : felt* -end - -struct E: - member x : D -end - """ - ) - - for (typ, expected_res) in [ - # Positive cases. - ("test_scope.A", True), - ("test_scope.B", True), - ("test_scope.C", True), - ("(felt, felt)", True), - ("(felt, (felt, test_scope.C))", True), - # Negative cases. - ("test_scope.D", False), - ("test_scope.E", False), - ("(felt, test_scope.D)", False), - ]: - assert ( - check_felts_only_type( - cairo_type=mark_type_resolved(parse_type(typ)), - identifier_manager=program.identifiers, - ) - == expected_res - ) - - def test_storage_var_success(): program = preprocess_str( """ diff --git a/src/starkware/starknet/public/CMakeLists.txt b/src/starkware/starknet/public/CMakeLists.txt index 2260a4cd..39306ece 100644 --- a/src/starkware/starknet/public/CMakeLists.txt +++ b/src/starkware/starknet/public/CMakeLists.txt @@ -2,8 +2,10 @@ python_lib(starknet_abi_lib PREFIX starkware/starknet/public FILES abi.py + abi_structs.py LIBS + cairo_compile_lib cairo_vm_crypto_lib pip_eth_hash pip_pycryptodome @@ -16,6 +18,7 @@ full_python_test(starknet_abi_lib_test TESTED_MODULES starkware/starknet/public/ FILES + abi_structs_test.py abi_test.py LIBS diff --git a/src/starkware/starknet/public/abi_structs.py b/src/starkware/starknet/public/abi_structs.py new file mode 100644 index 00000000..9b480b43 --- /dev/null +++ b/src/starkware/starknet/public/abi_structs.py @@ -0,0 +1,109 @@ +import dataclasses +from typing import Set, Tuple + +from starkware.cairo.lang.compiler.ast.cairo_types import ( + CairoType, + TypeFelt, + TypePointer, + TypeStruct, + TypeTuple, +) +from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition +from starkware.cairo.lang.compiler.parser import parse_type +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.type_system import mark_type_resolved + + +@dataclasses.dataclass +class AbiTypeInfo: + # The type after removing type qualification (for example, a.b.c.T -> T). + modified_type: CairoType + # All structs that appear inside the type. + structs: Set[ScopedName] + + +def prepare_type_for_abi(cairo_type: CairoType) -> AbiTypeInfo: + """ + Recursively visits the given type and returns an AbiTypeInfo instance. + """ + if isinstance(cairo_type, TypeTuple): + new_members = [] + structs = set() + for inner_type in cairo_type.members: + res = prepare_type_for_abi(inner_type) + structs |= res.structs + new_members.append(res.modified_type) + + return AbiTypeInfo( + modified_type=dataclasses.replace(cairo_type, members=new_members), + structs=structs, + ) + elif isinstance(cairo_type, TypeStruct): + struct_name = cairo_type.resolved_scope.path[-1] + + return AbiTypeInfo( + modified_type=dataclasses.replace( + cairo_type, scope=ScopedName.from_string(struct_name) + ), + structs={cairo_type.resolved_scope}, + ) + elif isinstance(cairo_type, TypeFelt): + return AbiTypeInfo(modified_type=cairo_type, structs=set()) + elif isinstance(cairo_type, TypePointer): + res = prepare_type_for_abi(cairo_type=cairo_type.pointee) + return AbiTypeInfo( + modified_type=dataclasses.replace(cairo_type, pointee=res.modified_type), + structs=res.structs, + ) + else: + raise NotImplementedError(f"Unexpected type: {cairo_type.format()}.") + + +def struct_definition_to_abi_entry( + struct_definition: StructDefinition, +) -> Tuple[dict, Set[ScopedName]]: + """ + Returns a tuple with: + 1. An ABI entry describing the given struct. + 2. A set of struct names that are used inside the struct members. + """ + members = [] + structs = set() + for name, member_definition in struct_definition.members.items(): + abi_type_info = prepare_type_for_abi(member_definition.cairo_type) + members.append( + { + "name": name, + "type": abi_type_info.modified_type.format(), + "offset": member_definition.offset, + } + ) + structs |= abi_type_info.structs + abi_entry = { + "name": struct_definition.full_name.path[-1], + "type": "struct", + "members": members, + "size": struct_definition.size, + } + return abi_entry, structs + + +def struct_definition_from_abi_entry(abi_entry: dict) -> StructDefinition: + """ + Converts an ABI entry of a struct to StructDefinition. + """ + assert ( + abi_entry["type"] == "struct" + ), f"Expected an entry of type 'struct'. Got: '{abi_entry['type']}'." + member_definitions = {} + for member in abi_entry["members"]: + member_definitions[member["name"]] = MemberDefinition( + cairo_type=mark_type_resolved(parse_type(member["type"])), + offset=member["offset"], + ) + return StructDefinition( + full_name=ScopedName.from_string(abi_entry["name"]), + members=member_definitions, + size=abi_entry["size"], + location=None, + ) diff --git a/src/starkware/starknet/public/abi_structs_test.py b/src/starkware/starknet/public/abi_structs_test.py new file mode 100644 index 00000000..1e0f9007 --- /dev/null +++ b/src/starkware/starknet/public/abi_structs_test.py @@ -0,0 +1,68 @@ +from starkware.cairo.lang.compiler.identifier_definition import MemberDefinition, StructDefinition +from starkware.cairo.lang.compiler.parser import parse_type +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.compiler.type_system import mark_type_resolved +from starkware.starknet.public.abi_structs import ( + AbiTypeInfo, + prepare_type_for_abi, + struct_definition_from_abi_entry, + struct_definition_to_abi_entry, +) + + +def test_prepare_type_for_abi(): + cairo_type = mark_type_resolved(parse_type("(felt, (a.b.c.MyStruct*, T)**)")) + expected_modified_type = mark_type_resolved(parse_type("(felt, (MyStruct*, T)**)")) + assert prepare_type_for_abi(cairo_type) == AbiTypeInfo( + modified_type=expected_modified_type, + structs={ScopedName.from_string("a.b.c.MyStruct"), ScopedName.from_string("T")}, + ) + + +def test_struct_definition_to_abi_entry(): + struct_definition = StructDefinition( + full_name=ScopedName.from_string("a.b.c.MyStruct"), + members={ + "x": MemberDefinition(offset=7, cairo_type=mark_type_resolved(parse_type("a.b.c.T*"))), + }, + size=1, + location=None, + ) + new_abi_entry, structs = struct_definition_to_abi_entry(struct_definition=struct_definition) + assert new_abi_entry == { + "type": "struct", + "name": "MyStruct", + "members": [ + { + "name": "x", + "offset": 7, + "type": "T*", + } + ], + "size": 1, + } + assert structs == {ScopedName.from_string("a.b.c.T")} + + +def test_abi_structs_both_directions(): + abi_entry = { + "name": "MyStruct", + "type": "struct", + "members": [ + { + "name": "x", + "type": "felt**", + "offset": 0, + }, + { + "name": "y", + "type": "(felt, MyStruct*)", + "offset": 1, + }, + ], + "size": 12, + } + struct_definition = struct_definition_from_abi_entry(abi_entry) + new_abi_entry, structs = struct_definition_to_abi_entry(struct_definition=struct_definition) + assert new_abi_entry == abi_entry + assert structs == {ScopedName.from_string("MyStruct")} diff --git a/src/starkware/starknet/security/CMakeLists.txt b/src/starkware/starknet/security/CMakeLists.txt index 72c58b22..4d9a07b2 100644 --- a/src/starkware/starknet/security/CMakeLists.txt +++ b/src/starkware/starknet/security/CMakeLists.txt @@ -39,8 +39,9 @@ python_lib(starknet_hints_whitelist_lib FILES hints_whitelist.py - whitelists/latest.json whitelists/cairo_keccak.json + whitelists/cairo_sha256.json + whitelists/latest.json LIBS starknet_security_lib diff --git a/src/starkware/starknet/security/hints_whitelist.py b/src/starkware/starknet/security/hints_whitelist.py index 7668b578..3eea00da 100644 --- a/src/starkware/starknet/security/hints_whitelist.py +++ b/src/starkware/starknet/security/hints_whitelist.py @@ -2,7 +2,6 @@ from starkware.starknet.security.secure_hints import HintsWhitelist - WHILTELIST_DIR = os.path.join(os.path.dirname(__file__), "whitelists") diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo index 64048715..7d56aa40 100644 --- a/src/starkware/starknet/security/starknet_common.cairo +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -7,10 +7,11 @@ from starkware.cairo.common.keccak import unsafe_keccak from starkware.cairo.common.math import ( abs_value, assert_250_bit, assert_in_range, assert_le, assert_le_felt, assert_lt, assert_lt_felt, assert_nn, assert_nn_le, assert_not_equal, assert_not_zero, sign, - signed_div_rem, split_felt, unsigned_div_rem) + signed_div_rem, split_felt, split_int, unsigned_div_rem) from starkware.cairo.common.math_cmp import ( is_in_range, is_le, is_le_felt, is_nn, is_nn_le, is_not_zero) from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.memset import memset from starkware.cairo.common.signature import verify_ecdsa_signature from starkware.cairo.common.squash_dict import squash_dict from starkware.cairo.common.uint256 import ( diff --git a/src/starkware/starknet/security/whitelists/cairo_sha256.json b/src/starkware/starknet/security/whitelists/cairo_sha256.json new file mode 100644 index 00000000..529d99ad --- /dev/null +++ b/src/starkware/starknet/security/whitelists/cairo_sha256.json @@ -0,0 +1,123 @@ +{ + "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [ + { + "expr": "[cast(ap + (-2), felt*)]", + "name": "sha256.finalize_sha256.__fp__" + }, + { + "expr": "[cast(fp + (-5), starkware.cairo.common.cairo_builtins.BitwiseBuiltin**)]", + "name": "sha256.finalize_sha256.bitwise_ptr" + }, + { + "expr": "[cast(ap + (-1), felt*)]", + "name": "sha256.finalize_sha256.n" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "sha256.finalize_sha256.range_check_ptr" + }, + { + "expr": "[cast(ap + (-1), felt**)]", + "name": "sha256.finalize_sha256.round_constants" + }, + { + "expr": "[cast(fp + (-3), felt**)]", + "name": "sha256.finalize_sha256.sha256_ptr_end" + }, + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "sha256.finalize_sha256.sha256_ptr_start" + } + ], + "hint_lines": [ + "# Add dummy pairs of input and output.", + "from starkware.cairo.common.cairo_sha256.sha256_utils import (", + " IV, compute_message_schedule, sha2_compress_function)", + "", + "_block_size = int(ids.BLOCK_SIZE)", + "assert 0 <= _block_size < 20", + "_sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS)", + "assert 0 <= _sha256_input_chunk_size_felts < 100", + "", + "message = [0] * _sha256_input_chunk_size_felts", + "w = compute_message_schedule(message)", + "output = sha2_compress_function(IV, w)", + "padding = (message + IV + output) * (_block_size - 1)", + "segments.write_arg(ids.sha256_ptr_end, padding)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-4), felt**)]", + "name": "sha256.sha256.input" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "sha256.sha256.n_bytes" + }, + { + "expr": "cast([ap + (-11)] + 10, felt*)", + "name": "sha256.sha256.output" + }, + { + "expr": "[cast(ap + (-2), felt*)]", + "name": "sha256.sha256.range_check_ptr" + }, + { + "expr": "cast([ap + (-11)] + 10, felt*)", + "name": "sha256.sha256.sha256_ptr" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "sha256.sha256.sha256_start" + } + ], + "hint_lines": [ + "from starkware.cairo.common.cairo_sha256.sha256_utils import (", + " IV, compute_message_schedule, sha2_compress_function)", + "", + "_sha256_input_chunk_size_felts = int(ids.SHA256_INPUT_CHUNK_SIZE_FELTS)", + "assert 0 <= _sha256_input_chunk_size_felts < 100", + "", + "w = compute_message_schedule(memory.get_range(", + " ids.sha256_start, _sha256_input_chunk_size_felts))", + "new_state = sha2_compress_function(IV, w)", + "segments.write_arg(ids.output, new_state)" + ] + }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp, felt*)]", + "name": "sha256._sha256_input.full_word" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "sha256._sha256_input.input" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "sha256._sha256_input.n_bytes" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "sha256._sha256_input.n_words" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "sha256._sha256_input.range_check_ptr" + }, + { + "expr": "[cast(fp + (-6), felt**)]", + "name": "sha256._sha256_input.sha256_ptr" + } + ], + "hint_lines": [ + "ids.full_word = int(ids.n_bytes >= 4)" + ] + } + ] +} diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json index 7728ad3c..6865d312 100644 --- a/src/starkware/starknet/security/whitelists/latest.json +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -307,6 +307,37 @@ "assert ids.n_used_accesses == len(access_indices[key])" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.math.split_int.base" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.split_int.bound" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.math.split_int.n" + }, + { + "expr": "[cast(fp + (-3), felt**)]", + "name": "starkware.cairo.common.math.split_int.output" + }, + { + "expr": "[cast(fp + (-8), felt*)]", + "name": "starkware.cairo.common.math.split_int.range_check_ptr" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "starkware.cairo.common.math.split_int.value" + } + ], + "hint_lines": [ + "assert ids.value == 0, 'split_int(): value is out of range.'" + ] + }, { "allowed_expressions": [ { @@ -1347,6 +1378,38 @@ "memory[ap] = segments.add()" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), felt*)]", + "name": "starkware.cairo.common.math.split_int.base" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.math.split_int.bound" + }, + { + "expr": "[cast(fp + (-6), felt*)]", + "name": "starkware.cairo.common.math.split_int.n" + }, + { + "expr": "[cast(fp + (-3), felt**)]", + "name": "starkware.cairo.common.math.split_int.output" + }, + { + "expr": "[cast(fp + (-8), felt*)]", + "name": "starkware.cairo.common.math.split_int.range_check_ptr" + }, + { + "expr": "[cast(fp + (-7), felt*)]", + "name": "starkware.cairo.common.math.split_int.value" + } + ], + "hint_lines": [ + "memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base", + "assert res < ids.bound, f'split_int(): Limb {res} is out of range.'" + ] + }, { "allowed_expressions": [ { @@ -1379,6 +1442,38 @@ "ids.continue_copying = 1 if n > 0 else 0" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.memset.memset.continue_loop" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "starkware.cairo.common.memset.memset.dst" + }, + { + "expr": "[cast(ap + (-1), starkware.cairo.common.memset.memset.LoopFrame*)]", + "name": "starkware.cairo.common.memset.memset.frame" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.memset.memset.n" + }, + { + "expr": "cast(ap + 1, starkware.cairo.common.memset.memset.LoopFrame*)", + "name": "starkware.cairo.common.memset.memset.next_frame" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.memset.memset.value" + } + ], + "hint_lines": [ + "n -= 1", + "ids.continue_loop = 1 if n > 0 else 0" + ] + }, { "allowed_expressions": [ { @@ -1609,6 +1704,25 @@ "vm_enter_scope({'n': ids.len})" ] }, + { + "allowed_expressions": [ + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "starkware.cairo.common.memset.memset.dst" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.memset.memset.n" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.memset.memset.value" + } + ], + "hint_lines": [ + "vm_enter_scope({'n': ids.n})" + ] + }, { "allowed_expressions": [ { @@ -1651,6 +1765,30 @@ "expr": "[cast(fp + (-4), felt**)]", "name": "starkware.cairo.common.memcpy.memcpy.src" }, + { + "expr": "[cast(ap, felt*)]", + "name": "starkware.cairo.common.memset.memset.continue_loop" + }, + { + "expr": "[cast(fp + (-5), felt**)]", + "name": "starkware.cairo.common.memset.memset.dst" + }, + { + "expr": "[cast(ap + (-1), starkware.cairo.common.memset.memset.LoopFrame*)]", + "name": "starkware.cairo.common.memset.memset.frame" + }, + { + "expr": "[cast(fp + (-3), felt*)]", + "name": "starkware.cairo.common.memset.memset.n" + }, + { + "expr": "cast(ap + 1, starkware.cairo.common.memset.memset.LoopFrame*)", + "name": "starkware.cairo.common.memset.memset.next_frame" + }, + { + "expr": "[cast(fp + (-4), felt*)]", + "name": "starkware.cairo.common.memset.memset.value" + }, { "expr": "[cast(fp + 2, felt*)]", "name": "starkware.cairo.common.squash_dict.squash_dict.big_keys" diff --git a/src/starkware/starknet/storage/starknet_storage.py b/src/starkware/starknet/storage/starknet_storage.py index 8c4fb184..cc9d3a88 100644 --- a/src/starkware/starknet/storage/starknet_storage.py +++ b/src/starkware/starknet/storage/starknet_storage.py @@ -1,12 +1,12 @@ -from abc import ABC, abstractmethod import asyncio import concurrent import dataclasses +from abc import ABC, abstractmethod from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME -from starkware.starkware_utils.binary_fact_tree import BinaryFactDict -from starkware.starkware_utils.patricia_tree.patricia_tree import PatriciaTree +from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict +from starkware.starkware_utils.commitment_tree.patricia_tree.patricia_tree import PatriciaTree from starkware.storage.storage import HASH_BYTES, Fact, FactFetchingContext, HashFunctionType diff --git a/src/starkware/starknet/testing/CMakeLists.txt b/src/starkware/starknet/testing/CMakeLists.txt index a2e34295..252379c1 100644 --- a/src/starkware/starknet/testing/CMakeLists.txt +++ b/src/starkware/starknet/testing/CMakeLists.txt @@ -4,6 +4,7 @@ python_lib(starknet_testing_lib FILES contract.py starknet.py + state.py LIBS cairo_vm_crypto_lib diff --git a/src/starkware/starknet/testing/contract.py b/src/starkware/starknet/testing/contract.py index acff487e..10ac5c49 100644 --- a/src/starkware/starknet/testing/contract.py +++ b/src/starkware/starknet/testing/contract.py @@ -1,8 +1,9 @@ -from collections import namedtuple +import sys import types -from typing import Dict, List, Optional, Tuple, Type, Union +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Type, Union -from starkware.starknet.testing.starknet import Starknet +from starkware.starknet.testing.state import StarknetState class StarknetContract: @@ -10,16 +11,16 @@ class StarknetContract: A high level interface to a StarkNet contract used for testing. Allows invoking functions. Example: contract_definition = compile_starknet_files(...) - starknet = await Starknet.empty() - contract_address = await starknet.deploy(contract_definition=contract_definition) + state = await StarknetState.empty() + contract_address = await state.deploy(contract_definition=contract_definition) contract = StarknetContract( - starknet=starknet, abi=contract_definition.abi, contract_address=contract_address) + state=state, abi=contract_definition.abi, contract_address=contract_address) await contract.foo(a=1, b=[2, 3]).invoke() """ - def __init__(self, starknet: Starknet, abi: dict, contract_address: Union[int, str]): - self.starknet = starknet + def __init__(self, state: StarknetState, abi: List[Any], contract_address: Union[int, str]): + self.state = state self._abi_function_mapping = { abi_entry["name"]: abi_entry for abi_entry in abi if abi_entry["type"] == "function" @@ -55,8 +56,12 @@ def template(): ) # Create a function like template(), but with extra arguments. - func_code = types.CodeType( + if sys.version_info.major != 3: + raise Exception("Must be using Python3.") + posonlyargcount = (0,) if sys.version_info.minor >= 8 else () + func_code = types.CodeType( # type: ignore len(arg_names), # Arg: argcount. + *posonlyargcount, # type: ignore 0, # Arg: kwonlyargcount. len(arg_names), # Arg: nlocals. template.__code__.co_stacksize + len(arg_names), # Arg: stacksize. @@ -171,7 +176,7 @@ def _build_function_call(self, function_abi: dict, args: dict, ret_tuple: Type): raise NotImplementedError return StarknetContractFunctionInvocation( - starknet=self.starknet, + state=self.state, contract_address=self.contract_address, function_abi=function_abi, calldata=calldata, @@ -186,13 +191,13 @@ class StarknetContractFunctionInvocation: def __init__( self, - starknet: Starknet, + state: StarknetState, contract_address: Union[int, str], function_abi: dict, calldata: List[int], ret_tuple: Type, ): - self.starknet = starknet + self.state = state self.contract_address = contract_address self.function_abi = function_abi self.calldata = calldata @@ -202,7 +207,7 @@ async def call(self) -> List[int]: """ Executes the function call, without changing the state. """ - execution_info = await self.starknet.copy().invoke_raw( + execution_info = await self.state.copy().invoke_raw( contract_address=self.contract_address, selector=self.function_abi["name"], calldata=self.calldata, @@ -213,7 +218,7 @@ async def invoke(self) -> List[int]: """ Executes the function call, and apply changes on the state. """ - execution_info = await self.starknet.invoke_raw( + execution_info = await self.state.invoke_raw( contract_address=self.contract_address, selector=self.function_abi["name"], calldata=self.calldata, diff --git a/src/starkware/starknet/testing/contract_test.py b/src/starkware/starknet/testing/contract_test.py index 47a19367..9c399fbe 100644 --- a/src/starkware/starknet/testing/contract_test.py +++ b/src/starkware/starknet/testing/contract_test.py @@ -1,9 +1,10 @@ import os + import pytest from starkware.starknet.compiler.compile import compile_starknet_files -from starkware.starknet.testing.starknet import Starknet from starkware.starknet.testing.contract import StarknetContract +from starkware.starknet.testing.state import StarknetState CONTRACT_FILE = os.path.join(os.path.dirname(__file__), "test.cairo") @@ -11,10 +12,10 @@ @pytest.mark.asyncio async def test_function_call(): contract_definition = compile_starknet_files([CONTRACT_FILE], debug_info=True) - starknet = await Starknet.empty() - contract_address = await starknet.deploy(contract_definition=contract_definition) + state = await StarknetState.empty() + contract_address = await state.deploy(contract_definition=contract_definition) contract = StarknetContract( - starknet=starknet, abi=contract_definition.abi, contract_address=contract_address + state=state, abi=contract_definition.abi, contract_address=contract_address ) await contract.increase_value(address=132, value=3).invoke() diff --git a/src/starkware/starknet/testing/starknet.py b/src/starkware/starknet/testing/starknet.py index 8cd9877d..294a464f 100644 --- a/src/starkware/starknet/testing/starknet.py +++ b/src/starkware/starknet/testing/starknet.py @@ -1,132 +1,50 @@ -import copy -from collections import defaultdict from typing import List, Optional, Union -from starkware.cairo.lang.vm.crypto import async_pedersen_hash_func -from starkware.starknet.business_logic.internal_transaction import ( - InternalDeploy, - InternalInvokeFunction, -) -from starkware.starknet.business_logic.internal_transaction_interface import ( - TransactionExecutionInfo, -) -from starkware.starknet.business_logic.state import CarriedState, SharedState -from starkware.starknet.business_logic.state_objects import ContractCarriedState, ContractState -from starkware.starknet.definitions import fields +from starkware.starknet.compiler.compile import compile_starknet_files from starkware.starknet.definitions.general_config import StarknetGeneralConfig -from starkware.starknet.public.abi import get_selector_from_name -from starkware.starknet.services.api.contract_definition import ContractDefinition, EntryPointType -from starkware.storage.dict_storage import DictStorage -from starkware.storage.storage import FactFetchingContext +from starkware.starknet.services.api.contract_definition import ContractDefinition +from starkware.starknet.testing.contract import StarknetContract +from starkware.starknet.testing.state import StarknetState class Starknet: """ - StarkNet testing object. Represents a state of a StarkNet network. - - Can be deepcopied. - - Example usage: + A high level interface to a StarkNet state object. + Example: starknet = await Starknet.empty() - contract_definition = compile_starknet_files([CONTRACT_FILE], debug_info=True) - contract_address = await starknet.deploy(contract_definition=contract_definition) - res = await starknet.invoke_raw( - contract_address=contract_address, selector="func", calldata=[1, 2]) + contract = await starknet.deploy('contract.cairo') + await contract.foo(a=1, b=[2, 3]).invoke() """ - def __init__(self, state: CarriedState, general_config: StarknetGeneralConfig): - """ - Constructor. Should not be used directly. Use empty() instead. - """ + def __init__(self, state: StarknetState): self.state = state - self.general_config = general_config - - def copy(self) -> "Starknet": - """ - Creates a new Starknet instance with the same state. And modifications to one instance - would not affect the other. - """ - return copy.deepcopy(self) @classmethod - async def empty(self, general_config: Optional[StarknetGeneralConfig] = None) -> "Starknet": - """ - Creates a new Starknet instance. - """ - if general_config is None: - general_config = StarknetGeneralConfig() - ffc = FactFetchingContext(storage=DictStorage(), hash_func=async_pedersen_hash_func) - empty_contract_state = await ContractState.empty( - storage_commitment_tree_height=general_config.contract_storage_commitment_tree_height, - ffc=ffc, - ) - empty_contract_carried_state = ContractCarriedState( - state=empty_contract_state, storage_updates={} - ) - shared_state = await SharedState.empty(ffc=ffc, general_config=general_config) - state = CarriedState.empty(shared_state=shared_state, ffc=ffc) - state.contract_states = defaultdict(lambda: copy.deepcopy(empty_contract_carried_state)) - return Starknet(state=state, general_config=general_config) + async def empty(cls, general_config: Optional[StarknetGeneralConfig] = None) -> "Starknet": + return Starknet(state=await StarknetState.empty(general_config=general_config)) async def deploy( self, - contract_definition: ContractDefinition, + source: Optional[str] = None, + contract_def: Optional[ContractDefinition] = None, contract_address: Optional[Union[int, str]] = None, - ) -> int: - """ - Deploys a contract. Returns the contract address. - - Args: - contract_definition - a compiled StarkNet contract returned by compile_starknet_files(). - contract_address - If supplied, a hexadecimal string or an integer representing the contract - address to use for deploying. Otherwise, the contract address is randomized. - """ - if contract_address is None: - contract_address = fields.ContractAddressField.get_random_value() - if isinstance(contract_address, str): - contract_address = int(contract_address, 16) - assert isinstance(contract_address, int) - - tx = InternalDeploy( - contract_address=contract_address, contract_definition=contract_definition - ) - - with self.state.copy_and_apply() as state_copy: - await tx.apply_state_updates(state=state_copy, general_config=self.general_config) - return contract_address - - async def invoke_raw( - self, - contract_address: Union[int, str], - selector: Union[int, str], - calldata: List[int], - entry_point_type: EntryPointType = EntryPointType.EXTERNAL, - ) -> TransactionExecutionInfo: - """ - Invokes a contract function. Returns the execution info. - - Args: - contract_address - a hexadecimal string or an integer representing the contract address. - selector - either a function name or an integer selector for the entrypoint to invoke. - calldata - a list of integers to pass as calldata to the invoked function. - """ - - if isinstance(contract_address, str): - contract_address = int(contract_address, 16) - assert isinstance(contract_address, int) - - if isinstance(selector, str): - selector = get_selector_from_name(selector) - assert isinstance(selector, int) - - tx = InternalInvokeFunction( - contract_address=contract_address, - entry_point_selector=selector, - entry_point_type=entry_point_type, - calldata=calldata, - ) - - with self.state.copy_and_apply() as state_copy: - return await tx.apply_state_updates( - state=state_copy, general_config=self.general_config + cairo_path: Optional[List[str]] = None, + ) -> StarknetContract: + assert (0 if source is None else 1) + ( + 0 if contract_def is None else 1 + ) == 1, "Exactly one of source, contract_def should be supplied." + if contract_def is None: + contract_def = compile_starknet_files( + files=[source], debug_info=True, cairo_path=cairo_path ) + source = None + cairo_path = None + assert ( + cairo_path is None + ), "The cairo_path argument can only be used with the source argument." + assert contract_def is not None + address = await self.state.deploy( + contract_definition=contract_def, contract_address=contract_address + ) + assert contract_def.abi is not None, "Missing ABI." + return StarknetContract(state=self.state, abi=contract_def.abi, contract_address=address) diff --git a/src/starkware/starknet/testing/starknet_test.py b/src/starkware/starknet/testing/starknet_test.py index 622f043e..7f8c1b2c 100644 --- a/src/starkware/starknet/testing/starknet_test.py +++ b/src/starkware/starknet/testing/starknet_test.py @@ -1,4 +1,5 @@ import os + import pytest from starkware.starknet.compiler.compile import compile_starknet_files @@ -9,16 +10,12 @@ @pytest.mark.asyncio async def test_basic(): - contract_definition = compile_starknet_files([CONTRACT_FILE], debug_info=True) starknet = await Starknet.empty() + contract = await starknet.deploy(CONTRACT_FILE) + res = await contract.increase_value(address=100, value=5).invoke() + assert res == () + assert await contract.get_value(address=100).call() == (5,) - contract_address = await starknet.deploy(contract_definition=contract_definition) - res = await starknet.invoke_raw( - contract_address=contract_address, selector="increase_value", calldata=[100, 5] - ) - assert res.retdata == [] - - res = await starknet.invoke_raw( - contract_address=contract_address, selector="get_value", calldata=[100] - ) - assert res.retdata == [5] + # Check deploy without compilation. + contract_def = compile_starknet_files(files=[CONTRACT_FILE]) + other_contract = await starknet.deploy(contract_def=contract_def) diff --git a/src/starkware/starknet/testing/state.py b/src/starkware/starknet/testing/state.py new file mode 100644 index 00000000..ebdac083 --- /dev/null +++ b/src/starkware/starknet/testing/state.py @@ -0,0 +1,130 @@ +import copy +from collections import defaultdict +from typing import List, Optional, Union + +from starkware.cairo.lang.vm.crypto import async_pedersen_hash_func +from starkware.starknet.business_logic.internal_transaction import ( + InternalDeploy, + InternalInvokeFunction, +) +from starkware.starknet.business_logic.internal_transaction_interface import ( + TransactionExecutionInfo, +) +from starkware.starknet.business_logic.state import CarriedState, SharedState +from starkware.starknet.business_logic.state_objects import ContractCarriedState, ContractState +from starkware.starknet.definitions import fields +from starkware.starknet.definitions.general_config import StarknetGeneralConfig +from starkware.starknet.public.abi import get_selector_from_name +from starkware.starknet.services.api.contract_definition import ContractDefinition, EntryPointType +from starkware.storage.dict_storage import DictStorage +from starkware.storage.storage import FactFetchingContext + + +class StarknetState: + """ + StarkNet testing object. Represents a state of a StarkNet network. + + Example usage: + starknet = await StarknetState.empty() + contract_definition = compile_starknet_files([CONTRACT_FILE], debug_info=True) + contract_address = await starknet.deploy(contract_definition=contract_definition) + res = await starknet.invoke_raw( + contract_address=contract_address, selector="func", calldata=[1, 2]) + """ + + def __init__(self, state: CarriedState, general_config: StarknetGeneralConfig): + """ + Constructor. Should not be used directly. Use empty() instead. + """ + self.state = state + self.general_config = general_config + + def copy(self) -> "StarknetState": + """ + Creates a new StarknetState instance with the same state. And modifications to one instance + would not affect the other. + """ + return copy.deepcopy(self) + + @classmethod + async def empty(cls, general_config: Optional[StarknetGeneralConfig] = None) -> "StarknetState": + """ + Creates a new StarknetState instance. + """ + if general_config is None: + general_config = StarknetGeneralConfig() + ffc = FactFetchingContext(storage=DictStorage(), hash_func=async_pedersen_hash_func) + empty_contract_state = await ContractState.empty( + storage_commitment_tree_height=general_config.contract_storage_commitment_tree_height, + ffc=ffc, + ) + empty_contract_carried_state = ContractCarriedState( + state=empty_contract_state, storage_updates={} + ) + shared_state = await SharedState.empty(ffc=ffc, general_config=general_config) + state = CarriedState.empty(shared_state=shared_state, ffc=ffc) + state.contract_states = defaultdict(lambda: copy.deepcopy(empty_contract_carried_state)) + return cls(state=state, general_config=general_config) + + async def deploy( + self, + contract_definition: ContractDefinition, + contract_address: Optional[Union[int, str]] = None, + ) -> int: + """ + Deploys a contract. Returns the contract address. + + Args: + contract_definition - a compiled StarkNet contract returned by compile_starknet_files(). + contract_address - If supplied, a hexadecimal string or an integer representing the contract + address to use for deploying. Otherwise, the contract address is randomized. + """ + if contract_address is None: + contract_address = fields.ContractAddressField.get_random_value() + if isinstance(contract_address, str): + contract_address = int(contract_address, 16) + assert isinstance(contract_address, int) + + tx = InternalDeploy( + contract_address=contract_address, contract_definition=contract_definition + ) + + with self.state.copy_and_apply() as state_copy: + await tx.apply_state_updates(state=state_copy, general_config=self.general_config) + return contract_address + + async def invoke_raw( + self, + contract_address: Union[int, str], + selector: Union[int, str], + calldata: List[int], + entry_point_type: EntryPointType = EntryPointType.EXTERNAL, + ) -> TransactionExecutionInfo: + """ + Invokes a contract function. Returns the execution info. + + Args: + contract_address - a hexadecimal string or an integer representing the contract address. + selector - either a function name or an integer selector for the entrypoint to invoke. + calldata - a list of integers to pass as calldata to the invoked function. + """ + + if isinstance(contract_address, str): + contract_address = int(contract_address, 16) + assert isinstance(contract_address, int) + + if isinstance(selector, str): + selector = get_selector_from_name(selector) + assert isinstance(selector, int) + + tx = InternalInvokeFunction( + contract_address=contract_address, + entry_point_selector=selector, + entry_point_type=entry_point_type, + calldata=calldata, + ) + + with self.state.copy_and_apply() as state_copy: + return await tx.apply_state_updates( + state=state_copy, general_config=self.general_config + ) diff --git a/src/starkware/starkware_utils/CMakeLists.txt b/src/starkware/starkware_utils/CMakeLists.txt index f6696351..e89e86ff 100644 --- a/src/starkware/starkware_utils/CMakeLists.txt +++ b/src/starkware/starkware_utils/CMakeLists.txt @@ -11,17 +11,17 @@ python_lib(starkware_utils_lib PREFIX starkware/starkware_utils FILES - binary_fact_tree.py - binary_fact_tree_node.py + commitment_tree/binary_fact_tree.py + commitment_tree/binary_fact_tree_node.py + commitment_tree/merkle_tree/traverse_tree.py + commitment_tree/patricia_tree/nodes.py + commitment_tree/patricia_tree/patricia_tree.py + commitment_tree/patricia_tree/virtual_patricia_node.py config_base.py custom_raising_dict.py error_handling.py field_validators.py marshmallow_dataclass_fields.py - merkle_tree/traverse_tree.py - patricia_tree/nodes.py - patricia_tree/patricia_tree.py - patricia_tree/virtual_patricia_node.py subsequence.py validated_dataclass.py validated_fields.py @@ -46,8 +46,8 @@ full_python_test(patricia_tree_test TESTED_MODULES starkware/starkware_utils FILES - patricia_tree/nodes_test.py - patricia_tree/virtual_patricia_node_test.py + commitment_tree/patricia_tree/nodes_test.py + commitment_tree/patricia_tree/virtual_patricia_node_test.py LIBS cairo_common_lib diff --git a/src/starkware/starkware_utils/commitment_tree/CMakeLists.txt b/src/starkware/starkware_utils/commitment_tree/CMakeLists.txt new file mode 100644 index 00000000..bf5f3e72 --- /dev/null +++ b/src/starkware/starkware_utils/commitment_tree/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(merkle_tree) diff --git a/src/starkware/starkware_utils/patricia_tree/__init__.py b/src/starkware/starkware_utils/commitment_tree/__init__.py similarity index 100% rename from src/starkware/starkware_utils/patricia_tree/__init__.py rename to src/starkware/starkware_utils/commitment_tree/__init__.py diff --git a/src/starkware/starkware_utils/binary_fact_tree.py b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py similarity index 70% rename from src/starkware/starkware_utils/binary_fact_tree.py rename to src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py index 7ffd313e..1abae122 100644 --- a/src/starkware/starkware_utils/binary_fact_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree.py @@ -4,7 +4,7 @@ from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass from starkware.storage.storage import Fact, FactFetchingContext -TFact = TypeVar('TFact', bound=Fact) +TFact = TypeVar("TFact", bound=Fact) BinaryFactDict = Dict[bytes, Tuple[bytes, ...]] @@ -17,23 +17,31 @@ class BinaryFactTree(ValidatedMarshmallowDataclass): @classmethod @abstractmethod async def empty_tree( - cls, ffc: FactFetchingContext, height: int, leaf_fact: Fact) -> 'BinaryFactTree': + cls, ffc: FactFetchingContext, height: int, leaf_fact: Fact + ) -> "BinaryFactTree": """ Initializes an empty BinaryFactTree of the given height. """ @abstractmethod async def get_leaves( - self, ffc: FactFetchingContext, indices: Collection[int], fact_cls: Type[TFact], - facts: Optional[BinaryFactDict] = None) -> Dict[int, TFact]: + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TFact], + facts: Optional[BinaryFactDict] = None, + ) -> Dict[int, TFact]: """ Returns the values of the leaves whose indices are given. """ @abstractmethod async def update( - self, ffc: FactFetchingContext, modifications: Collection[Tuple[int, Fact]], - facts: Optional[BinaryFactDict] = None) -> 'BinaryFactTree': + self, + ffc: FactFetchingContext, + modifications: Collection[Tuple[int, Fact]], + facts: Optional[BinaryFactDict] = None, + ) -> "BinaryFactTree": """ Updates the tree with the given list of modifications, writes all the new facts to the storage and returns a new BinaryFactTree representing the fact of the root of the new tree. @@ -47,7 +55,8 @@ async def get_leaf(self, ffc: FactFetchingContext, index: int, fact_cls: Type[TF Returns the value of a single leaf whose index is given. """ leaves = await self.get_leaves(ffc=ffc, indices=[index], fact_cls=fact_cls) - assert leaves.keys() == {index}, ( - f'get_leaves() on single leaf index returned an unexpected result.') + assert leaves.keys() == { + index + }, f"get_leaves() on single leaf index returned an unexpected result." return leaves[index] diff --git a/src/starkware/starkware_utils/binary_fact_tree_node.py b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py similarity index 79% rename from src/starkware/starkware_utils/binary_fact_tree_node.py rename to src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py index e9564600..4abfbcd0 100644 --- a/src/starkware/starkware_utils/binary_fact_tree_node.py +++ b/src/starkware/starkware_utils/commitment_tree/binary_fact_tree_node.py @@ -1,16 +1,28 @@ import asyncio from abc import ABC, abstractmethod from typing import ( - Any, AsyncIterator, Collection, Dict, NamedTuple, Optional, Tuple, Type, TypeVar, Union, cast) - -from starkware.starkware_utils.binary_fact_tree import BinaryFactDict, TFact -from starkware.starkware_utils.merkle_tree.traverse_tree import traverse_tree + Any, + AsyncIterator, + Collection, + Dict, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict, TFact +from starkware.starkware_utils.commitment_tree.merkle_tree.traverse_tree import traverse_tree from starkware.storage.storage import Fact, FactFetchingContext -TBinaryFactTreeNode = TypeVar('TBinaryFactTreeNode', bound='BinaryFactTreeNode') +TBinaryFactTreeNode = TypeVar("TBinaryFactTreeNode", bound="BinaryFactTreeNode") UpdateTree = Optional[Union[Tuple[Any, Any], Fact]] -NodeType = NamedTuple('NodeType', [ - ('index', int), ('tree', 'BinaryFactTreeNode'), ('update', UpdateTree)]) +NodeType = NamedTuple( + "NodeType", [("index", int), ("tree", "BinaryFactTreeNode"), ("update", UpdateTree)] +) class BinaryFactTreeNode(ABC): @@ -30,8 +42,9 @@ def leaf_hash(self) -> bytes: Returns the hash of self, which must be a leaf to use this property. """ assert self.is_leaf, ( - f'leaf_hash property must only be called on leaf nodes; got: height ' - f'{self.get_height_in_tree()}.') + f"leaf_hash property must only be called on leaf nodes; got: height " + f"{self.get_height_in_tree()}." + ) return self._leaf_hash @@ -54,9 +67,12 @@ def create_leaf(cls: Type[TBinaryFactTreeNode], hash_value: bytes) -> TBinaryFac @classmethod @abstractmethod async def combine( - cls: Type[TBinaryFactTreeNode], ffc: FactFetchingContext, left: 'BinaryFactTreeNode', - right: 'BinaryFactTreeNode', - facts: Optional[BinaryFactDict] = None) -> TBinaryFactTreeNode: + cls: Type[TBinaryFactTreeNode], + ffc: FactFetchingContext, + left: "BinaryFactTreeNode", + right: "BinaryFactTreeNode", + facts: Optional[BinaryFactDict] = None, + ) -> TBinaryFactTreeNode: """ Gets two BinaryFactTreeNode objects left and right representing children nodes, and builds their parent node. Returns a new BinaryFactTreeNode. @@ -66,8 +82,8 @@ async def combine( @abstractmethod async def get_children( - self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] = None) -> Tuple[ - 'BinaryFactTreeNode', 'BinaryFactTreeNode']: + self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] = None + ) -> Tuple["BinaryFactTreeNode", "BinaryFactTreeNode"]: """ Returns the two BinaryFactTreeNode objects which are the roots of the subtrees of the current BinaryFactTreeNode. @@ -76,8 +92,12 @@ async def get_children( """ async def _get_leaves( - self, ffc: FactFetchingContext, indices: Collection[int], - fact_cls: Type[TFact], facts: Optional[BinaryFactDict] = None) -> Dict[int, TFact]: + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TFact], + facts: Optional[BinaryFactDict] = None, + ) -> Dict[int, TFact]: """ Returns the values of the leaves whose indices are given. @@ -88,20 +108,23 @@ async def _get_leaves( (derived class of BinaryFactTree). """ assert not issubclass(fact_cls, InnerNodeFact), ( - f'Leaf fact class object {fact_cls.__name__} must not inherit from ' - f'{InnerNodeFact.__name__}.') + f"Leaf fact class object {fact_cls.__name__} must not inherit from " + f"{InnerNodeFact.__name__}." + ) def unify_leaves( - left_leaves: Dict[int, TFact], right_leaves: Dict[int, TFact]) -> Dict[int, TFact]: + left_leaves: Dict[int, TFact], right_leaves: Dict[int, TFact] + ) -> Dict[int, TFact]: return {**left_leaves, **{x + mid: y for x, y in right_leaves.items()}} if len(indices) == 0: return {} if self.is_leaf: - assert set(indices) == {0}, f'Merkle tree indices out of range: {indices}.' + assert set(indices) == {0}, f"Merkle tree indices out of range: {indices}." leaf = await get_node_fact_or_fail( - ffc=ffc, node_fact_cls=fact_cls, fact_hash=self.leaf_hash) + ffc=ffc, node_fact_cls=fact_cls, fact_hash=self.leaf_hash + ) return {0: leaf} @@ -115,24 +138,29 @@ def unify_leaves( # execution of the recursive task. if len(left_indices) == 0: right_leaves = await right_child._get_leaves( - ffc=ffc, indices=right_indices, fact_cls=fact_cls, facts=facts) + ffc=ffc, indices=right_indices, fact_cls=fact_cls, facts=facts + ) return unify_leaves(right_leaves=right_leaves, left_leaves={}) if len(right_indices) == 0: left_leaves = await left_child._get_leaves( - ffc=ffc, indices=left_indices, fact_cls=fact_cls, facts=facts) + ffc=ffc, indices=left_indices, fact_cls=fact_cls, facts=facts + ) return unify_leaves(right_leaves={}, left_leaves=left_leaves) left_leaves, right_leaves = await asyncio.gather( left_child._get_leaves(ffc=ffc, indices=left_indices, fact_cls=fact_cls, facts=facts), - right_child._get_leaves(ffc=ffc, indices=right_indices, fact_cls=fact_cls, facts=facts)) + right_child._get_leaves(ffc=ffc, indices=right_indices, fact_cls=fact_cls, facts=facts), + ) return unify_leaves(left_leaves=left_leaves, right_leaves=right_leaves) async def _update( - self: TBinaryFactTreeNode, ffc: FactFetchingContext, - modifications: Collection[Tuple[int, Fact]], - facts: Optional[BinaryFactDict] = None) -> TBinaryFactTreeNode: + self: TBinaryFactTreeNode, + ffc: FactFetchingContext, + modifications: Collection[Tuple[int, Fact]], + facts: Optional[BinaryFactDict] = None, + ) -> TBinaryFactTreeNode: """ Updates the tree with the given list of modifications, writes all the new facts to the storage and returns a new BinaryFactTree representing the fact of the root of the new tree. @@ -164,7 +192,8 @@ async def update_necessary(node_index: int): ffc=ffc, left=updated_nodes[2 * node_index], right=updated_nodes[2 * node_index + 1], - facts=facts) + facts=facts, + ) del updated_nodes[2 * node_index] del updated_nodes[2 * node_index + 1] @@ -201,10 +230,12 @@ async def traverse_node(node: NodeType) -> AsyncIterator[NodeType]: yield NodeType(index=2 * node_index + 1, tree=right, update=update_subtree[1]) update_tree = build_update_tree( - height=self.get_height_in_tree(), modifications=modifications) + height=self.get_height_in_tree(), modifications=modifications + ) first_node = NodeType(index=1, tree=self, update=update_tree) await traverse_tree( - get_children_callback=traverse_node, root=first_node, n_workers=ffc.n_workers) + get_children_callback=traverse_node, root=first_node, n_workers=ffc.n_workers + ) # Since the updated_nodes dictionary cleans itself, we expect only the new root to be # present, at node index 1. @@ -224,22 +255,27 @@ def to_tuple(self) -> Tuple[bytes, ...]: """ -TInnerNodeFact = TypeVar('TInnerNodeFact', bound=InnerNodeFact) +TInnerNodeFact = TypeVar("TInnerNodeFact", bound=InnerNodeFact) async def get_node_fact_or_fail( - ffc: FactFetchingContext, node_fact_cls: Type[TFact], fact_hash: bytes) -> TFact: + ffc: FactFetchingContext, node_fact_cls: Type[TFact], fact_hash: bytes +) -> TFact: node_fact = await node_fact_cls.get(storage=ffc.storage, suffix=fact_hash) - assert node_fact is not None, f'Fact missing from DB: 0x{fact_hash.hex()}.' + assert node_fact is not None, f"Fact missing from DB: 0x{fact_hash.hex()}." return node_fact async def read_node_fact( - ffc: FactFetchingContext, inner_node_fact_cls: Type[TInnerNodeFact], fact_hash: bytes, - facts: Optional[BinaryFactDict]) -> TInnerNodeFact: + ffc: FactFetchingContext, + inner_node_fact_cls: Type[TInnerNodeFact], + fact_hash: bytes, + facts: Optional[BinaryFactDict], +) -> TInnerNodeFact: node_fact = await get_node_fact_or_fail( - ffc=ffc, node_fact_cls=inner_node_fact_cls, fact_hash=fact_hash) + ffc=ffc, node_fact_cls=inner_node_fact_cls, fact_hash=fact_hash + ) if facts is not None: facts[fact_hash] = node_fact.to_tuple() @@ -248,8 +284,8 @@ async def read_node_fact( async def write_node_fact( - ffc: FactFetchingContext, inner_node_fact: InnerNodeFact, - facts: Optional[BinaryFactDict]) -> bytes: + ffc: FactFetchingContext, inner_node_fact: InnerNodeFact, facts: Optional[BinaryFactDict] +) -> bytes: fact_hash = await inner_node_fact.set_fact(ffc=ffc) if facts is not None: diff --git a/src/starkware/starkware_utils/merkle_tree/traverse_tree.py b/src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py similarity index 89% rename from src/starkware/starkware_utils/merkle_tree/traverse_tree.py rename to src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py index feba91a3..b9f1acdc 100644 --- a/src/starkware/starkware_utils/merkle_tree/traverse_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/merkle_tree/traverse_tree.py @@ -1,12 +1,14 @@ import asyncio from typing import AsyncIterator, Callable, Optional, TypeVar -NodeType = TypeVar('NodeType') +NodeType = TypeVar("NodeType") async def traverse_tree( - get_children_callback: Callable[[NodeType], AsyncIterator[NodeType]], root: NodeType, - n_workers: Optional[int] = None): + get_children_callback: Callable[[NodeType], AsyncIterator[NodeType]], + root: NodeType, + n_workers: Optional[int] = None, +): """ Traverses a tree as follows: 1. Starts by calling get_children_callback(root). This function should return the children of diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/__init__.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starkware_utils/patricia_tree/nodes.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py similarity index 77% rename from src/starkware/starkware_utils/patricia_tree/nodes.py rename to src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py index 35a081ab..6e97005d 100644 --- a/src/starkware/starkware_utils/patricia_tree/nodes.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes.py @@ -3,7 +3,7 @@ from typing import ClassVar, List, Tuple, Type from starkware.python.utils import blockify, from_bytes, to_bytes -from starkware.starkware_utils.binary_fact_tree_node import InnerNodeFact +from starkware.starkware_utils.commitment_tree.binary_fact_tree_node import InnerNodeFact from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.storage.storage import HASH_BYTES, HashFunctionType @@ -18,7 +18,7 @@ class PatriciaNodeFact(InnerNodeFact, ValidatedDataclass): @classmethod def prefix(cls) -> bytes: - return b'patricia_node' + return b"patricia_node" @property @classmethod @@ -31,7 +31,7 @@ def PREIMAGE_LENGTH(self) -> int: """ @classmethod - def deserialize(cls, data: bytes) -> 'PatriciaNodeFact': + def deserialize(cls, data: bytes) -> "PatriciaNodeFact": node_fact_cls = get_node_type(fact_preimage=data) return node_fact_cls.deserialize(data=data) @@ -46,10 +46,10 @@ class EmptyNodeFact(PatriciaNodeFact): EMPTY_NODE_HASH: ClassVar[bytes] = bytes(PatriciaNodeFact.HASH_BYTES_LENGTH) def serialize(self) -> bytes: - return b'' + return b"" @classmethod - def deserialize(cls, data: bytes) -> 'EmptyNodeFact': + def deserialize(cls, data: bytes) -> "EmptyNodeFact": return cls() async def _hash(self, hash_func: HashFunctionType) -> bytes: @@ -63,8 +63,9 @@ def to_tuple(self) -> Tuple[bytes, ...]: def verify_path_value(path: int, length: int): - assert 0 <= path < (1 << length), ( - f'Edge path must be at most of length {length}; got: {bin(path)}.') + assert ( + 0 <= path < (1 << length) + ), f"Edge path must be at most of length {length}; got: {bin(path)}." @dataclasses.dataclass(frozen=True) @@ -83,17 +84,21 @@ def __post_init__(self): super().__post_init__() legal_binary_node = ( - self.left_node != EmptyNodeFact.EMPTY_NODE_HASH and - self.right_node != EmptyNodeFact.EMPTY_NODE_HASH) - assert legal_binary_node, ( - 'It is not allowed for any child of a binary node to be the empty node.') + self.left_node != EmptyNodeFact.EMPTY_NODE_HASH + and self.right_node != EmptyNodeFact.EMPTY_NODE_HASH + ) + assert ( + legal_binary_node + ), "It is not allowed for any child of a binary node to be the empty node." def serialize(self) -> bytes: return self.left_node + self.right_node @classmethod - def deserialize(cls, data: bytes) -> 'BinaryNodeFact': - return cls(left_node=data[:cls.HASH_BYTES_LENGTH], right_node=data[cls.HASH_BYTES_LENGTH:]) + def deserialize(cls, data: bytes) -> "BinaryNodeFact": + return cls( + left_node=data[: cls.HASH_BYTES_LENGTH], right_node=data[cls.HASH_BYTES_LENGTH :] + ) async def _hash(self, hash_func: HashFunctionType) -> bytes: """ @@ -129,19 +134,22 @@ class EdgeNodeFact(PatriciaNodeFact): def __post_init__(self): super().__post_init__() - assert self.edge_length > 0, ( - f'The length of an edge node must be positive; got: {self.edge_length}.') + assert ( + self.edge_length > 0 + ), f"The length of an edge node must be positive; got: {self.edge_length}." verify_path_value(path=self.edge_path, length=self.edge_length) def serialize(self) -> bytes: return self.bottom_node + to_bytes(self.edge_path) + to_bytes(self.edge_length, length=1) @classmethod - def deserialize(cls, data: bytes) -> 'EdgeNodeFact': + def deserialize(cls, data: bytes) -> "EdgeNodeFact": bottom_node, edge_path, edge_length = blockify(data=data, chunk_size=cls.HASH_BYTES_LENGTH) return cls( - bottom_node=bottom_node, edge_path=from_bytes(edge_path), - edge_length=from_bytes(edge_length)) + bottom_node=bottom_node, + edge_path=from_bytes(edge_path), + edge_length=from_bytes(edge_length), + ) async def _hash(self, hash_func: HashFunctionType) -> bytes: """ @@ -158,7 +166,7 @@ def to_tuple(self) -> Tuple[bytes, ...]: return to_bytes(self.edge_length), to_bytes(self.edge_path), self.bottom_node -def get_node_type(fact_preimage: bytes) -> Type['PatriciaNodeFact']: +def get_node_type(fact_preimage: bytes) -> Type["PatriciaNodeFact"]: """ Returns the node fact type according to the fact preimage length. """ @@ -169,4 +177,4 @@ def get_node_type(fact_preimage: bytes) -> Type['PatriciaNodeFact']: if preimage_length == node_fact_cls.PREIMAGE_LENGTH: return node_fact_cls - raise NotImplementedError(f'Unsupported fact preimage length: {preimage_length}.') + raise NotImplementedError(f"Unsupported fact preimage length: {preimage_length}.") diff --git a/src/starkware/starkware_utils/patricia_tree/nodes_test.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes_test.py similarity index 85% rename from src/starkware/starkware_utils/patricia_tree/nodes_test.py rename to src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes_test.py index 5520a99f..e10c108b 100644 --- a/src/starkware/starkware_utils/patricia_tree/nodes_test.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/nodes_test.py @@ -1,8 +1,12 @@ import pytest from starkware.python.utils import from_bytes, to_bytes -from starkware.starkware_utils.patricia_tree.nodes import ( - BinaryNodeFact, EdgeNodeFact, EmptyNodeFact, get_node_type) +from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import ( + BinaryNodeFact, + EdgeNodeFact, + EmptyNodeFact, + get_node_type, +) from starkware.storage.test_utils import hash_func @@ -24,8 +28,9 @@ async def test_binary_node(): left_node, right_node = nodes with pytest.raises( - AssertionError, - match='It is not allowed for any child of a binary node to be the empty node.'): + AssertionError, + match="It is not allowed for any child of a binary node to be the empty node.", + ): BinaryNodeFact(left_node=left_node, right_node=right_node) diff --git a/src/starkware/starkware_utils/patricia_tree/patricia_tree.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py similarity index 64% rename from src/starkware/starkware_utils/patricia_tree/patricia_tree.py rename to src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py index e5bbe2f1..c045a8db 100644 --- a/src/starkware/starkware_utils/patricia_tree/patricia_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py @@ -3,9 +3,15 @@ import marshmallow_dataclass -from starkware.starkware_utils.binary_fact_tree import BinaryFactDict, BinaryFactTree, TFact -from starkware.starkware_utils.patricia_tree.nodes import EmptyNodeFact -from starkware.starkware_utils.patricia_tree.virtual_patricia_node import VirtualPatriciaNode +from starkware.starkware_utils.commitment_tree.binary_fact_tree import ( + BinaryFactDict, + BinaryFactTree, + TFact, +) +from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import EmptyNodeFact +from starkware.starkware_utils.commitment_tree.patricia_tree.virtual_patricia_node import ( + VirtualPatriciaNode, +) from starkware.starkware_utils.validated_fields import bytes_as_hex_metadata from starkware.storage.storage import Fact, FactFetchingContext @@ -21,37 +27,48 @@ class PatriciaTree(BinaryFactTree): @classmethod async def empty_tree( - cls, ffc: FactFetchingContext, height: int, leaf_fact: Fact) -> 'PatriciaTree': + cls, ffc: FactFetchingContext, height: int, leaf_fact: Fact + ) -> "PatriciaTree": """ Initializes an empty PatriciaTree of the given height. """ empty_leaf_fact_hash = await leaf_fact.set_fact(ffc=ffc) assert empty_leaf_fact_hash == EmptyNodeFact.EMPTY_NODE_HASH, ( - f'The hash value of an empty leaf fact must be {EmptyNodeFact.EMPTY_NODE_HASH.hex()}; ' - f'got: {empty_leaf_fact_hash.hex()}.') + f"The hash value of an empty leaf fact must be {EmptyNodeFact.EMPTY_NODE_HASH.hex()}; " + f"got: {empty_leaf_fact_hash.hex()}." + ) return PatriciaTree(root=EmptyNodeFact.EMPTY_NODE_HASH, height=height) async def get_leaves( - self, ffc: FactFetchingContext, indices: Collection[int], fact_cls: Type[TFact], - facts: Optional[BinaryFactDict] = None) -> Dict[int, TFact]: + self, + ffc: FactFetchingContext, + indices: Collection[int], + fact_cls: Type[TFact], + facts: Optional[BinaryFactDict] = None, + ) -> Dict[int, TFact]: """ Returns the values of the leaves whose indices are given. """ virtual_root_node = VirtualPatriciaNode.from_hash(hash_value=self.root, height=self.height) return await virtual_root_node._get_leaves( - ffc=ffc, indices=indices, fact_cls=fact_cls, facts=facts) + ffc=ffc, indices=indices, fact_cls=fact_cls, facts=facts + ) async def update( - self, ffc: FactFetchingContext, modifications: Collection[Tuple[int, Fact]], - facts: Optional[BinaryFactDict] = None) -> 'PatriciaTree': + self, + ffc: FactFetchingContext, + modifications: Collection[Tuple[int, Fact]], + facts: Optional[BinaryFactDict] = None, + ) -> "PatriciaTree": """ Updates the tree with the given list of modifications, writes all the new facts to the storage and returns a new PatriciaTree representing the fact of the root of the new tree. """ virtual_root_node = VirtualPatriciaNode.from_hash(hash_value=self.root, height=self.height) updated_virtual_root_node = await virtual_root_node._update( - ffc=ffc, modifications=modifications, facts=facts) + ffc=ffc, modifications=modifications, facts=facts + ) # In case root is an edge node, its fact must be explicitly written to DB. root_hash = await updated_virtual_root_node.commit(ffc=ffc, facts=facts) diff --git a/src/starkware/starkware_utils/patricia_tree/virtual_patricia_node.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py similarity index 70% rename from src/starkware/starkware_utils/patricia_tree/virtual_patricia_node.py rename to src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py index e6b9528e..53182d86 100644 --- a/src/starkware/starkware_utils/patricia_tree/virtual_patricia_node.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node.py @@ -2,11 +2,19 @@ import dataclasses from typing import Optional, Tuple -from starkware.starkware_utils.binary_fact_tree import BinaryFactDict -from starkware.starkware_utils.binary_fact_tree_node import ( - BinaryFactTreeNode, read_node_fact, write_node_fact) -from starkware.starkware_utils.patricia_tree.nodes import ( - BinaryNodeFact, EdgeNodeFact, EmptyNodeFact, PatriciaNodeFact, verify_path_value) +from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict +from starkware.starkware_utils.commitment_tree.binary_fact_tree_node import ( + BinaryFactTreeNode, + read_node_fact, + write_node_fact, +) +from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import ( + BinaryNodeFact, + EdgeNodeFact, + EmptyNodeFact, + PatriciaNodeFact, + verify_path_value, +) from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.storage.storage import FactFetchingContext @@ -40,25 +48,29 @@ def __post_init__(self): verify_path_value(path=self.path, length=self.length) @classmethod - def empty_node(cls, height: int) -> 'VirtualPatriciaNode': + def empty_node(cls, height: int) -> "VirtualPatriciaNode": return cls(bottom_node=EmptyNodeFact.EMPTY_NODE_HASH, path=0, length=0, height=height) @classmethod - def from_hash(cls, hash_value: bytes, height: int) -> 'VirtualPatriciaNode': + def from_hash(cls, hash_value: bytes, height: int) -> "VirtualPatriciaNode": """ Returns a virtual Patricia node of the form (hash, 0, 0). """ return cls(bottom_node=hash_value, path=0, length=0, height=height) @classmethod - def create_leaf(cls, hash_value: bytes) -> 'VirtualPatriciaNode': + def create_leaf(cls, hash_value: bytes) -> "VirtualPatriciaNode": return cls.from_hash(hash_value=hash_value, height=0) async def read_bottom_node_fact( - self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict]) -> PatriciaNodeFact: + self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] + ) -> PatriciaNodeFact: return await read_node_fact( - ffc=ffc, inner_node_fact_cls=PatriciaNodeFact, # type: ignore - fact_hash=self.bottom_node, facts=facts) + ffc=ffc, + inner_node_fact_cls=PatriciaNodeFact, # type: ignore + fact_hash=self.bottom_node, + facts=facts, + ) @property def is_empty(self) -> bool: @@ -85,12 +97,13 @@ async def commit(self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] return self.bottom_node edge_node_fact = EdgeNodeFact( - bottom_node=self.bottom_node, edge_path=self.path, edge_length=self.length) + bottom_node=self.bottom_node, edge_path=self.path, edge_length=self.length + ) return await write_node_fact(ffc=ffc, inner_node_fact=edge_node_fact, facts=facts) async def decommit( - self, ffc: FactFetchingContext, - facts: Optional[BinaryFactDict]) -> 'VirtualPatriciaNode': + self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] + ) -> "VirtualPatriciaNode": """ Returns the canonical representation of the information embedded in self. Returns (bottom, path, length) for an edge node of form (hash, 0, 0), which is the @@ -110,16 +123,22 @@ async def decommit( return self if isinstance(root_node_fact, EdgeNodeFact): return VirtualPatriciaNode( - bottom_node=root_node_fact.bottom_node, path=root_node_fact.edge_path, - length=root_node_fact.edge_length, height=self.height) + bottom_node=root_node_fact.bottom_node, + path=root_node_fact.edge_path, + length=root_node_fact.edge_length, + height=self.height, + ) - raise NotImplementedError(f'Unexpected node fact type: {type(root_node_fact).__name__}.') + raise NotImplementedError(f"Unexpected node fact type: {type(root_node_fact).__name__}.") @classmethod async def combine( - cls, ffc: FactFetchingContext, left: 'BinaryFactTreeNode', - right: 'BinaryFactTreeNode', - facts: Optional[BinaryFactDict] = None) -> 'VirtualPatriciaNode': + cls, + ffc: FactFetchingContext, + left: "BinaryFactTreeNode", + right: "BinaryFactTreeNode", + facts: Optional[BinaryFactDict] = None, + ) -> "VirtualPatriciaNode": """ Gets two VirtualPatriciaNode objects left and right representing children nodes, and builds their parent node. Returns a new VirtualPatriciaNode. @@ -129,8 +148,9 @@ async def combine( # Downcast arguments. assert isinstance(left, VirtualPatriciaNode) and isinstance(right, VirtualPatriciaNode) - assert right.height == left.height, ( - f'Only trees of same height can be combined; got: {right.height} and {left.height}.') + assert ( + right.height == left.height + ), f"Only trees of same height can be combined; got: {right.height} and {left.height}." parent_height = right.height + 1 if left.is_empty and right.is_empty: @@ -142,15 +162,15 @@ async def combine( return await cls._combine_to_virtual_edge_node(ffc=ffc, left=left, right=right, facts=facts) async def get_children( - self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] = None) -> Tuple[ - 'VirtualPatriciaNode', 'VirtualPatriciaNode']: + self, ffc: FactFetchingContext, facts: Optional[BinaryFactDict] = None + ) -> Tuple["VirtualPatriciaNode", "VirtualPatriciaNode"]: """ Returns the two VirtualPatriciaNode objects which are the subtrees of the current VirtualPatriciaNode. If facts argument is not None, this dictionary is filled with facts read from the DB. """ - assert not self.is_leaf, 'get_children() must not be called on leaves.' + assert not self.is_leaf, "get_children() must not be called on leaves." children_height = self.height - 1 if self.is_empty: @@ -167,46 +187,62 @@ async def get_children( if isinstance(fact, EdgeNodeFact): # A previously committed edge node. edge_node = VirtualPatriciaNode( - bottom_node=fact.bottom_node, path=fact.edge_path, length=fact.edge_length, - height=self.height) + bottom_node=fact.bottom_node, + path=fact.edge_path, + length=fact.edge_length, + height=self.height, + ) return edge_node._get_virtual_edge_node_children() assert isinstance(fact, BinaryNodeFact) return ( self.from_hash(hash_value=fact.left_node, height=children_height), - self.from_hash(hash_value=fact.right_node, height=children_height)) + self.from_hash(hash_value=fact.right_node, height=children_height), + ) # Internal utils. @classmethod async def _combine_to_binary_node( - cls, ffc: FactFetchingContext, left: 'VirtualPatriciaNode', - right: 'VirtualPatriciaNode', facts: Optional[BinaryFactDict]) -> 'VirtualPatriciaNode': + cls, + ffc: FactFetchingContext, + left: "VirtualPatriciaNode", + right: "VirtualPatriciaNode", + facts: Optional[BinaryFactDict], + ) -> "VirtualPatriciaNode": """ Combines two non-empty nodes to form a binary node. Writes the constructed node fact to the DB, as well as (up to) two other facts for the children if they were not previously committed. """ - left_node_hash, right_node_hash = await asyncio.gather(*( - node.commit(ffc=ffc, facts=facts) for node in (left, right))) + left_node_hash, right_node_hash = await asyncio.gather( + *(node.commit(ffc=ffc, facts=facts) for node in (left, right)) + ) parent_node_fact = BinaryNodeFact(left_node=left_node_hash, right_node=right_node_hash) parent_fact_hash = await write_node_fact( - ffc=ffc, inner_node_fact=parent_node_fact, facts=facts) + ffc=ffc, inner_node_fact=parent_node_fact, facts=facts + ) return VirtualPatriciaNode( - bottom_node=parent_fact_hash, path=0, length=0, height=right.height + 1) + bottom_node=parent_fact_hash, path=0, length=0, height=right.height + 1 + ) @classmethod async def _combine_to_virtual_edge_node( - cls, ffc: FactFetchingContext, left: 'VirtualPatriciaNode', - right: 'VirtualPatriciaNode', facts: Optional[BinaryFactDict]) -> 'VirtualPatriciaNode': + cls, + ffc: FactFetchingContext, + left: "VirtualPatriciaNode", + right: "VirtualPatriciaNode", + facts: Optional[BinaryFactDict], + ) -> "VirtualPatriciaNode": """ Combines an empty node and a non-empty node to form a virtual edge node. If the non-empty node is not known to be of canonical form, reads its fact from the DB in order to make it such (or make sure it is). """ - assert left.is_empty != right.is_empty, ( - '_combine_to_virtual_edge_node() must be called on one empty and one non-empty nodes.') + assert ( + left.is_empty != right.is_empty + ), "_combine_to_virtual_edge_node() must be called on one empty and one non-empty nodes." non_empty_child = right if left.is_empty else left non_empty_child = await non_empty_child.decommit(ffc=ffc, facts=facts) @@ -214,14 +250,18 @@ async def _combine_to_virtual_edge_node( parent_path = non_empty_child.path if left.is_empty: # Turn on the MSB bit if the non-empty child is on the right. - parent_path += (1 << non_empty_child.length) + parent_path += 1 << non_empty_child.length return VirtualPatriciaNode( - bottom_node=non_empty_child.bottom_node, path=parent_path, - length=non_empty_child.length + 1, height=non_empty_child.height + 1) + bottom_node=non_empty_child.bottom_node, + path=parent_path, + length=non_empty_child.length + 1, + height=non_empty_child.height + 1, + ) def _get_virtual_edge_node_children( - self) -> Tuple['VirtualPatriciaNode', 'VirtualPatriciaNode']: + self, + ) -> Tuple["VirtualPatriciaNode", "VirtualPatriciaNode"]: """ Returns the children of a virtual edge node: an empty node and a shorter-by-one virtual edge node, according to the direction embedded in the edge path. @@ -231,7 +271,9 @@ def _get_virtual_edge_node_children( non_empty_child = VirtualPatriciaNode( bottom_node=self.bottom_node, path=self.path & ((1 << children_length) - 1), # Turn the MSB bit off. - length=children_length, height=children_height) + length=children_length, + height=children_height, + ) edge_child_direction = self.path >> children_length empty_child = VirtualPatriciaNode.empty_node(height=children_height) diff --git a/src/starkware/starkware_utils/patricia_tree/virtual_patricia_node_test.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py similarity index 81% rename from src/starkware/starkware_utils/patricia_tree/virtual_patricia_node_test.py rename to src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py index a257eb38..1041f1bd 100644 --- a/src/starkware/starkware_utils/patricia_tree/virtual_patricia_node_test.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/virtual_patricia_node_test.py @@ -7,8 +7,13 @@ from starkware.cairo.common.patricia_utils import compute_patricia_from_leaves_for_test from starkware.crypto.signature.fast_pedersen_hash import async_pedersen_hash_func, pedersen_hash from starkware.python.utils import from_bytes, to_bytes -from starkware.starkware_utils.patricia_tree.nodes import BinaryNodeFact, EmptyNodeFact -from starkware.starkware_utils.patricia_tree.virtual_patricia_node import VirtualPatriciaNode +from starkware.starkware_utils.commitment_tree.patricia_tree.nodes import ( + BinaryNodeFact, + EmptyNodeFact, +) +from starkware.starkware_utils.commitment_tree.patricia_tree.virtual_patricia_node import ( + VirtualPatriciaNode, +) from starkware.storage.storage import Fact, FactFetchingContext, HashFunctionType from starkware.storage.test_utils import MockStorage @@ -24,7 +29,7 @@ class LeafFact(Fact): @classmethod def prefix(cls) -> bytes: - return b'leaf' + return b"leaf" def serialize(self) -> bytes: return to_bytes(self.value) @@ -33,20 +38,21 @@ async def _hash(self, hash_func: HashFunctionType) -> bytes: return self.serialize() @classmethod - def deserialize(cls, data: bytes) -> 'LeafFact': + def deserialize(cls, data: bytes) -> "LeafFact": return cls(from_bytes(data)) @classmethod - def empty(cls) -> 'LeafFact': + def empty(cls) -> "LeafFact": return cls(value=0) async def make_virtual_edge_non_canonical( - ffc: FactFetchingContext, node: VirtualPatriciaNode) -> VirtualPatriciaNode: + ffc: FactFetchingContext, node: VirtualPatriciaNode +) -> VirtualPatriciaNode: """ Returns the non-canonical form (hash, 0, 0) of a virtual edge node. """ - assert node.is_virtual_edge, 'Node should be of canonical form.' + assert node.is_virtual_edge, "Node should be of canonical form." node_hash = await node.commit(ffc=ffc, facts=None) return VirtualPatriciaNode.from_hash(hash_value=node_hash, height=node.height) @@ -54,7 +60,8 @@ async def make_virtual_edge_non_canonical( def verify_root(leaves: Collection[int], expected_root_hash: bytes): root_hash, _preimage, _node_at_path = compute_patricia_from_leaves_for_test( - leaves=leaves, hash_func=pedersen_hash) + leaves=leaves, hash_func=pedersen_hash + ) assert expected_root_hash == to_bytes(root_hash) @@ -68,15 +75,19 @@ async def test_combine_and_get_children(ffc: FactFetchingContext): # 0 0 0 0 # 0 12 0 0 0 0 30 0 """ - leaf_hash_12, leaf_hash_30, _ = await asyncio.gather(*( - leaf_fact.set_fact(ffc=ffc) - for leaf_fact in (LeafFact(value=12), LeafFact(value=30), LeafFact(value=0)))) + leaf_hash_12, leaf_hash_30, _ = await asyncio.gather( + *( + leaf_fact.set_fact(ffc=ffc) + for leaf_fact in (LeafFact(value=12), LeafFact(value=30), LeafFact(value=0)) + ) + ) # Combine two empty trees. empty_tree_0 = VirtualPatriciaNode.empty_node(height=0) empty_tree_1 = await VirtualPatriciaNode.combine(ffc=ffc, left=empty_tree_0, right=empty_tree_0) assert empty_tree_1 == VirtualPatriciaNode( - bottom_node=EmptyNodeFact.EMPTY_NODE_HASH, path=0, length=0, height=1) + bottom_node=EmptyNodeFact.EMPTY_NODE_HASH, path=0, length=0, height=1 + ) assert await empty_tree_1.get_children(ffc=ffc) == (empty_tree_0, empty_tree_0) # Build left subtree. @@ -94,7 +105,8 @@ async def test_combine_and_get_children(ffc: FactFetchingContext): # Combine left edge node and right empty tree. left_tree_2 = await VirtualPatriciaNode.combine(ffc=ffc, left=left_tree_1, right=empty_tree_1) assert left_tree_2 == VirtualPatriciaNode( - bottom_node=leaf_hash_12, path=int('01', 2), length=2, height=2) + bottom_node=leaf_hash_12, path=int("01", 2), length=2, height=2 + ) # Get children on both forms. expected_children = (left_tree_1, empty_tree_1) @@ -117,7 +129,8 @@ async def test_combine_and_get_children(ffc: FactFetchingContext): # Combine left empty tree and right edge node. right_tree_2 = await VirtualPatriciaNode.combine(ffc=ffc, left=empty_tree_1, right=right_tree_1) assert right_tree_2 == VirtualPatriciaNode( - bottom_node=leaf_hash_30, path=int('10', 2), length=2, height=2) + bottom_node=leaf_hash_30, path=int("10", 2), length=2, height=2 + ) # Get children on both forms. expected_children = (empty_tree_1, right_tree_1) @@ -127,8 +140,9 @@ async def test_combine_and_get_children(ffc: FactFetchingContext): # Build whole tree. # Combine left edge and right edge. - left_node, right_node = await asyncio.gather(*( - node.commit(ffc=ffc, facts=None) for node in (left_tree_2, right_tree_2))) + left_node, right_node = await asyncio.gather( + *(node.commit(ffc=ffc, facts=None) for node in (left_tree_2, right_tree_2)) + ) binary_node_fact = BinaryNodeFact(left_node=left_node, right_node=right_node) root_hash = await binary_node_fact._hash(hash_func=ffc.hash_func) @@ -137,14 +151,17 @@ async def test_combine_and_get_children(ffc: FactFetchingContext): left_edge_child, right_edge_child = await tree.get_children(ffc=ffc) assert (left_edge_child, right_edge_child) == ( VirtualPatriciaNode(bottom_node=left_node, path=0, length=0, height=2), - VirtualPatriciaNode(bottom_node=right_node, path=0, length=0, height=2)) + VirtualPatriciaNode(bottom_node=right_node, path=0, length=0, height=2), + ) # Test operations on the original edge children (now non-canonical). # Combining with an empty node yields another edge with length longer-by-one. parent_edge = await VirtualPatriciaNode.combine( - ffc=ffc, left=left_edge_child, right=VirtualPatriciaNode.empty_node(height=2)) + ffc=ffc, left=left_edge_child, right=VirtualPatriciaNode.empty_node(height=2) + ) assert parent_edge == VirtualPatriciaNode( - bottom_node=left_tree_2.bottom_node, path=int('001', 2), length=3, height=3) + bottom_node=left_tree_2.bottom_node, path=int("001", 2), length=3, height=3 + ) # Getting their children returns another edge with length shorter-by-one. assert await left_edge_child.get_children(ffc=ffc) == (left_tree_1, empty_tree_1) @@ -173,24 +190,32 @@ async def test_update_and_get_leaves(ffc: FactFetchingContext): # Check get_leaves(). expected_leaves = { leaf_id: leaves[leaf_id] if leaf_id in leaves else LeafFact(value=0) - for leaf_id in leaves_range} - assert await tree._get_leaves( - ffc=ffc, indices=leaves_range, fact_cls=LeafFact) == expected_leaves + for leaf_id in leaves_range + } + assert ( + await tree._get_leaves(ffc=ffc, indices=leaves_range, fact_cls=LeafFact) == expected_leaves + ) # Compare to test util result. verify_root( leaves=[leaf.value for leaf in expected_leaves.values()], - expected_root_hash=tree.bottom_node) + expected_root_hash=tree.bottom_node, + ) # Update leaf values again: new leaves contain addition, deletion and updating a key. updated_leaves = { - 0: LeafFact(value=2), 1: LeafFact(value=20), 3: LeafFact(value=6), 6: LeafFact(value=0)} + 0: LeafFact(value=2), + 1: LeafFact(value=20), + 3: LeafFact(value=6), + 6: LeafFact(value=0), + } tree = await tree._update(ffc=ffc, modifications=updated_leaves.items()) # Check get_leaves(). updated_leaves = {**expected_leaves, **updated_leaves} - assert await tree._get_leaves( - ffc=ffc, indices=leaves_range, fact_cls=LeafFact) == updated_leaves + assert ( + await tree._get_leaves(ffc=ffc, indices=leaves_range, fact_cls=LeafFact) == updated_leaves + ) # Compare to test util result. sorted_by_index_leaf_values = [updated_leaves[leaf_id].value for leaf_id in leaves_range] diff --git a/src/starkware/starkware_utils/config_base.py b/src/starkware/starkware_utils/config_base.py index 66c4e199..8dd8a13f 100644 --- a/src/starkware/starkware_utils/config_base.py +++ b/src/starkware/starkware_utils/config_base.py @@ -10,33 +10,36 @@ logger = logging.getLogger(__name__) -TConfig = TypeVar('TConfig', bound='Config') +TConfig = TypeVar("TConfig", bound="Config") # General utilities. + def load_config( - config_file_path: Optional[str] = None, load_logging_config: Optional[bool] = True) -> dict: + config_file_path: Optional[str] = None, load_logging_config: Optional[bool] = True +) -> dict: if config_file_path is None: - config_file_path = '/config.yml' + config_file_path = "/config.yml" - config = yaml.safe_load(open(config_file_path, 'r')) + config = yaml.safe_load(open(config_file_path, "r")) if load_logging_config: - logging.config.dictConfig(config.get('LOGGING', {})) + logging.config.dictConfig(config.get("LOGGING", {})) return config def fetch_application_config(global_config: dict) -> dict: - return global_config.get('application', {}) + return global_config.get("application", {}) def fetch_service_config(global_config: dict) -> dict: - return fetch_application_config(global_config).get('config', {}) + return fetch_application_config(global_config).get("config", {}) # Base class for config schemas. + class Config(ValidatedMarshmallowDataclass): @classmethod def load(cls: Type[TConfig], data: dict) -> TConfig: @@ -52,4 +55,5 @@ def remove_none_values(self, data, many=False): def log_fields(config: Config): for field in dataclasses.fields(config): logger.info( - f'Initialized {field.name} configuration with value: {getattr(config, field.name)}') + f"Initialized {field.name} configuration with value: {getattr(config, field.name)}" + ) diff --git a/src/starkware/starkware_utils/custom_raising_dict.py b/src/starkware/starkware_utils/custom_raising_dict.py index f3a6259f..15cea5e1 100644 --- a/src/starkware/starkware_utils/custom_raising_dict.py +++ b/src/starkware/starkware_utils/custom_raising_dict.py @@ -4,8 +4,8 @@ from frozendict import frozendict -KT = TypeVar('KT') -VT = TypeVar('VT') +KT = TypeVar("KT") +VT = TypeVar("VT") class CustomRaisingDict(ABC, UserDict, Generic[KT, VT]): @@ -23,7 +23,7 @@ def exception_type(cls) -> Type[Exception]: def __init_subclass__(cls, exception_type: Type[Exception], **kwargs): super().__init_subclass__(**kwargs) # type: ignore[call-arg] - assert issubclass(exception_type, KeyError), 'Exception type must subclass KeyError.' + assert issubclass(exception_type, KeyError), "Exception type must subclass KeyError." cls.exception_type = exception_type # type: ignore def __getitem__(self, key: KT) -> VT: @@ -53,8 +53,9 @@ def __init_subclass__(cls, exception_type: Type[Exception], **kwargs): class _CustomRaisingFrozenDict(CustomRaisingDict[KT, VT], exception_type=exception_type): pass - _CustomRaisingFrozenDict.__name__ = _CustomRaisingFrozenDict.__qualname__ = \ - 'CustomRaisingFrozenDict' + _CustomRaisingFrozenDict.__name__ = ( + _CustomRaisingFrozenDict.__qualname__ + ) = "CustomRaisingFrozenDict" cls.dict_cls = _CustomRaisingFrozenDict diff --git a/src/starkware/starkware_utils/error_handling.py b/src/starkware/starkware_utils/error_handling.py index e017e1b5..f544b76c 100644 --- a/src/starkware/starkware_utils/error_handling.py +++ b/src/starkware/starkware_utils/error_handling.py @@ -4,7 +4,7 @@ from enum import Enum, auto from typing import Any, Dict, List, Optional, Type -symbol_to_function = {'!=': operator.ne, '==': operator.eq, '>': operator.gt, '>=': operator.ge} +symbol_to_function = {"!=": operator.ne, "==": operator.eq, ">": operator.gt, ">=": operator.ge} class ErrorCode(Enum): @@ -139,10 +139,10 @@ class StarkException(WebFriendlyException): def __init__(self, code, message: Optional[str] = None): self.code = code self.message = message - super().__init__(status_code=500, body={'code': code, 'message': message}) + super().__init__(status_code=500, body={"code": code, "message": message}) def __repr__(self) -> str: - return f'{type(self).__name__}(code={self.code}, message={self.message})' + return f"{type(self).__name__}(code={self.code}, message={self.message})" def __eq__(self, other: Any) -> bool: if not isinstance(other, StarkException): @@ -165,7 +165,7 @@ def stark_assert_eq(exp_val, actual_val, code, message: Optional[str] = None): Verifies that the expected value is equal to the actual value, raising a StarkException with the appropriate code and message, where the expected and actual values are added to the message. """ - _stark_assert_not_symbol(exp_val, actual_val, symbol='!=', code=code, message=message) + _stark_assert_not_symbol(exp_val, actual_val, symbol="!=", code=code, message=message) def stark_assert_ne(exp_val, actual_val, code, message: Optional[str] = None): @@ -174,7 +174,7 @@ def stark_assert_ne(exp_val, actual_val, code, message: Optional[str] = None): with the appropriate code and message, where the expected and actual values are added to the message. """ - _stark_assert_not_symbol(exp_val, actual_val, symbol='==', code=code, message=message) + _stark_assert_not_symbol(exp_val, actual_val, symbol="==", code=code, message=message) def stark_assert_le(exp_val, actual_val, code, message: Optional[str] = None): @@ -183,7 +183,7 @@ def stark_assert_le(exp_val, actual_val, code, message: Optional[str] = None): StarkException with the appropriate code and message, where the expected and actual values are added to the message. """ - _stark_assert_not_symbol(exp_val, actual_val, symbol='>', code=code, message=message) + _stark_assert_not_symbol(exp_val, actual_val, symbol=">", code=code, message=message) def stark_assert_lt(exp_val, actual_val, code, message: Optional[str] = None): @@ -192,11 +192,10 @@ def stark_assert_lt(exp_val, actual_val, code, message: Optional[str] = None): StarkException with the appropriate code and message, where the expected and actual values are added to the message. """ - _stark_assert_not_symbol(exp_val, actual_val, symbol='>=', code=code, message=message) + _stark_assert_not_symbol(exp_val, actual_val, symbol=">=", code=code, message=message) -def _stark_assert_not_symbol( - exp_val, actual_val, symbol: str, code, message: Optional[str] = None): +def _stark_assert_not_symbol(exp_val, actual_val, symbol: str, code, message: Optional[str] = None): """ Receives a symbol as a string that compares two values (e.g '==', '>') and verifies that: `not exp_val symbol actual_val`. @@ -212,15 +211,18 @@ def _stark_assert_not_symbol( format_val = lambda val: hex(val) if isinstance(val, int) and val > MIN_HEX_SIZE else val if symbol_to_function[symbol](exp_val, actual_val): - eq_log = f'{format_val(exp_val)} {symbol} {format_val(actual_val)}' - message = f'{message}\n{eq_log}' if message else eq_log + eq_log = f"{format_val(exp_val)} {symbol} {format_val(actual_val)}" + message = f"{message}\n{eq_log}" if message else eq_log raise StarkException(code=code, message=message) @contextlib.contextmanager def wrap_with_stark_exception( - code: ErrorCode, message: Optional[str] = None, logger: Optional[logging.Logger] = None, - exception_types: Optional[List[Type[Exception]]] = None): + code: ErrorCode, + message: Optional[str] = None, + logger: Optional[logging.Logger] = None, + exception_types: Optional[List[Type[Exception]]] = None, +): """ Wraps exceptions of types exception_types thrown in the context with StarkException. If exception_types are not provided, only AssertionError is wrapped. diff --git a/src/starkware/starkware_utils/field_validators.py b/src/starkware/starkware_utils/field_validators.py index bf393b23..2274689e 100644 --- a/src/starkware/starkware_utils/field_validators.py +++ b/src/starkware/starkware_utils/field_validators.py @@ -5,21 +5,24 @@ import marshmallow.validate from web3 import Web3 -DNS_REGEX = r'^((\*)|(\*\.))?([a-z0-9-]){1,62}(\.[a-z0-9-]{1,62})*\.?$' +DNS_REGEX = r"^((\*)|(\*\.))?([a-z0-9-]){1,62}(\.[a-z0-9-]{1,62})*\.?$" -T = TypeVar('T') -TypeKey = TypeVar('TypeKey') -TypeValue = TypeVar('TypeValue') +T = TypeVar("T") +TypeKey = TypeVar("TypeKey") +TypeValue = TypeVar("TypeValue") ValidatorType = Callable[[T], Union[T, bool]] # Validators for public use in config dataclasses. + def validate_regex_match( - field_name: str, *, regex: str, allow_none: bool, regex_description: str) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be a legal {regex_description}'.format( - field_name=field_name, regex_description=regex_description) + field_name: str, *, regex: str, allow_none: bool, regex_description: str +) -> ValidatorType: + error_message = "Invalid {field_name}: {{input}}; must be a legal {regex_description}".format( + field_name=field_name, regex_description=regex_description + ) validate_regex = marshmallow.validate.Regexp(regex=regex, error=error_message) def validator(value): @@ -36,39 +39,46 @@ def validator(value): def validate_dns(*, allow_none: bool) -> ValidatorType: return validate_regex_match( - field_name='dns', regex=DNS_REGEX, allow_none=allow_none, regex_description='DNS label') + field_name="dns", regex=DNS_REGEX, allow_none=allow_none, regex_description="DNS label" + ) -HEX_REGEX = r'^0x[a-fA-F0-9]+$' +HEX_REGEX = r"^0x[a-fA-F0-9]+$" validate_optional_hex_str = validate_regex_match( - field_name='fact', regex=HEX_REGEX, allow_none=True, regex_description='hex string') + field_name="fact", regex=HEX_REGEX, allow_none=True, regex_description="hex string" +) def validate_url( - *, url_name: str, schemes: marshmallow.types.StrSequenceOrSet, - require_full_url: bool) -> ValidatorType: + *, url_name: str, schemes: marshmallow.types.StrSequenceOrSet, require_full_url: bool +) -> ValidatorType: error_message = ( - 'Invalid {url_name} URL: {{input}}; ' - 'must be a legal URL starting with {schemes}').format( - url_name=url_name, schemes=','.join(schemes)) + "Invalid {url_name} URL: {{input}}; " "must be a legal URL starting with {schemes}" + ).format(url_name=url_name, schemes=",".join(schemes)) return marshmallow.validate.URL( - schemes=schemes, require_tld=require_full_url, error=error_message) + schemes=schemes, require_tld=require_full_url, error=error_message + ) validate_gateway_url = validate_url( - url_name='Gateway endpoint', schemes={'http', 'https'}, require_full_url=False) + url_name="Gateway endpoint", schemes={"http", "https"}, require_full_url=False +) validate_internal_url = validate_url( - url_name='Internal Gateway endpoint', schemes={'http', 'https'}, require_full_url=False) + url_name="Internal Gateway endpoint", schemes={"http", "https"}, require_full_url=False +) validate_node_endpoint = validate_url( - url_name='Node endpoint', schemes={'http', 'https'}, require_full_url=False) + url_name="Node endpoint", schemes={"http", "https"}, require_full_url=False +) def validate_one_of( - field_name: str, *, choices: Iterable, allow_none: bool = False) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; allowed values: {{choices}}'.format( - field_name=field_name) + field_name: str, *, choices: Iterable, allow_none: bool = False +) -> ValidatorType: + error_message = "Invalid {field_name}: {{input}}; allowed values: {{choices}}".format( + field_name=field_name + ) one_of_validator = marshmallow.validate.OneOf(choices=choices, error=error_message) def validator(value): @@ -81,34 +91,49 @@ def validator(value): def validate_equal(field_name: str, *, allowed_value: T) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be: {{other}}'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be: {{other}}".format( + field_name=field_name + ) return marshmallow.validate.Equal(comparable=allowed_value, error=error_message) def validate_length(field_name: str, *, length: int) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be of length: {length}'.format( - field_name=field_name, length=length) + error_message = "Invalid {field_name}: {{input}}; must be of length: {length}".format( + field_name=field_name, length=length + ) return marshmallow.validate.Length(equal=length, error=error_message) def validate_in_range( - field_name, *, min_value: Optional[int] = None, max_value: Optional[int] = None, - min_inclusive: bool = True, max_inclusive: bool = True, - allow_none: bool = False, error_message: Optional[str] = None) -> ValidatorType: + field_name, + *, + min_value: Optional[int] = None, + max_value: Optional[int] = None, + min_inclusive: bool = True, + max_inclusive: bool = True, + allow_none: bool = False, + error_message: Optional[str] = None, +) -> ValidatorType: if error_message is None: range_string = ( f'{"[" if min_inclusive else "("}' f'{"-inf" if min_value is None else min_value},' f'{"inf" if max_value is None else max_value}' - f'{"]" if max_inclusive else ")"}') - error_message = \ - 'Invalid {field_name}: {{input}}; must be in the range {range_string}'.format( - field_name=field_name, range_string=range_string) + f'{"]" if max_inclusive else ")"}' + ) + error_message = ( + "Invalid {field_name}: {{input}}; must be in the range {range_string}".format( + field_name=field_name, range_string=range_string + ) + ) range_validator = marshmallow.validate.Range( - min=min_value, max=max_value, min_inclusive=min_inclusive, max_inclusive=max_inclusive, - error=error_message) + min=min_value, + max=max_value, + min_inclusive=min_inclusive, + max_inclusive=max_inclusive, + error=error_message, + ) def validator(value): if allow_none and value is None: @@ -120,24 +145,35 @@ def validator(value): def validate_positive(field_name: str, *, allow_none: bool = False) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be a positive value'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be a positive value".format( + field_name=field_name + ) return validate_in_range( - field_name=field_name, min_value=0, min_inclusive=False, allow_none=allow_none, - error_message=error_message) + field_name=field_name, + min_value=0, + min_inclusive=False, + allow_none=allow_none, + error_message=error_message, + ) def validate_non_negative(field_name, *, allow_none=False): - error_message = 'Invalid {field_name}: {{input}}; must be a non-negative value'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be a non-negative value".format( + field_name=field_name + ) return validate_in_range( - field_name=field_name, min_value=0, min_inclusive=True, allow_none=allow_none, - error_message=error_message) + field_name=field_name, + min_value=0, + min_inclusive=True, + allow_none=allow_none, + error_message=error_message, + ) def validate_positive_or_infinity(field_name: str) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be positive -1 (for unlimited)'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be positive -1 (for unlimited)".format( + field_name=field_name + ) def validator(value): if value <= 0 and value != -1: @@ -150,15 +186,21 @@ def validator(value): def validate_probability(field_name: str, *, allow_none: bool = False) -> ValidatorType: return validate_in_range( - field_name=field_name, min_value=0, max_value=1, min_inclusive=True, max_inclusive=True, - allow_none=allow_none) + field_name=field_name, + min_value=0, + max_value=1, + min_inclusive=True, + max_inclusive=True, + allow_none=allow_none, + ) def validate_public_key(field_name: str) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be a legal Ethereum address'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be a legal Ethereum address".format( + field_name=field_name + ) - address_regex = r'^0x[a-fA-F0-9]{40}$' + address_regex = r"^0x[a-fA-F0-9]{40}$" validate_address_regex = marshmallow.validate.Regexp(regex=address_regex, error=error_message) def validator(addresses: Union[str, List[str]]): @@ -177,22 +219,25 @@ def validator(addresses: Union[str, List[str]]): def validate_private_key(field_name: str) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be a legal Ethereum private key'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be a legal Ethereum private key".format( + field_name=field_name + ) - private_key_regex = r'^0x[a-fA-F0-9]{64}$' + private_key_regex = r"^0x[a-fA-F0-9]{64}$" return marshmallow.validate.Regexp(regex=private_key_regex, error=error_message) def validate_customer_id(field_name: str) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be an alphanumeric string'.format( - field_name=field_name) - return marshmallow.validate.Regexp(regex=r'^[A-Za-z0-9_-]+$', error=error_message) + error_message = "Invalid {field_name}: {{input}}; must be an alphanumeric string".format( + field_name=field_name + ) + return marshmallow.validate.Regexp(regex=r"^[A-Za-z0-9_-]+$", error=error_message) def validate_absolute_linux_path(field_name: str, *, allow_none: bool) -> ValidatorType: - error_message = 'Invalid {field_name}: {{input}}; must be a legal absolute Linux path'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be a legal absolute Linux path".format( + field_name=field_name + ) def validator(value: str): if allow_none and value is None: @@ -206,42 +251,45 @@ def validator(value: str): return validator -validate_certificates_path = validate_absolute_linux_path('certificates_path', allow_none=True) +validate_certificates_path = validate_absolute_linux_path("certificates_path", allow_none=True) def validate_communication_params(*, url: str, certificates_path: Optional[str]): - https_used = url.startswith('https') + https_used = url.startswith("https") certs_used = certificates_path is not None if certs_used and not https_used: - raise ValueError('Certificates should be used together with a HTTPS URL') + raise ValueError("Certificates should be used together with a HTTPS URL") def validate_dict( - field_name: str, - *, key_validator: Optional[Callable[[str], Callable[[TypeKey], bool]]] = None, - value_validator: Optional[Callable[[str], Callable[[TypeValue], bool]]] = None, - allow_none: bool = False) -> Callable[[Dict[TypeKey, TypeValue]], bool]: + field_name: str, + *, + key_validator: Optional[Callable[[str], Callable[[TypeKey], bool]]] = None, + value_validator: Optional[Callable[[str], Callable[[TypeValue], bool]]] = None, + allow_none: bool = False, +) -> Callable[[Dict[TypeKey, TypeValue]], bool]: """ Returns a validator for a dictionary, that validates the keys according to key_validator, and the values according to value_validator. These validators should be methods that get the field name, and return the validator for that field. Set these validators to None to have empty validators, which will always return True. """ + def validator(dictionary: Dict[TypeKey, TypeValue]): nonlocal key_validator, value_validator if allow_none and dictionary is None: return True if key_validator is None: - key_validator = (lambda name: lambda key: True) + key_validator = lambda name: lambda key: True if value_validator is None: - value_validator = (lambda name: lambda value: True) + value_validator = lambda name: lambda value: True for key, value in dictionary.items(): try: (key_validator(str(key)))(key) (value_validator(str(key)))(value) except Exception as e: - raise type(e)(f'Dictionary {field_name} is not valid: ' + str(e)) + raise type(e)(f"Dictionary {field_name} is not valid: " + str(e)) return True return validator @@ -251,8 +299,9 @@ def validate_power_of_two(field_name: str) -> ValidatorType: """ Return a validator for a number, that validates that the number is a power of 2. """ - error_message = 'Invalid {field_name}: {{input}}; must be a power of 2'.format( - field_name=field_name) + error_message = "Invalid {field_name}: {{input}}; must be a power of 2".format( + field_name=field_name + ) def validator(value): tmp_value = value diff --git a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py index 08189bd7..94c69e58 100644 --- a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py +++ b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py @@ -16,7 +16,8 @@ class IntAsStr(mfields.Field): serialized to strings in the JSONs, so that JavaSscript can handle them (JavaScript cannot handle uint64 numbers). """ - default_error_messages = {'invalid': 'Expected int string, got: "{input}".'} + + default_error_messages = {"invalid": 'Expected int string, got: "{input}".'} def _serialize(self, value, attr, obj, **kwargs): if value is None: @@ -24,8 +25,8 @@ def _serialize(self, value, attr, obj, **kwargs): return str(value) def _deserialize(self, value, attr, data, **kwargs): - if re.match('^-?[0-9]+$', value) is None: - self.fail('invalid', input=value) + if re.match("^-?[0-9]+$", value) is None: + self.fail("invalid", input=value) return int(value) @@ -48,7 +49,8 @@ def _serialize(self, value, attr, obj, **kwargs): return None raise ValidationError( - message=f'Field of type {type(self).__name__} is None, but allow_none=False') + message=f"Field of type {type(self).__name__} is None, but allow_none=False" + ) def _deserialize(self, value, attr, data, **kwargs): # No need to handle the case in which value is None, since public deserialize() method @@ -62,7 +64,7 @@ class IntAsHex(mfields.Field): field elements. """ - default_error_messages = {'invalid': 'Expected hex string, got: "{input}".'} + default_error_messages = {"invalid": 'Expected hex string, got: "{input}".'} def _serialize(self, value, attr, obj, **kwargs): if value is None: @@ -71,8 +73,8 @@ def _serialize(self, value, attr, obj, **kwargs): return hex(value) def _deserialize(self, value, attr, data, **kwargs): - if re.match('^0x[0-9a-f]+$', value) is None: - self.fail('invalid', input=value) + if re.match("^0x[0-9a-f]+$", value) is None: + self.fail("invalid", input=value) return int(value, 16) @@ -82,7 +84,7 @@ class BytesAsHex(mfields.Field): A field that behaves like bytes, but serializes to a hex string. """ - default_error_messages = {'invalid': 'Expected hex string, got: "{input}".'} + default_error_messages = {"invalid": 'Expected hex string, got: "{input}".'} def _serialize(self, value, attr, obj, **kwargs): if value is None: @@ -91,8 +93,8 @@ def _serialize(self, value, attr, obj, **kwargs): return value.hex() def _deserialize(self, value, attr, data, **kwargs): - if re.match('^[0-9a-f]*$', value) is None: - self.fail('invalid', input=value) + if re.match("^[0-9a-f]*$", value) is None: + self.fail("invalid", input=value) return bytes.fromhex(value) @@ -102,16 +104,16 @@ class BytesAsBase64Str(mfields.Field): A field that behaves like bytes, but serializes to base64. """ - default_error_messages = {'invalid': 'Expected Base64 bytes, got: "{input}".'} + default_error_messages = {"invalid": 'Expected Base64 bytes, got: "{input}".'} def _serialize(self, value, attr, obj, **kwargs): if value is None: return None assert isinstance(value, bytes) - return base64.b64encode(value).decode('ascii') + return base64.b64encode(value).decode("ascii") def _deserialize(self, value, attr, data, **kwargs): - return base64.b64decode(value.encode('ascii')) + return base64.b64decode(value.encode("ascii")) class CustomField(ABC): @@ -129,8 +131,9 @@ def _type(cls) -> type: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # type: ignore[call-arg] - assert issubclass(cls, FieldABC), \ - 'CustomField must be used along with inheritance from a marshmallow field.' + assert issubclass( + cls, FieldABC + ), "CustomField must be used along with inheritance from a marshmallow field." def _deserialize(self, *args, **kwargs): return self._type(super()._deserialize(*args, **kwargs)) # type: ignore @@ -158,10 +161,13 @@ class CustomRaisingFrozenDictField(CustomField, mfields.Mapping): # Field metadata for general use in marshmallow dataclasses. + def enum_field_metadata( - *, enum_class: type, require: bool = True, allow_none: bool = False) -> dict: + *, enum_class: type, require: bool = True, allow_none: bool = False +) -> dict: return dict( - marshmallow_field=EnumField(enum_cls=enum_class, required=require, allow_none=allow_none)) + marshmallow_field=EnumField(enum_cls=enum_class, required=require, allow_none=allow_none) + ) boolean_field_metadata = dict(marshmallow_field=mfields.Boolean(truthy={True}, falsy={False})) diff --git a/src/starkware/starkware_utils/serializable.py b/src/starkware/starkware_utils/serializable.py index 04c9f384..56bd8e09 100644 --- a/src/starkware/starkware_utils/serializable.py +++ b/src/starkware/starkware_utils/serializable.py @@ -3,8 +3,8 @@ from json import JSONDecoder, JSONEncoder from typing import ClassVar, Dict, Type, TypeVar -TSerializableObject = TypeVar('TSerializableObject', bound='Serializable') -TStrSerializableObject = TypeVar('TStrSerializableObject', bound='StringSerializable') +TSerializableObject = TypeVar("TSerializableObject", bound="Serializable") +TStrSerializableObject = TypeVar("TStrSerializableObject", bound="StringSerializable") class Serializable(ABC): @@ -52,7 +52,7 @@ class StringSerializable(Serializable): then it needs to implement the dumps function. """ - _classes: ClassVar[Dict[str, Type['StringSerializable']]] = {} + _classes: ClassVar[Dict[str, Type["StringSerializable"]]] = {} _serialize_name: ClassVar[str] def __init_subclass__(cls, **kwargs): @@ -64,8 +64,8 @@ def __init_subclass__(cls, **kwargs): if mro_cls is StringSerializable: # The dumps method is abstract. continue - if 'dumps' in mro_cls.__dict__: - cls._serialize_name = f'{mro_cls}' + if "dumps" in mro_cls.__dict__: + cls._serialize_name = f"{mro_cls}" StringSerializable._classes[cls._serialize_name] = cls @abstractmethod @@ -78,20 +78,17 @@ def loads(cls: Type[TStrSerializableObject], data: str) -> TStrSerializableObjec pass def serialize(self) -> bytes: - return self.dumps().encode('ascii') + return self.dumps().encode("ascii") @classmethod def deserialize(cls: Type[TStrSerializableObject], data: bytes) -> TStrSerializableObject: - return cls.loads(data=data.decode('ascii')) + return cls.loads(data=data.decode("ascii")) class SerializableEncoder(JSONEncoder): def default(self, obj): if isinstance(obj, StringSerializable): if obj._serialize_name in StringSerializable._classes: - return { - '_serializable': obj._serialize_name, - 'value': obj.dumps() - } + return {"_serializable": obj._serialize_name, "value": obj.dumps()} return JSONEncoder.default(self, obj) @@ -104,12 +101,12 @@ def __init__(self, *args, **kwargs): super().__init__(object_hook=self.object_hook, *args, **kwargs) def object_hook(self, obj): - if '_serializable' not in obj: + if "_serializable" not in obj: return obj - cls_repr = obj['_serializable'] + cls_repr = obj["_serializable"] serialized_class = StringSerializable._classes.get(cls_repr, None) - assert serialized_class is not None, f'Could not decode the class {cls_repr}.' - return serialized_class.loads(data=obj['value']) + assert serialized_class is not None, f"Could not decode the class {cls_repr}." + return serialized_class.loads(data=obj["value"]) @staticmethod def get_decoder() -> Type[JSONDecoder]: diff --git a/src/starkware/starkware_utils/time/time.py b/src/starkware/starkware_utils/time/time.py index 14ce6810..7d4840fe 100644 --- a/src/starkware/starkware_utils/time/time.py +++ b/src/starkware/starkware_utils/time/time.py @@ -14,7 +14,8 @@ mocked_time_func: ContextVar[Callable[[], float]] = ContextVar( - 'mocked_time_func', default=def_time_func) + "mocked_time_func", default=def_time_func +) def time(): diff --git a/src/starkware/starkware_utils/validated_dataclass.py b/src/starkware/starkware_utils/validated_dataclass.py index 90be8281..dff514ad 100644 --- a/src/starkware/starkware_utils/validated_dataclass.py +++ b/src/starkware/starkware_utils/validated_dataclass.py @@ -12,9 +12,9 @@ from starkware.starkware_utils.serializable import StringSerializable from starkware.starkware_utils.validated_fields import Field -TValidatedDataclass = TypeVar('TValidatedDataclass', bound='ValidatedDataclass') -TSerializableDataclass = TypeVar('TSerializableDataclass', bound='SerializableMarshmallowDataclass') -T = TypeVar('T') +TValidatedDataclass = TypeVar("TValidatedDataclass", bound="ValidatedDataclass") +TSerializableDataclass = TypeVar("TSerializableDataclass", bound="SerializableMarshmallowDataclass") +T = TypeVar("T") class SerializableMarshmallowDataclass(StringSerializable): @@ -30,7 +30,7 @@ class SerializableMarshmallowDataclass(StringSerializable): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # type: ignore[call-arg] - cls.class_name_prefix = camel_to_snake_case(camel_case_name=cls.__name__).encode('ascii') + cls.class_name_prefix = camel_to_snake_case(camel_case_name=cls.__name__).encode("ascii") def dump(self) -> dict: return self.Schema().dump(obj=self) @@ -69,8 +69,8 @@ def validate_dataclass(self): @classmethod def get_random_element( - cls: Type[TValidatedDataclass], - random_object: Optional[random.Random] = None, **data) -> TValidatedDataclass: + cls: Type[TValidatedDataclass], random_object: Optional[random.Random] = None, **data + ) -> TValidatedDataclass: """ Generates a random object of the given class restricted by the given data. Any field can be either passed as an argument (field_name=field_value), and if not, @@ -104,45 +104,49 @@ class Outer(ValidatedMarshmallowDataclass): validated_field = get_validated_field(field=field) if validated_field is not None: new_object_data[field.name] = validated_field.get_random_value( - random_object=random_object) + random_object=random_object + ) continue # The field is a validated class object. - is_validated_dataclass = ( - inspect.isclass(field.type) and - issubclass(field.type, ValidatedMarshmallowDataclass)) + is_validated_dataclass = inspect.isclass(field.type) and issubclass( + field.type, ValidatedMarshmallowDataclass + ) if is_validated_dataclass: new_object_data[field.name] = field.type.get_random_element( - random_object=random_object) + random_object=random_object + ) continue raise Exception( - f'Could not randomize the field {field.name} in an object of type {cls}.') + f"Could not randomize the field {field.name} in an object of type {cls}." + ) return cls(**new_object_data) # type: ignore def validate_values(self): for field in dataclasses.fields(self): - metadata = getattr(field, 'metadata', None) + metadata = getattr(field, "metadata", None) if metadata is None: continue value = getattr(self, field.name) # First use the field_validated argument, and only if it does not exist, # use the validation inside the marshmallow field argument. - validated_field = metadata.get('validated_field', None) + validated_field = metadata.get("validated_field", None) if validated_field is None: - marshmallow_field = field.metadata.get('marshmallow_field', None) + marshmallow_field = field.metadata.get("marshmallow_field", None) if marshmallow_field is not None: validate_field(field=marshmallow_field, value=value) else: - name_in_messages = metadata.get('name_in_messages', None) + name_in_messages = metadata.get("name_in_messages", None) validated_field.validate(value=value, name=name_in_messages) def validate_types(self): for field in dataclasses.fields(self): typeguard.check_type( - argname=field.name, value=getattr(self, field.name), expected_type=field.type) + argname=field.name, value=getattr(self, field.name), expected_type=field.type + ) class ValidatedMarshmallowDataclass(ValidatedDataclass, SerializableMarshmallowDataclass): @@ -156,8 +160,8 @@ def get_validated_field(field: dataclasses.Field) -> Optional[Field]: Checks if the dataclass field has a validated_field attribute in its metadata. If so returns it, otherwise returns None. """ - if field.metadata is not None and 'validated_field' in field.metadata: - return field.metadata['validated_field'] + if field.metadata is not None and "validated_field" in field.metadata: + return field.metadata["validated_field"] return None @@ -180,6 +184,7 @@ class Child(Base): derived class construction will work as expected. """ if cls is None: # Arguments passed directly to decorator. + def inner(cls): prepare_class_annotations_and_attribute_values(cls) return marshmallow_dataclass.dataclass(cls, **kwargs) @@ -200,7 +205,8 @@ def prepare_class_annotations_and_attribute_values(cls): """ annotations, attr_values = process_class_annotations_and_attribute_values(cls=cls) set_class_annotations_and_attribute_values( - cls=cls, annotations=annotations, attr_values=attr_values) + cls=cls, annotations=annotations, attr_values=attr_values + ) def process_class_annotations_and_attribute_values(cls) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -213,7 +219,7 @@ def process_class_annotations_and_attribute_values(cls) -> Tuple[Dict[str, Any], attr_values: Dict[str, Any] = {} for base_cls in inspect.getmro(cls): - if '__annotations__' not in base_cls.__dict__: + if "__annotations__" not in base_cls.__dict__: continue for name in base_cls.__annotations__: @@ -225,11 +231,13 @@ def process_class_annotations_and_attribute_values(cls) -> Tuple[Dict[str, Any], attr_values[name] = base_cls.__dict__[name] continue - if ('__dataclass_fields__' in base_cls.__dict__ and - name in base_cls.__dict__['__dataclass_fields__']): + if ( + "__dataclass_fields__" in base_cls.__dict__ + and name in base_cls.__dict__["__dataclass_fields__"] + ): # cls is a dataclass, in which all fields appear in cls.__dataclass_fields__, # rather than directly in cls.__dict__. - attr_values[name] = base_cls.__dict__['__dataclass_fields__'][name] + attr_values[name] = base_cls.__dict__["__dataclass_fields__"][name] continue # Prepand annotations, so that they appear in reverse MRO order. @@ -239,7 +247,8 @@ def process_class_annotations_and_attribute_values(cls) -> Tuple[Dict[str, Any], def set_class_annotations_and_attribute_values( - cls, annotations: Dict[str, Any], attr_values: Dict[str, Any]): + cls, annotations: Dict[str, Any], attr_values: Dict[str, Any] +): """ Sets given attributes to cls.__dict__ and its annotations. The annotations will contain the given annotations, where the attributes with default values @@ -255,8 +264,10 @@ def set_class_annotations_and_attribute_values( # Locate members with default values in the end of the annotations dictionary. cls.__annotations__ = { - name: annotation for name, annotation in annotations.items() - if name not in default_value_annotations} + name: annotation + for name, annotation in annotations.items() + if name not in default_value_annotations + } cls.__annotations__.update(default_value_annotations) @@ -274,17 +285,20 @@ class A: # If member does not appear in __init__'s signature, having a default value is irrelevant. return ( - attr_value.init and - attr_value.default is not dataclasses.MISSING or + attr_value.init + and attr_value.default is not dataclasses.MISSING + or # Mypy has a problem with object members that are callables (it sees access to them as # passing self). This is actually originated in dataclasses' annotations in typeshed, since # the source code has no annotations. # See https://github.com/python/mypy/issues/6910 for details on this problem. - attr_value.default_factory is not dataclasses.MISSING) # type: ignore + attr_value.default_factory is not dataclasses.MISSING # type: ignore + ) # Validators for private use in this file. + def validate_value(*, field: mfields.Field, value: Any): """ Invokes the field's validator, if exists and it is callable. @@ -321,7 +335,7 @@ def validate_list(list_field: mfields.List, list_value: Sequence): if list_field.allow_none: return - raise marshmallow.ValidationError('Field may not be None.') + raise marshmallow.ValidationError("Field may not be None.") for inner_element in list_value: validate_field(field=list_field.inner, value=inner_element) diff --git a/src/starkware/starkware_utils/validated_fields.py b/src/starkware/starkware_utils/validated_fields.py index 0e6a6d7b..e060f2e9 100644 --- a/src/starkware/starkware_utils/validated_fields.py +++ b/src/starkware/starkware_utils/validated_fields.py @@ -10,9 +10,13 @@ from starkware.starkware_utils.error_handling import ErrorCode, StarkErrorCode, stark_assert from starkware.starkware_utils.field_validators import validate_in_range from starkware.starkware_utils.marshmallow_dataclass_fields import ( - BytesAsBase64Str, BytesAsHex, IntAsHex, IntAsStr) + BytesAsBase64Str, + BytesAsHex, + IntAsHex, + IntAsStr, +) -T = TypeVar('T') +T = TypeVar("T") class Field(ABC, Generic[T]): @@ -95,7 +99,8 @@ def metadata(self, field_name: Optional[str] = None): return dict( marshmallow_field=self.get_marshmallow_field(), validated_field=self, - name_in_messages=self.name if field_name is None else field_name) + name_in_messages=self.name if field_name is None else field_name, + ) class OptionalField(Field[Optional[T]]): @@ -125,7 +130,7 @@ def name(self) -> str: def format(self, value: Optional[T]) -> str: if value is None: - return 'None' + return "None" return self.field.format(value=value) # Randomization. @@ -147,9 +152,10 @@ def get_invalid_values(self) -> List[Optional[T]]: return [value for value in self.field.get_invalid_values() if value is not None] def format_invalid_value_error_message( - self, value: Optional[T], name: Optional[str] = None) -> str: + self, value: Optional[T], name: Optional[str] = None + ) -> str: if value is None: - return f'{name} is valid (None).' + return f"{name} is valid (None)." return self.field.format_invalid_value_error_message(value=value, name=name) def get_marshmallow_field(self) -> mfields.Field: @@ -167,7 +173,7 @@ class RangeValidatedField(Field[int]): name_in_error_message: str out_of_range_error_code: ErrorCode formatter: Optional[Callable[[int], str]] = None - out_of_range_message: ClassVar[str] = '{field_name} {field_value} is out of range' + out_of_range_message: ClassVar[str] = "{field_name} {field_value} is out of range" @property def name(self): @@ -185,8 +191,8 @@ def is_valid(self, value: int) -> bool: def format_invalid_value_error_message(self, value: int, name: Optional[str] = None) -> str: return self.out_of_range_message.format( - field_name=self.name if name is None else name, - field_value=self._format_value(value)) + field_name=self.name if name is None else name, field_value=self._format_value(value) + ) @property def error_code(self) -> ErrorCode: @@ -208,8 +214,9 @@ def get_marshmallow_field(self) -> mfields.Field: if self.formatter is None: return mfields.Integer(required=True) raise NotImplementedError( - f'{self.name}: The given formatter {self.formatter.__name__} ' - 'does not have a suitable metadata.') + f"{self.name}: The given formatter {self.formatter.__name__} " + "does not have a suitable metadata." + ) class BytesLengthField(Field[bytes]): @@ -220,7 +227,7 @@ class BytesLengthField(Field[bytes]): def __init__(self, name: str, error_code: StarkErrorCode, length: int): self._name = name self._error_code = error_code - assert length > 0, 'Bytes length must be at least 1.' + assert length > 0, "Bytes length must be at least 1." self.length = length @property @@ -236,7 +243,7 @@ def is_valid(self, value: bytes) -> bool: return len(value) == self.length def get_invalid_values(self) -> List[bytes]: - return [b'\x00' * (self.length - 1), b'\x00' * (self.length + 1)] + return [b"\x00" * (self.length - 1), b"\x00" * (self.length + 1)] @property def error_code(self) -> ErrorCode: @@ -245,7 +252,7 @@ def error_code(self) -> ErrorCode: def format_invalid_value_error_message(self, value: bytes, name: Optional[str] = None) -> str: name = self.name if name is None else name value_repr = self.format(value=value) - return f'{name} {value_repr} length is not {self.length} bytes, instead it is {len(value)}' + return f"{name} {value_repr} length is not {self.length} bytes, instead it is {len(value)}" # Serialization. def get_marshmallow_field(self) -> mfields.Field: @@ -257,9 +264,12 @@ def format(self, value: bytes) -> str: # Field metadata utilities. + def _generate_metadata( - marshmallow_field_cls: Type[mfields.Field], validated_field: Optional[Field], - required: Optional[bool] = None) -> Dict[str, Any]: + marshmallow_field_cls: Type[mfields.Field], + validated_field: Optional[Field], + required: Optional[bool] = None, +) -> Dict[str, Any]: if required is None: required = True @@ -271,38 +281,50 @@ def _generate_metadata( def int_metadata( - validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + validated_field: Optional[Field], required: Optional[bool] = None +) -> Dict[str, Any]: return _generate_metadata( - marshmallow_field_cls=mfields.Integer, validated_field=validated_field, required=required) + marshmallow_field_cls=mfields.Integer, validated_field=validated_field, required=required + ) def int_as_hex_metadata( - validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + validated_field: Optional[Field], required: Optional[bool] = None +) -> Dict[str, Any]: return _generate_metadata( - marshmallow_field_cls=IntAsHex, validated_field=validated_field, required=required) + marshmallow_field_cls=IntAsHex, validated_field=validated_field, required=required + ) def int_as_str_metadata( - validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + validated_field: Optional[Field], required: Optional[bool] = None +) -> Dict[str, Any]: return _generate_metadata( - marshmallow_field_cls=IntAsStr, validated_field=validated_field, required=required) + marshmallow_field_cls=IntAsStr, validated_field=validated_field, required=required + ) def bytes_as_hex_metadata( - validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + validated_field: Optional[Field], required: Optional[bool] = None +) -> Dict[str, Any]: return _generate_metadata( - marshmallow_field_cls=BytesAsHex, validated_field=validated_field, required=required) + marshmallow_field_cls=BytesAsHex, validated_field=validated_field, required=required + ) def bytes_as_base64_str_metadata( - validated_field: Optional[Field], required: Optional[bool] = None) -> Dict[str, Any]: + validated_field: Optional[Field], required: Optional[bool] = None +) -> Dict[str, Any]: return _generate_metadata( - marshmallow_field_cls=BytesAsBase64Str, validated_field=validated_field, required=required) + marshmallow_field_cls=BytesAsBase64Str, validated_field=validated_field, required=required + ) + def sequential_id_metadata( - field_name: str, required: bool = True, - allow_previous_id: bool = False) -> Dict[str, Any]: + field_name: str, required: bool = True, allow_previous_id: bool = False +) -> Dict[str, Any]: validator = validate_in_range(field_name=field_name, min_value=-1 if allow_previous_id else 0) return dict( - marshmallow_field=mfields.Integer(strict=True, required=required, validate=validator)) + marshmallow_field=mfields.Integer(strict=True, required=required, validate=validator) + ) diff --git a/src/starkware/storage/batch_store.py b/src/starkware/storage/batch_store.py index 87c9c7f5..e24c871f 100644 --- a/src/starkware/storage/batch_store.py +++ b/src/starkware/storage/batch_store.py @@ -14,10 +14,8 @@ def __init__(self, storage: Storage, n_workers_set: int, n_workers_get: int): self.store = storage self.set_queue: Queue[Tuple[bytes, bytes, asyncio.Future]] = Queue() self.get_queue: Queue[Tuple[bytes, asyncio.Future]] = Queue() - self.tasks = [ - asyncio.create_task(self.set_value_thread()) for _ in range(n_workers_set)] - self.tasks += [ - asyncio.create_task(self.get_value_thread()) for _ in range(n_workers_get)] + self.tasks = [asyncio.create_task(self.set_value_thread()) for _ in range(n_workers_set)] + self.tasks += [asyncio.create_task(self.get_value_thread()) for _ in range(n_workers_get)] async def close(self): for task in self.tasks: @@ -26,8 +24,8 @@ async def close(self): await task except Exception as ex: if not isinstance(ex, asyncio.CancelledError): - logger.error(f'Excpetion occurred! Exception: {ex}') - logger.debug('Exception details', exc_info=True) + logger.error(f"Excpetion occurred! Exception: {ex}") + logger.debug("Exception details", exc_info=True) async def set_value(self, key: bytes, value: bytes): # Put value in the set_queue. diff --git a/src/starkware/storage/batch_store_test.py b/src/starkware/storage/batch_store_test.py index 83e32d62..9956e481 100644 --- a/src/starkware/storage/batch_store_test.py +++ b/src/starkware/storage/batch_store_test.py @@ -15,11 +15,12 @@ async def test_batch_store(): storage = BatchStore(storage=inner_store, n_workers_set=2, n_workers_get=2) async def set_value(val_id): - await storage.set_value(f'key{val_id}'.encode('ascii'), f'value{val_id}'.encode('ascii')) + await storage.set_value(f"key{val_id}".encode("ascii"), f"value{val_id}".encode("ascii")) async def get_value(val_id): - assert await storage.get_value(f'key{val_id}'.encode('ascii')) == \ - f'value{val_id}'.encode('ascii') + assert await storage.get_value(f"key{val_id}".encode("ascii")) == f"value{val_id}".encode( + "ascii" + ) tasks = [asyncio.create_task(set_value(i)) for i in range(4)] await asyncio.sleep(0.02) diff --git a/src/starkware/storage/dict_storage.py b/src/starkware/storage/dict_storage.py index f687b493..e4fd16e5 100644 --- a/src/starkware/storage/dict_storage.py +++ b/src/starkware/storage/dict_storage.py @@ -53,4 +53,4 @@ async def get_value(self, key: bytes) -> Optional[bytes]: return value async def del_value(self, key: bytes): - raise NotImplementedError('CachedStorage is expected to handle only immutable items') + raise NotImplementedError("CachedStorage is expected to handle only immutable items") diff --git a/src/starkware/storage/gated_storage.py b/src/starkware/storage/gated_storage.py index 358ab935..99a9520c 100644 --- a/src/starkware/storage/gated_storage.py +++ b/src/starkware/storage/gated_storage.py @@ -4,7 +4,7 @@ from starkware.storage import Storage from starkware.storage.names import generate_unique_key -MAGIC_HEADER = hashlib.sha256(b'Gated storage magic header').digest() +MAGIC_HEADER = hashlib.sha256(b"Gated storage magic header").digest() class GatedStorage(Storage): @@ -31,32 +31,32 @@ async def _compress_value(self, key: bytes, value: bytes) -> Tuple[bytes, bytes] second storage, with a unique key and returns the new value that will be stored to the first storage which indicates that the original value is stored in storage1. """ - if value[:len(MAGIC_HEADER)] != MAGIC_HEADER: + if value[: len(MAGIC_HEADER)] != MAGIC_HEADER: # If the value starts with MAGIC_HEADER, treat the value as a large value; Hence, it # will be stored in the second storage. if len(value) <= self.limit: return key, value ukey = generate_unique_key( - item_type='gated', - props={'orig_key': key.hex()}, + item_type="gated", + props={"orig_key": key.hex()}, ) await self.storage1.set_value(key=ukey, value=value) new_value = MAGIC_HEADER + ukey return key, new_value async def set_value(self, key: bytes, value: bytes): - await self.storage0.set_value(* await self._compress_value(key=key, value=value)) + await self.storage0.set_value(*await self._compress_value(key=key, value=value)) async def setnx_value(self, key: bytes, value: bytes) -> bool: - return await self.storage0.setnx_value(* await self._compress_value(key=key, value=value)) + return await self.storage0.setnx_value(*await self._compress_value(key=key, value=value)) async def get_value(self, key: bytes) -> Optional[bytes]: value = await self.storage0.get_value(key=key) if value is None: return None - if (value[:len(MAGIC_HEADER)]) == MAGIC_HEADER: - ukey = value[len(MAGIC_HEADER):] + if (value[: len(MAGIC_HEADER)]) == MAGIC_HEADER: + ukey = value[len(MAGIC_HEADER) :] return await self.storage1.get_value(key=ukey) return value @@ -68,8 +68,8 @@ async def del_value(self, key: bytes): value = await self.storage0.get_value(key=key) if value is None: return - if (value[:len(MAGIC_HEADER)]) == MAGIC_HEADER: - ukey = value[len(MAGIC_HEADER):] + if (value[: len(MAGIC_HEADER)]) == MAGIC_HEADER: + ukey = value[len(MAGIC_HEADER) :] await self.storage1.del_value(key=ukey) await self.storage0.del_value(key=key) diff --git a/src/starkware/storage/gated_storage_test.py b/src/starkware/storage/gated_storage_test.py index bdc14b15..2ff876a1 100644 --- a/src/starkware/storage/gated_storage_test.py +++ b/src/starkware/storage/gated_storage_test.py @@ -8,15 +8,15 @@ async def test_gated_storage(): storage = GatedStorage(limit=10, storage0=MockStorage(), storage1=MockStorage()) - keys_values = [(b'k0', b'v0'), (b'k1', b'v1' * 6)] + keys_values = [(b"k0", b"v0"), (b"k1", b"v1" * 6)] for k, v in keys_values: assert await storage.get_value(key=k) is None await storage.set_value(key=k, value=v) assert await storage.get_value(key=k) == v - assert not await storage.setnx_value(key=k, value=b'wrong') + assert not await storage.setnx_value(key=k, value=b"wrong") assert await storage.get_value(key=k) == v - assert storage.storage0.db.keys() == {b'k0', b'k1'} + assert storage.storage0.db.keys() == {b"k0", b"k1"} assert len(storage.storage1.db.keys()) == 1 for k, _ in keys_values: @@ -33,10 +33,10 @@ async def test_magic_header_gated_storage(): will be stored in the secondary storage. """ storage = GatedStorage(limit=1000, storage0=MockStorage(), storage1=MockStorage()) - key, value = (b'k0', MAGIC_HEADER + b'v0') + key, value = (b"k0", MAGIC_HEADER + b"v0") await storage.set_value(key=key, value=value) assert await storage.get_value(key=key) == value - assert storage.storage0.db.keys() == {b'k0'} + assert storage.storage0.db.keys() == {b"k0"} assert len(storage.storage1.db.keys()) == 1 await storage.del_value(key=key) assert len(storage.storage0.db.keys()) == 0 diff --git a/src/starkware/storage/imm_storage.py b/src/starkware/storage/imm_storage.py index bca7d2be..dc4e0cc1 100644 --- a/src/starkware/storage/imm_storage.py +++ b/src/starkware/storage/imm_storage.py @@ -6,18 +6,21 @@ class _ImmediateStorage(Storage): - def __init__(self, storage: Storage): + def __init__(self, storage: Storage, avoid_write_through: bool): self.storage = storage self.write_tasks: List[asyncio.Task] = [] self.db: Dict[bytes, bytes] = {} + self.avoid_write_through = avoid_write_through async def set_value(self, key: bytes, value: bytes): - assert isinstance(key, bytes), f'key must be bytes. Got {type(key)}.' + assert isinstance(key, bytes), f"key must be bytes. Got {type(key)}." self.db[key] = value - self.write_tasks.append(asyncio.create_task(self.storage.set_value(key, value))) + + if not self.avoid_write_through: + self.write_tasks.append(asyncio.create_task(self.storage.set_value(key, value))) async def get_value(self, key: bytes) -> Optional[bytes]: - assert isinstance(key, bytes), f'key must be bytes. Got {type(key)}.' + assert isinstance(key, bytes), f"key must be bytes. Got {type(key)}." if key in self.db: return self.db[key] res = await self.storage.get_value(key) @@ -26,10 +29,12 @@ async def get_value(self, key: bytes) -> Optional[bytes]: return res async def del_value(self, key: bytes): - assert isinstance(key, bytes), f'key must be bytes. Got {type(key)}.' + assert isinstance(key, bytes), f"key must be bytes. Got {type(key)}." if key in self.db: del self.db[key] - self.write_tasks.append(asyncio.create_task(self.storage.del_value(key))) + + if not self.avoid_write_through: + self.write_tasks.append(asyncio.create_task(self.storage.del_value(key))) async def wait_for_all(self): for task in self.write_tasks: @@ -37,9 +42,9 @@ async def wait_for_all(self): @asynccontextmanager -async def immediate_storage(storage: Storage): +async def immediate_storage(storage: Storage, avoid_write_through: bool = False): try: - res = _ImmediateStorage(storage) + res = _ImmediateStorage(storage=storage, avoid_write_through=avoid_write_through) yield res finally: await res.wait_for_all() diff --git a/src/starkware/storage/internal_proxy_storage.py b/src/starkware/storage/internal_proxy_storage.py index e100bffa..1a34497e 100644 --- a/src/starkware/storage/internal_proxy_storage.py +++ b/src/starkware/storage/internal_proxy_storage.py @@ -10,10 +10,10 @@ def __init__(self, internal_client): self.internal_client = internal_client async def set_value(self, key, value): - raise NotImplementedError('Cannot set storage values in this version.') + raise NotImplementedError("Cannot set storage values in this version.") async def del_value(self, key): - raise NotImplementedError('Cannot delete storage values in this version.') + raise NotImplementedError("Cannot delete storage values in this version.") async def get_value(self, key): return await self.internal_client.get_value(key) diff --git a/src/starkware/storage/internal_proxy_storage_test.py b/src/starkware/storage/internal_proxy_storage_test.py index fc7cbd98..3ae30093 100644 --- a/src/starkware/storage/internal_proxy_storage_test.py +++ b/src/starkware/storage/internal_proxy_storage_test.py @@ -10,7 +10,7 @@ class MockInternalClient: async def get_value(self, key): - return str(key) + '_result' + return str(key) + "_result" @pytest.mark.asyncio @@ -18,7 +18,7 @@ async def test_internal_proxy_storage(): storage = InternalProxyStorage(internal_client=MockInternalClient()) async def get_value(val_id): - assert await storage.get_value(f'key{val_id}') == f'key{val_id}_result' + assert await storage.get_value(f"key{val_id}") == f"key{val_id}_result" tasks = [asyncio.create_task(get_value(i)) for i in range(4)] await asyncio.sleep(0.02) diff --git a/src/starkware/storage/metrics.py b/src/starkware/storage/metrics.py index d2fb7b7a..4b7d72d2 100644 --- a/src/starkware/storage/metrics.py +++ b/src/starkware/storage/metrics.py @@ -1,17 +1,17 @@ import prometheus_client CACHED_STORAGE_GET_TOTAL = prometheus_client.Counter( - name='starkware_cached_storage_get_total_count', - documentation='Count of total get_value() calls to CachedStorage', + name="starkware_cached_storage_get_total_count", + documentation="Count of total get_value() calls to CachedStorage", labelnames=(), ) CACHED_STORAGE_GET_CACHE = prometheus_client.Counter( - name='starkware_cached_storage_get_cache_count', - documentation='Count of get_value() calls to CachedStorage that got the value from the cache', + name="starkware_cached_storage_get_cache_count", + documentation="Count of get_value() calls to CachedStorage that got the value from the cache", labelnames=(), ) # Metric names may diverge on client argument. -CACHED_STORAGE_GET_TOTAL_NAME = getattr(CACHED_STORAGE_GET_TOTAL, '_name') -CACHED_STORAGE_GET_CACHE_NAME = getattr(CACHED_STORAGE_GET_CACHE, '_name') +CACHED_STORAGE_GET_TOTAL_NAME = getattr(CACHED_STORAGE_GET_TOTAL, "_name") +CACHED_STORAGE_GET_CACHE_NAME = getattr(CACHED_STORAGE_GET_CACHE, "_name") diff --git a/src/starkware/storage/names.py b/src/starkware/storage/names.py index f0ed2323..b0f333f1 100644 --- a/src/starkware/storage/names.py +++ b/src/starkware/storage/names.py @@ -10,11 +10,11 @@ def generate_unique_key(item_type: str, props: Dict[str, str] = {}) -> bytes: """ Generates a unique S3 storage key. """ - suffix = datetime.utcfromtimestamp(time()).strftime('%H%M%S') + suffix = datetime.utcfromtimestamp(time()).strftime("%H%M%S") t = datetime.fromtimestamp(time()) - suffix += '_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8)) - key = f'{t:year=%04Y/month=%02m/day=%02d}/type={item_type}' + suffix += "_" + "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) + key = f"{t:year=%04Y/month=%02m/day=%02d}/type={item_type}" for prop, val in props.items(): - key = f'{key}/{prop}={val}' - key = f'{key}/{suffix}' - return key.encode('ascii') + key = f"{key}/{prop}={val}" + key = f"{key}/{suffix}" + return key.encode("ascii") diff --git a/src/starkware/storage/storage.py b/src/starkware/storage/storage.py index a3db9e38..2fd31086 100644 --- a/src/starkware/storage/storage.py +++ b/src/starkware/storage/storage.py @@ -17,16 +17,16 @@ class Storage(ABC): """ @staticmethod - async def from_config(config, logger=None) -> 'Storage': + async def from_config(config, logger=None) -> "Storage": """ Creates a Storage instance from a config dictionary. """ - parts = config['class'].rsplit('.', 1) + parts = config["class"].rsplit(".", 1) storage_class = getattr(import_module(parts[0]), parts[1]) - if hasattr(storage_class, 'create_from_config'): - return await storage_class.create_from_config(**config['config']) - return storage_class(**config.get('config', {})) + if hasattr(storage_class, "create_from_config"): + return await storage_class.create_from_config(**config["config"]) + return storage_class(**config.get("config", {})) @abstractmethod async def set_value(self, key: bytes, value: bytes): @@ -49,13 +49,13 @@ async def mget(self, keys: Sequence[bytes]) -> Tuple[Optional[bytes], ...]: async def set_int(self, key: bytes, value: int): assert isinstance(key, bytes) assert isinstance(value, int) - value_bytes = str(value).encode('ascii') + value_bytes = str(value).encode("ascii") await self.set_value(key, value_bytes) async def setnx_int(self, key: bytes, value: int) -> bool: assert isinstance(key, bytes) assert isinstance(value, int) - value_bytes = str(value).encode('ascii') + value_bytes = str(value).encode("ascii") return await self.setnx_value(key, value_bytes) async def get_int(self, key: bytes, default=None) -> Optional[int]: @@ -66,13 +66,13 @@ async def get_int(self, key: bytes, default=None) -> Optional[int]: async def set_float(self, key: bytes, value: float): assert isinstance(key, bytes) assert isinstance(value, float) - value_bytes = str(value).encode('ascii') + value_bytes = str(value).encode("ascii") await self.set_value(key, value_bytes) async def setnx_float(self, key: bytes, value: float) -> bool: assert isinstance(key, bytes) assert isinstance(value, float) - value_bytes = str(value).encode('ascii') + value_bytes = str(value).encode("ascii") return await self.setnx_value(key, value_bytes) async def get_float(self, key: bytes, default=None) -> Optional[float]: @@ -83,22 +83,22 @@ async def get_float(self, key: bytes, default=None) -> Optional[float]: async def set_str(self, key: bytes, value: str): assert isinstance(key, bytes) assert isinstance(value, str) - value_bytes = value.encode('ascii') + value_bytes = value.encode("ascii") await self.set_value(key, value_bytes) async def setnx_str(self, key: bytes, value: str) -> bool: assert isinstance(key, bytes) assert isinstance(value, str) - value_bytes = value.encode('ascii') + value_bytes = value.encode("ascii") return await self.setnx_value(key, value_bytes) async def get_str(self, key: bytes, default=None) -> Optional[str]: assert isinstance(key, bytes) result = await self.get_value(key) - return default if result is None else result.decode('ascii') + return default if result is None else result.decode("ascii") async def setnx_value(self, key: bytes, value: bytes) -> bool: - raise NotImplementedError(f'{self.__class__.__name__} does not implement setnx_value') + raise NotImplementedError(f"{self.__class__.__name__} does not implement setnx_value") async def setnx_time(self, key: bytes, time: float): assert isinstance(key, bytes) @@ -110,7 +110,7 @@ async def get_time(self, key: bytes) -> Optional[float]: return await self.get_float(key) -TDBObject = TypeVar('TDBObject', bound='DBObject') +TDBObject = TypeVar("TDBObject", bound="DBObject") class DBObject(Serializable): @@ -123,7 +123,7 @@ def prefix(cls) -> bytes: @classmethod def db_key(cls, suffix: bytes) -> bytes: - return cls.prefix() + b':' + suffix + return cls.prefix() + b":" + suffix @classmethod async def get(cls: Type[TDBObject], storage: Storage, suffix: bytes) -> Optional[TDBObject]: @@ -150,7 +150,7 @@ def get_update_for_mset(self, suffix: bytes) -> Tuple[bytes, bytes]: return (self.db_key(suffix), self.serialize()) -TIndexedDBObject = TypeVar('TIndexedDBObject', bound='IndexedDBObject') +TIndexedDBObject = TypeVar("TIndexedDBObject", bound="IndexedDBObject") class IndexedDBObject(DBObject): @@ -160,19 +160,19 @@ class IndexedDBObject(DBObject): @classmethod def key(cls, index: int) -> bytes: - return cls.db_key(str(index).encode('ascii')) + return cls.db_key(str(index).encode("ascii")) @classmethod async def get_obj( - cls: Type[TIndexedDBObject], - storage: Storage, index: int) -> Optional[TIndexedDBObject]: - return await cls.get(storage, str(index).encode('ascii')) + cls: Type[TIndexedDBObject], storage: Storage, index: int + ) -> Optional[TIndexedDBObject]: + return await cls.get(storage, str(index).encode("ascii")) async def set_obj(self, storage: Storage, index: int): - await self.set(storage, str(index).encode('ascii')) + await self.set(storage, str(index).encode("ascii")) async def setnx_obj(self, storage: Storage, index: int) -> bool: - return await self.setnx(storage, str(index).encode('ascii')) + return await self.setnx(storage, str(index).encode("ascii")) def get_indexed_update_for_mset(self, index: int) -> Tuple[bytes, bytes]: """ @@ -193,15 +193,17 @@ class FactFetchingContext: """ def __init__( - self, storage: Storage, hash_func: HashFunctionType, n_workers: Optional[int] = None): + self, storage: Storage, hash_func: HashFunctionType, n_workers: Optional[int] = None + ): self.storage = storage self.hash_func = hash_func self.n_workers = n_workers def __repr__(self) -> str: return ( - f'{type(self)}(storage={self.storage!r}, hash_func={self.hash_func!r}, ' - f'n_workers={self.n_workers!r})') + f"{type(self)}(storage={self.storage!r}, hash_func={self.hash_func!r}, " + f"n_workers={self.n_workers!r})" + ) class Fact(DBObject): @@ -209,6 +211,7 @@ class Fact(DBObject): A fact is a DB object with a DB key that is a hash of its value. Use set_fact() and get() to read and write facts. """ + @abstractmethod async def _hash(self, hash_func: HashFunctionType) -> bytes: pass @@ -229,7 +232,7 @@ async def extend(self): pass @abstractmethod - async def __aenter__(self) -> 'LockObject': + async def __aenter__(self) -> "LockObject": pass @abstractmethod @@ -239,14 +242,14 @@ async def __aexit__(self, exc_type, exc, tb): class LockManager(ABC): @staticmethod - async def from_config(config, logger=None) -> 'LockManager': + async def from_config(config, logger=None) -> "LockManager": """ Creates a LockManager instance from a config dictionary. """ - parts = config['class'].rsplit('.', 1) + parts = config["class"].rsplit(".", 1) lock_manager_class = getattr(import_module(parts[0]), parts[1]) - return lock_manager_class(**config['config']) + return lock_manager_class(**config["config"]) @staticmethod @contextlib.asynccontextmanager @@ -272,6 +275,7 @@ async def destroy(self): @contextlib.contextmanager def distributed_hash_function(hash_function: HashFunctionType, n_hash_workers: int): with concurrent.futures.ProcessPoolExecutor(max_workers=n_hash_workers) as pool: + async def async_hash_funcion(x, y): return await asyncio.get_event_loop().run_in_executor(pool, hash_function, x, y) diff --git a/src/starkware/storage/storage_test.py b/src/starkware/storage/storage_test.py index 12b8815f..afea27bf 100644 --- a/src/starkware/storage/storage_test.py +++ b/src/starkware/storage/storage_test.py @@ -14,12 +14,12 @@ async def test_dummy_lock(): locked = [False] async def try_lock1(): - async with await lock_manager.lock('lock1') as _: + async with await lock_manager.lock("lock1") as _: locked[0] = True - async with await lock_manager.lock('lock0') as lock: + async with await lock_manager.lock("lock0") as lock: await lock.extend() - async with await lock_manager.lock('lock1') as lock: + async with await lock_manager.lock("lock1") as lock: # Try to lock. t = asyncio.create_task(try_lock1()) await asyncio.sleep(0.01) @@ -31,11 +31,11 @@ async def try_lock1(): @pytest.mark.asyncio async def test_from_config(): - config = {'class': 'starkware.storage.dict_storage.DictStorage', 'config': {}} + config = {"class": "starkware.storage.dict_storage.DictStorage", "config": {}} storage = await Storage.from_config(config) assert type(storage) is DictStorage - config['config']['bad_param'] = None - with pytest.raises(TypeError, match='got an unexpected keyword argument'): + config["config"]["bad_param"] = None + with pytest.raises(TypeError, match="got an unexpected keyword argument"): await Storage.from_config(config) diff --git a/src/starkware/storage/test_utils.py b/src/starkware/storage/test_utils.py index 740f4cd3..e32b3d9a 100644 --- a/src/starkware/storage/test_utils.py +++ b/src/starkware/storage/test_utils.py @@ -116,12 +116,11 @@ def check_time(t0, min_t, max_t): """ t1 = asyncio.get_event_loop().time() delta = t1 - t0 - assert min_t <= delta <= max_t, \ - 'Timing test failed' + assert min_t <= delta <= max_t, "Timing test failed" @contextmanager -def timed_call_range(min_t=0, max_t=2**20): +def timed_call_range(min_t=0, max_t=2 ** 20): """ Context manager that asserts that the code within took some amount of time, between min_t and max_t.