From df674f73232a1e25317ca7bddcfba3227e2ad41c Mon Sep 17 00:00:00 2001 From: fnieddu <118167989+fnieddu@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:07:20 +0100 Subject: [PATCH] feat: to int (#24) * to int working (except 8bits version) * pinned blueprint to latest version --- libs/blueprint | 2 +- .../components/fixedpoint/conversion.hpp | 133 ++++++++++++++++++ .../components/fixedpoint/to_fixpoint.hpp | 44 ------ .../components/handle_component.hpp | 1 + .../mlir-assigner/parser/evaluator.hpp | 22 ++- .../Cast/CastToInt64.json | 1 - .../Cast/CastToInt64.mlir | 14 -- .../Cast/CastToInt64.res | 3 - .../tests/Ops/Onnx/Cast/CastToInt64.json | 1 + .../Cast/CastToInt64.onnx | 0 .../tests/Ops/Onnx/Cast/CastToInt64.res | 3 + .../tests/Ops/Onnx/Cast/CastToUInt64.json | 1 + .../tests/Ops/Onnx/Cast/CastToUInt64.onnx | 14 ++ .../tests/Ops/Onnx/Cast/CastToUInt64.res | 3 + .../Resize/ResizeSimple.json | 0 .../Resize/ResizeSimple.mlir | 0 .../Resize/ResizeSimple.onnx | Bin .../Resize/ResizeSimple.res | 0 18 files changed, 177 insertions(+), 65 deletions(-) create mode 100644 mlir-assigner/include/mlir-assigner/components/fixedpoint/conversion.hpp delete mode 100644 mlir-assigner/include/mlir-assigner/components/fixedpoint/to_fixpoint.hpp delete mode 100644 mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.json delete mode 100644 mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.mlir delete mode 100644 mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.json rename mlir-assigner/tests/Ops/{NeedsBlueprintComponent => Onnx}/Cast/CastToInt64.onnx (100%) create mode 100644 mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.res rename mlir-assigner/tests/Ops/{NeedsBlueprintComponent => Problematic}/Resize/ResizeSimple.json (100%) rename mlir-assigner/tests/Ops/{NeedsBlueprintComponent => Problematic}/Resize/ResizeSimple.mlir (100%) rename mlir-assigner/tests/Ops/{NeedsBlueprintComponent => Problematic}/Resize/ResizeSimple.onnx (100%) rename mlir-assigner/tests/Ops/{NeedsBlueprintComponent => Problematic}/Resize/ResizeSimple.res (100%) diff --git a/libs/blueprint b/libs/blueprint index 8c8ef7c..5fc4ba2 160000 --- a/libs/blueprint +++ b/libs/blueprint @@ -1 +1 @@ -Subproject commit 8c8ef7c89d6a5cd3a367809a86161afa6964148c +Subproject commit 5fc4ba251865a418fc6dc8394f24911d3c6a1567 diff --git a/mlir-assigner/include/mlir-assigner/components/fixedpoint/conversion.hpp b/mlir-assigner/include/mlir-assigner/components/fixedpoint/conversion.hpp new file mode 100644 index 0000000..f0b1866 --- /dev/null +++ b/mlir-assigner/include/mlir-assigner/components/fixedpoint/conversion.hpp @@ -0,0 +1,133 @@ +#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP +#define CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP + +#include "mlir/Dialect/zkml/IR/DotProduct.h" +#include +#include + +#include + +#include +#include +#include // TODO: check if there is a new mechanism for this in nil upstream + +#include +#include +#include + +namespace nil { + namespace blueprint { + + template + void handle_to_fixedpoint( + MlirOp &operation, + stack_frame> &frame, + circuit_proxy> &bp, + assignment_proxy> + &assignment, + std::uint32_t start_row) { + using component_type = components::int_to_fix< + crypto3::zk::snark::plonk_constraint_system, + BlueprintFieldType, basic_non_native_policy>; + + auto input = PREPARE_UNARY_INPUT(MlirOp); + using manifest_reader = detail::ManifestReader; + const auto p = detail::PolicyManager::get_parameters( + detail::ManifestReader::get_witness(0)); + + component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), + 1); + fill_trace(component, input, operation, frame, bp, assignment, start_row); + } + namespace detail { + template + void handle_to_int( + MlirOp &operation, + stack_frame> &frame, + circuit_proxy> + &bp, + assignment_proxy> + &assignment, + std::uint32_t start_row) { + using component_type = components::fix_to_int< + crypto3::zk::snark::plonk_constraint_system, + BlueprintFieldType, basic_non_native_policy>; + typename component_type::OutputType outputType = + static_cast(OutputType); + auto input = PREPARE_UNARY_INPUT(MlirOp); + using manifest_reader = detail::ManifestReader; + const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0)); + + component_type component(p.witness, manifest_reader::get_constants(), + manifest_reader::get_public_inputs(), 1, 1, outputType); + fill_trace(component, input, operation, frame, bp, assignment, start_row); + } + } // namespace detail + +#define HANDLE_TO_INT(TY) \ + detail::handle_to_int( \ + operation, frame, bp, assignment, start_row); + +#define HANDLE_TO_UINT(TY) \ + detail::handle_to_int( \ + operation, frame, bp, assignment, start_row); + template + void handle_to_int( + mlir::arith::FPToSIOp &operation, + stack_frame> &frame, + circuit_proxy> &bp, + assignment_proxy> + &assignment, + std::uint32_t start_row) { + using component_type = components::fix_to_int< + crypto3::zk::snark::plonk_constraint_system, + BlueprintFieldType, basic_non_native_policy>; + switch (operation->getResult(0).getType().getIntOrFloatBitWidth()) { + case 8: + HANDLE_TO_INT(component_type::OutputType::I8); + break; + case 16: + HANDLE_TO_INT(component_type::OutputType::I16); + break; + case 32: + HANDLE_TO_INT(component_type::OutputType::I32); + break; + case 64: + HANDLE_TO_INT(component_type::OutputType::I64); + break; + } + } + + template + void handle_to_int( + mlir::arith::FPToUIOp &operation, + stack_frame> &frame, + circuit_proxy> &bp, + assignment_proxy> + &assignment, + std::uint32_t start_row) { + using component_type = components::fix_to_int< + crypto3::zk::snark::plonk_constraint_system, + BlueprintFieldType, basic_non_native_policy>; + switch (operation->getResult(0).getType().getIntOrFloatBitWidth()) { + case 8: + HANDLE_TO_UINT(component_type::OutputType::U8); + break; + case 16: + HANDLE_TO_UINT(component_type::OutputType::U16); + break; + case 32: + HANDLE_TO_UINT(component_type::OutputType::U32); + break; + case 64: + HANDLE_TO_UINT(component_type::OutputType::U64); + break; + } + } + +#undef HANDLE_TO_INT +#undef HANDLE_TO_UINT + } // namespace blueprint +} // namespace nil + +#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP diff --git a/mlir-assigner/include/mlir-assigner/components/fixedpoint/to_fixpoint.hpp b/mlir-assigner/include/mlir-assigner/components/fixedpoint/to_fixpoint.hpp deleted file mode 100644 index e855e6b..0000000 --- a/mlir-assigner/include/mlir-assigner/components/fixedpoint/to_fixpoint.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP -#define CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP - -#include "mlir/Dialect/zkml/IR/DotProduct.h" -#include - -#include - -#include -#include -#include // TODO: check if there is a new mechanism for this in nil upstream - -#include -#include -#include - -namespace nil { - namespace blueprint { - - template - void handle_to_fixedpoint( - MlirOp &operation, - stack_frame> &frame, - circuit_proxy> &bp, - assignment_proxy> - &assignment, - std::uint32_t start_row) { - using component_type = components::int_to_fix< - crypto3::zk::snark::plonk_constraint_system, - BlueprintFieldType, basic_non_native_policy>; - - auto input = PREPARE_UNARY_INPUT(MlirOp); - using manifest_reader = detail::ManifestReader; - const auto p = detail::PolicyManager::get_parameters( - detail::ManifestReader::get_witness(0)); - - component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(), - 1); - fill_trace(component, input, operation, frame, bp, assignment, start_row); - } - } // namespace blueprint -} // namespace nil - -#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP diff --git a/mlir-assigner/include/mlir-assigner/components/handle_component.hpp b/mlir-assigner/include/mlir-assigner/components/handle_component.hpp index d38c0b2..0a6bb08 100644 --- a/mlir-assigner/include/mlir-assigner/components/handle_component.hpp +++ b/mlir-assigner/include/mlir-assigner/components/handle_component.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp index 069e33f..3b126f8 100644 --- a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp +++ b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp @@ -56,7 +56,7 @@ #include #include #include -#include +#include #include #include @@ -506,7 +506,9 @@ namespace zk_ml_toolchain { } else if (arith::UIToFPOp operation = llvm::dyn_cast(op)) { handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row); } else if (arith::FPToSIOp operation = llvm::dyn_cast(op)) { - UNREACHABLE("Cast from FixedPoint to Int??"); + handle_to_int(operation, frames.back(), bp, assignmnt, start_row); + } else if (arith::FPToUIOp operation = llvm::dyn_cast(op)) { + handle_to_int(operation, frames.back(), bp, assignmnt, start_row); } else if (llvm::isa(op) || llvm::isa(op) || llvm::isa(op)) { auto toExtend = frames.back().locals.find(mlir::hash_value(op->getOperand(0))); @@ -1028,6 +1030,22 @@ namespace zk_ml_toolchain { return; } + if (mlir::UnrealizedConversionCastOp operation = llvm::dyn_cast(op)) { + // we do not like this but when onnx-mlir lowers from onnx.Cast to unsigned it uses this to cast + // from signless integers (e.g. i64) to unsigned integer(e.g. ui64) + // SO if we transform from one signless integer to an unsigned integer with the SAME bit length + // we indulge, otherwise we panic + mlir::Type SrcType = operation->getOperand(0).getType(); + mlir::Type DstType = operation->getResult(0).getType(); + assert(SrcType.isSignlessInteger() && "src must be signless integertype for conversion cast"); + assert(DstType.isUnsignedInteger(SrcType.getIntOrFloatBitWidth()) && + "dst must be unsigned integer with same bit width as src"); + auto Src = frames.back().locals.find(mlir::hash_value(operation->getOperand(0))); + assert(Src != frames.back().locals.end()); + frames.back().locals[mlir::hash_value(operation->getResult(0))] = Src->second; + return; + } + std::string opName = op->getName().getIdentifier().str(); llvm::outs() << op->getDialect()->getNamespace() << "\n"; UNREACHABLE(std::string("unhandled operation: ") + opName); diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.json b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.json deleted file mode 100644 index 932a112..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.json +++ /dev/null @@ -1 +0,0 @@ -[{"memref": {"data": [0.052490234375, 0.0849761962890625, 0.290435791015625, 0.7718963623046875, 0.4334564208984375, 0.4372100830078125, 0.5269012451171875, 0.173980712890625, 0.0650787353515625, 0.6441802978515625], "dims": [1, 10], "type": "f32"}}] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.mlir b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.mlir deleted file mode 100644 index c586633..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.mlir +++ /dev/null @@ -1,14 +0,0 @@ -module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "casttoint64.mlir"} { - func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xi64> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} { - %alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xi64> - affine.for %arg1 = 0 to 1 { - affine.for %arg2 = 0 to 10 { - %0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32> - %1 = arith.fptosi %0 : f32 to i64 - affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xi64> - } - } - return %alloc : memref<1x10xi64> - } - "krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22i64\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> () -} diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.res b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.res deleted file mode 100644 index 77bec27..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.res +++ /dev/null @@ -1,3 +0,0 @@ -Result: -memref<1x10xint>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -ADD THE ROWS HERE \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.json b/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.json new file mode 100644 index 0000000..74c887f --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.json @@ -0,0 +1 @@ +[{"memref": {"data": [-4.334595012532761, 92.92536192655963, -83.79716946092812, -47.00000000000001, -58.99999999999999, 68.14344958903294, 97.32986424021, 43.620091845482136, 75.29771234198004, 57.999999999999], "dims": [1, 10], "type": "f32"}}] diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.onnx b/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.onnx similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/Cast/CastToInt64.onnx rename to mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.res b/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.res new file mode 100644 index 0000000..4b39298 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xi64>[-4, 92, -83, -47, -58, 68, 97, 43, 75, 57] +12 diff --git a/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.json b/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.json new file mode 100644 index 0000000..b26f43d --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.json @@ -0,0 +1 @@ +[{"memref": {"data": [4.000000000000001, 92.92536192655963, 83.79716946092812, 47.67719849791331, 58.17853088366004, 68.14344958903294, 97.32986424021, 43.620091845482136, 75.29771234198004, 57.999999999999], "dims": [1, 10], "type": "f32"}}] diff --git a/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.onnx b/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.onnx new file mode 100644 index 0000000..661a7c5 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.onnx @@ -0,0 +1,14 @@ + :_ + +in_aout_a"Cast* +to � CastToUInt64Z +in_a +  + + +b +out_a +   + + +B \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.res b/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.res new file mode 100644 index 0000000..cc9c94c --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xui64>[4, 92, 83, 47, 58, 68, 97, 43, 75, 57] +12 diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.json b/mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.json rename to mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.json diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.mlir b/mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.mlir similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.mlir rename to mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.mlir diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.onnx b/mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.onnx rename to mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.onnx diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.res b/mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.res similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/Resize/ResizeSimple.res rename to mlir-assigner/tests/Ops/Problematic/Resize/ResizeSimple.res