diff --git a/libs/blueprint b/libs/blueprint index 466989d..ce2204c 160000 --- a/libs/blueprint +++ b/libs/blueprint @@ -1 +1 @@ -Subproject commit 466989d28382262daa0c5ce5b2e63732a47e9474 +Subproject commit ce2204c50763bc4d9b771d4fdf1ba567ee181f16 diff --git a/mlir-assigner/include/mlir-assigner/components/comparison/fixed_comparison.hpp b/mlir-assigner/include/mlir-assigner/components/comparison/fixed_comparison.hpp index 505a91b..ee6efe0 100644 --- a/mlir-assigner/include/mlir-assigner/components/comparison/fixed_comparison.hpp +++ b/mlir-assigner/include/mlir-assigner/components/comparison/fixed_comparison.hpp @@ -83,7 +83,6 @@ namespace nil { // we compare 64 bits with this configuration auto result = call_component(operation, stack, bp, assignment, compParams); - // TODO should we store zero instead??? switch (operation.getPredicate()) { case mlir::arith::CmpFPredicate::UGT: case mlir::arith::CmpFPredicate::OGT: { @@ -122,9 +121,6 @@ namespace nil { UNREACHABLE("Unsupported fcmp predicate (UNO, ORD, AlwaysFalse, AlwaysTrue)"); break; } - default: - UNREACHABLE("Unsupported fcmp predicate"); - break; } } diff --git a/mlir-assigner/include/mlir-assigner/components/comparison/select.hpp b/mlir-assigner/include/mlir-assigner/components/comparison/select.hpp index a8dd6cd..e95a794 100644 --- a/mlir-assigner/include/mlir-assigner/components/comparison/select.hpp +++ b/mlir-assigner/include/mlir-assigner/components/comparison/select.hpp @@ -48,13 +48,12 @@ namespace nil { circuit_proxy> &bp, assignment_proxy> &assignment, - const common_component_parameters> &compParams) { - auto c = stack.get_local(operation.getCondition()); - auto x = stack.get_local(operation.getTrueValue()); - auto y = stack.get_local(operation.getFalseValue()); + const common_component_parameters< + crypto3::zk::snark::plonk_variable> &compParams) { using component_type = components::fix_select< crypto3::zk::snark::plonk_constraint_system, - BlueprintFieldType, basic_non_native_policy>; + BlueprintFieldType, + basic_non_native_policy>; using manifest_reader = detail::ManifestReader; typename component_type::input_type input; diff --git a/mlir-assigner/include/mlir-assigner/components/fixedpoint/dot_product.hpp b/mlir-assigner/include/mlir-assigner/components/fixedpoint/dot_product.hpp index 774f09e..619b542 100644 --- a/mlir-assigner/include/mlir-assigner/components/fixedpoint/dot_product.hpp +++ b/mlir-assigner/include/mlir-assigner/components/fixedpoint/dot_product.hpp @@ -30,9 +30,6 @@ namespace nil { &assignment, const common_component_parameters< crypto3::zk::snark::plonk_variable> &compParams) { - using component_type = components::fix_dot_rescale_2_gates< - crypto3::zk::snark::plonk_constraint_system, - BlueprintFieldType, basic_non_native_policy>; mlir::Value lhs = operation.getLhs(); mlir::Value rhs = operation.getRhs(); diff --git a/mlir-assigner/include/mlir-assigner/components/handle_component.hpp b/mlir-assigner/include/mlir-assigner/components/handle_component.hpp index 159ff2f..2677784 100644 --- a/mlir-assigner/include/mlir-assigner/components/handle_component.hpp +++ b/mlir-assigner/include/mlir-assigner/components/handle_component.hpp @@ -187,8 +187,8 @@ namespace nil { typename ComponentType::result_type result = std::uint8_t(compParams.gen_mode & generation_mode::ASSIGNMENTS) ? - result = components::generate_assignments(component, assignment, input, compParams.start_row) : - result = typename ComponentType::result_type(component, compParams.start_row); + components::generate_assignments(component, assignment, input, compParams.start_row) : + typename ComponentType::result_type(component, compParams.start_row); // touch result variables if (std::uint8_t(compParams.gen_mode & generation_mode::ASSIGNMENTS) == 0) { diff --git a/mlir-assigner/include/mlir-assigner/memory/memref.hpp b/mlir-assigner/include/mlir-assigner/memory/memref.hpp index 1323c1f..357907f 100644 --- a/mlir-assigner/include/mlir-assigner/memory/memref.hpp +++ b/mlir-assigner/include/mlir-assigner/memory/memref.hpp @@ -25,9 +25,10 @@ namespace nil { memref(std::vector dims, mlir::Type type) : data(), dims(dims), strides(), type(type) { strides.resize(dims.size()); - for (int i = dims.size() - 1; i >= 0; i--) { - strides[i] = (i == dims.size() - 1) ? 1 : strides[i + 1] * dims[i + 1]; - ASSERT(dims[i] > 0 && "Dims in tensor must be greater zero. Do you have a model with dynamic input?"); + for (ssize_t i = dims.size() - 1; i >= 0; i--) { + strides[i] = (i == ssize_t(dims.size()) - 1) ? 1 : strides[i + 1] * dims[i + 1]; + ASSERT(dims[i] > 0 && + "Dims in tensor must be greater zero. Do you have a model with dynamic input?"); } // this also handles the case when dims is empty, since we still allocate // 1 here @@ -35,9 +36,10 @@ namespace nil { } memref(llvm::ArrayRef dims, mlir::Type type) : data(), dims(dims), strides(), type(type) { strides.resize(dims.size()); - for (int i = dims.size() - 1; i >= 0; i--) { - strides[i] = (i == dims.size() - 1) ? 1 : strides[i + 1] * dims[i + 1]; - ASSERT(dims[i] > 0 && "Dims in tensor must be greater zero. Do you have a model with dynamic input?"); + for (ssize_t i = dims.size() - 1; i >= 0; i--) { + strides[i] = (i == ssize_t(dims.size()) - 1) ? 1 : strides[i + 1] * dims[i + 1]; + ASSERT(dims[i] > 0 && + "Dims in tensor must be greater zero. Do you have a model with dynamic input?"); } // this also handles the case when dims is empty, since we still allocate 1 // here @@ -47,8 +49,8 @@ namespace nil { memref(std::vector dims, std::vector data, mlir::Type type) : data(data), dims(dims), strides() { strides.resize(dims.size()); - for (int i = dims.size() - 1; i >= 0; i--) { - strides[i] = (i == dims.size() - 1) ? 1 : strides[i + 1] * dims[i + 1]; + for (ssize_t i = dims.size() - 1; i >= 0; i--) { + strides[i] = (i == ssize_t(dims.size()) - 1) ? 1 : strides[i + 1] * dims[i + 1]; } assert(data.size() == std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies())); @@ -57,7 +59,7 @@ namespace nil { const VarType &get(const std::vector &indices) const { assert(indices.size() == dims.size()); uint32_t offset = 0; - for (int i = 0; i < indices.size(); i++) { + for (size_t i = 0; i < indices.size(); i++) { assert(indices[i] < dims[i]); offset += indices[i] * strides[i]; } @@ -67,7 +69,7 @@ namespace nil { const VarType &get(const llvm::SmallVector &indices) const { assert(indices.size() == dims.size()); uint32_t offset = 0; - for (int i = 0; i < indices.size(); i++) { + for (size_t i = 0; i < indices.size(); i++) { assert(indices[i] < dims[i]); offset += indices[i] * strides[i]; } @@ -81,7 +83,7 @@ namespace nil { void put(const std::vector &indices, const VarType &value) { assert(indices.size() == dims.size()); uint32_t offset = 0; - for (int i = 0; i < indices.size(); i++) { + for (size_t i = 0; i < indices.size(); i++) { assert(indices[i] < dims[i]); offset += indices[i] * strides[i]; } @@ -91,7 +93,7 @@ namespace nil { void put(const llvm::SmallVector &indices, const VarType &value) { assert(indices.size() == dims.size()); uint32_t offset = 0; - for (int i = 0; i < indices.size(); i++) { + for (size_t i = 0; i < indices.size(); i++) { assert(indices[i] < dims[i]); offset += indices[i] * strides[i]; } @@ -99,14 +101,14 @@ namespace nil { } void put_flat(const int64_t idx, const VarType &value) { - assert(idx >= 0 && idx < data.size()); + assert(idx >= 0 && size_t(idx) < data.size()); data[idx] = value; } mlir::Type getType() const { return type; } - int64_t size() const { + size_t size() const { return data.size(); } @@ -126,7 +128,7 @@ namespace nil { &assignment) { using FixedPoint = components::FixedPoint; os << "memref<"; - for (int i = 0; i < dims.size(); i++) { + for (size_t i = 0; i < dims.size(); i++) { os << dims[i]; os << "x"; } @@ -136,7 +138,7 @@ namespace nil { os << type_str; if (type.isa()) { if (type.isUnsignedInteger()) { - for (int i = 0; i < data.size(); i++) { + for (size_t i = 0; i < data.size(); i++) { os << var_value(assignment, data[i]).data; if (i != data.size() - 1) os << ","; @@ -145,7 +147,7 @@ namespace nil { static constexpr typename BlueprintFieldType::integral_type half_p = (BlueprintFieldType::modulus - typename BlueprintFieldType::integral_type(1)) / typename BlueprintFieldType::integral_type(2); - for (int i = 0; i < data.size(); i++) { + for (size_t i = 0; i < data.size(); i++) { auto val = static_cast( var_value(assignment, data[i]).data); // check if negative @@ -159,7 +161,7 @@ namespace nil { } } } else if (type.isa()) { - for (int i = 0; i < data.size(); i++) { + for (size_t i = 0; i < data.size(); i++) { auto value = var_value(assignment, data[i]).data; FixedPoint out(value, FixedPoint::SCALE); os << out.to_double(); diff --git a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp index 42fa0eb..d570feb 100644 --- a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp +++ b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp @@ -187,10 +187,9 @@ namespace zk_ml_toolchain { boost::json::array &public_output, generation_mode gen_mode, nil::blueprint::print_format print_circuit_format, std::string &clip, nil::blueprint::logger &logger) : - bp(circuit), - assignmnt(assignment), public_input(public_input), private_input(private_input), - public_output(public_output), gen_mode(gen_mode), print_circuit_format(print_circuit_format), - logger(logger) { + gen_mode(gen_mode), + print_circuit_format(print_circuit_format), logger(logger), bp(circuit), assignmnt(assignment), + public_input(public_input), private_input(private_input), public_output(public_output) { lower_bound = FixedPoint(1, FixedPoint::SCALE).to_double(); if ("clip" == clip) { clip_strategy = ClipStrategy::CLIP; @@ -208,16 +207,15 @@ namespace zk_ml_toolchain { evaluator &operator=(const evaluator &pass) = delete; void handleKrnlEntryOperation(KrnlEntryPointOp &EntryPoint, std::string &func) { - int32_t numInputs = -1; - int32_t numOutputs = -1; - for (auto a : EntryPoint->getAttrs()) { if (a.getName() == EntryPoint.getEntryPointFuncAttrName()) { func = a.getValue().cast().getLeafReference().str(); } else if (a.getName() == EntryPoint.getNumInputsAttrName()) { - numInputs = a.getValue().cast().getInt(); + // do nothing for num inputs atm + // a.getValue().cast().getInt(); } else if (a.getName() == EntryPoint.getNumOutputsAttrName()) { - numOutputs = a.getValue().cast().getInt(); + // do nothing for num outputs atm + // a.getValue().cast().getInt(); } else if (a.getName() == EntryPoint.getSignatureAttrName()) { // do nothing for signature atm // TODO: check against input types & shapes @@ -373,13 +371,13 @@ namespace zk_ml_toolchain { } else if (arith::SubFOp operation = llvm::dyn_cast(op)) { nil::blueprint::handle_sub(operation, stack, bp, assignmnt, compParams); } else if (arith::MulFOp operation = llvm::dyn_cast(op)) { - handle_fmul(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_fmul(operation, stack, bp, assignmnt, compParams); } else if (arith::DivFOp operation = llvm::dyn_cast(op)) { - handle_fdiv(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_fdiv(operation, stack, bp, assignmnt, compParams); } else if (arith::RemFOp operation = llvm::dyn_cast(op)) { - handle_frem(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_frem(operation, stack, bp, assignmnt, compParams); } else if (arith::CmpFOp operation = llvm::dyn_cast(op)) { - handle_fcmp(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_fcmp(operation, stack, bp, assignmnt, compParams); } else if (arith::SelectOp operation = llvm::dyn_cast(op)) { ASSERT(operation.getNumOperands() == 3 && "Select must have three operands"); ASSERT(operation->getOperand(1).getType() == operation->getOperand(2).getType() && @@ -406,7 +404,7 @@ namespace zk_ml_toolchain { stack.push_local(operation->getResult(0), falsy); } } else { - handle_select(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_select(operation, stack, bp, assignmnt, compParams); } } else { std::string typeStr; @@ -415,7 +413,7 @@ namespace zk_ml_toolchain { UNREACHABLE(std::string("unhandled select operand: ") + typeStr); } } else if (arith::NegFOp operation = llvm::dyn_cast(op)) { - handle_neg(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_neg(operation, stack, bp, assignmnt, compParams); } else if (arith::AndIOp operation = llvm::dyn_cast(op)) { // check if logical and or bitwise and mlir::Type LhsType = operation.getLhs().getType(); @@ -608,7 +606,7 @@ namespace zk_ml_toolchain { if (math::ExpOp operation = llvm::dyn_cast(op)) { nil::blueprint::handle_exp(operation, stack, bp, assignmnt, compParams); } else if (math::LogOp operation = llvm::dyn_cast(op)) { - handle_log(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_log(operation, stack, bp, assignmnt, compParams); } else if (math::PowFOp operation = llvm::dyn_cast(op)) { UNREACHABLE("powf not supported. Did you compile the model with standard MLIR?"); } else if (math::IPowIOp operation = llvm::dyn_cast(op)) { @@ -620,11 +618,11 @@ namespace zk_ml_toolchain { assert(base == 2 && "For now we only support powi to power of 2"); stack.push_constant(operation.getResult(), 1 << pow); } else if (math::AbsFOp operation = llvm::dyn_cast(op)) { - handle_abs(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_abs(operation, stack, bp, assignmnt, compParams); } else if (math::CeilOp operation = llvm::dyn_cast(op)) { - handle_ceil(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_ceil(operation, stack, bp, assignmnt, compParams); } else if (math::FloorOp operation = llvm::dyn_cast(op)) { - handle_floor(operation, stack, bp, assignmnt, compParams); + nil::blueprint::handle_floor(operation, stack, bp, assignmnt, compParams); } else if (math::CopySignOp operation = llvm::dyn_cast(op)) { // TODO: do nothing for now since it only comes up during mod, and there // the component handles this correctly; do we need this later on? @@ -828,7 +826,8 @@ namespace zk_ml_toolchain { void handleZkMlOperation(Operation *op, const ComponentParameters &compParams) { if (zkml::DotProductOp operation = llvm::dyn_cast(op)) { - handle_dot_product(operation, zero_var, stack, bp, assignmnt, compParams); + nil::blueprint::handle_dot_product(operation, zero_var, stack, bp, assignmnt, + compParams); } else if (zkml::ArgMinOp operation = llvm::dyn_cast(op)) { auto nextIndexVar = put_into_assignment(stack.get_constant(operation.getNextIndex())); nil::blueprint::handle_argmin(operation, stack, bp, assignmnt, nextIndexVar, @@ -868,9 +867,6 @@ namespace zk_ml_toolchain { logger.debug("allocating memref"); logger << operation; MemRefType type = operation.getType(); - auto uses = operation->getResult(0).getUsers(); - auto res = operation->getResult(0); - auto res2 = operation.getMemref(); // check for dynamic size std::vector dims; auto operands = operation.getOperands(); diff --git a/mlir-assigner/include/mlir-assigner/parser/parser.hpp b/mlir-assigner/include/mlir-assigner/parser/parser.hpp index e1f47e4..3ab9a37 100644 --- a/mlir-assigner/include/mlir-assigner/parser/parser.hpp +++ b/mlir-assigner/include/mlir-assigner/parser/parser.hpp @@ -59,8 +59,8 @@ namespace nil { std::uint32_t target_prover_idx, const std::string &policy = "", generation_mode gen_mode = generation_mode::ASSIGNMENTS & generation_mode::CIRCUIT, print_format output_print_format = no_print) : - max_num_provers(max_num_provers), - gen_mode(gen_mode), print_output_format(output_print_format) { + gen_mode(gen_mode), + print_output_format(output_print_format), max_num_provers(max_num_provers) { if (max_num_provers != 1) { throw std::runtime_error( "Currently only one prover is supported, please " diff --git a/mlir-assigner/src/main.cpp b/mlir-assigner/src/main.cpp index db2585e..5a2f0e2 100644 --- a/mlir-assigner/src/main.cpp +++ b/mlir-assigner/src/main.cpp @@ -110,12 +110,12 @@ void print_circuit(const circuit_proxy &circuit_proxy, const auto second_var = constraint.second; if ((first_var.type == variable_type::column_type::witness || first_var.type == variable_type::column_type::constant) && - first_var.rotation == row) { + uint32_t(first_var.rotation) == row) { constraint.first = variable_type(first_var.index, local_row, first_var.relative, first_var.type); } if ((second_var.type == variable_type::column_type::witness || second_var.type == variable_type::column_type::constant) && - second_var.rotation == row) { + uint32_t(second_var.rotation) == row) { constraint.second = variable_type(second_var.index, local_row, second_var.relative, second_var.type); } @@ -428,11 +428,7 @@ int curve_dependent_main(std::string bytecode_file_name, using ArithmetizationParams = zk::snark::plonk_arithmetization_params; using ConstraintSystemType = zk::snark::plonk_constraint_system; - using ConstraintSystemProxyType = - zk::snark::plonk_table>; using ArithmetizationType = crypto3::zk::snark::plonk_constraint_system; - using AssignmentTableType = - zk::snark::plonk_table>; boost::json::value public_input_json_value; if (public_input_file_name.empty()) { @@ -507,15 +503,14 @@ int curve_dependent_main(std::string bytecode_file_name, // fill ComponentConstantColumns, ComponentConstantColumns + 1, ... std::iota(lookup_columns_indices.begin(), lookup_columns_indices.end(), ComponentConstantColumns); - auto usable_rows_amount = - zk::snark::pack_lookup_tables_horizontal(parser_instance.circuits[0].get_reserved_indices(), - parser_instance.circuits[0].get_reserved_tables(), - parser_instance.circuits[0].get(), - parser_instance.assignments[0].get(), - lookup_columns_indices, - ComponentSelectorColumns, - 0, - max_lookup_rows); + zk::snark::pack_lookup_tables_horizontal(parser_instance.circuits[0].get_reserved_indices(), + parser_instance.circuits[0].get_reserved_tables(), + parser_instance.circuits[0].get(), + parser_instance.assignments[0].get(), + lookup_columns_indices, + ComponentSelectorColumns, + 0, + max_lookup_rows); } constexpr std::uint32_t invalid_target_prover = std::numeric_limits::max(); @@ -628,8 +623,8 @@ int curve_dependent_main(std::string bytecode_file_name, ASSERT_MSG( nil::blueprint::is_satisfied(parser_instance.circuits[0].get(), parser_instance.assignments[0].get()), "The circuit is not satisfied"); - } else if (parser_instance.assignments.size() > 1 && (target_prover < parser_instance.assignments.size() || - invalid_target_prover == invalid_target_prover)) { + } else if (parser_instance.assignments.size() > 1 && + (target_prover < parser_instance.assignments.size() || target_prover == invalid_target_prover)) { // check only for target prover if set std::uint32_t start_idx = (target_prover == invalid_target_prover) ? 0 : target_prover; std::uint32_t end_idx =