Skip to content

Commit

Permalink
feat: where and exponential ops (#19)
Browse files Browse the repository at this point in the history
* added where op
* added tests for exponential ops
* added sqrt and ReduceL2 functionality
* updated readme
  • Loading branch information
0xThemis authored Jan 4, 2024
1 parent 0b75763 commit 526bcf7
Show file tree
Hide file tree
Showing 32 changed files with 214 additions and 43 deletions.
44 changes: 44 additions & 0 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/sqrt.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_SQRT_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_SQRT_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_sqrt(
mlir::math::SqrtOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_sqrt<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::math::SqrtOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_SQRT_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/cos.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sqrt.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>

#define PREPARE_UNARY_INPUT(OP) \
Expand Down
5 changes: 4 additions & 1 deletion mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <mlir-assigner/components/fixedpoint/ceil.hpp>
#include <mlir-assigner/components/fixedpoint/division.hpp>
#include <mlir-assigner/components/fixedpoint/exp.hpp>
#include <mlir-assigner/components/fixedpoint/sqrt.hpp>
#include <mlir-assigner/components/fixedpoint/log.hpp>
#include <mlir-assigner/components/fixedpoint/floor.hpp>
#include <mlir-assigner/components/fixedpoint/mul_rescale.hpp>
Expand Down Expand Up @@ -376,6 +377,8 @@ namespace zk_ml_toolchain {
handle_fixedpoint_exp_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::LogOp operation = llvm::dyn_cast<math::LogOp>(op)) {
handle_fixedpoint_log_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::PowFOp operation = llvm::dyn_cast<math::PowFOp>(op)) {
UNREACHABLE("TODO: component for powf not ready");
} else if (math::AbsFOp operation = llvm::dyn_cast<math::AbsFOp>(op)) {
handle_fixedpoint_abs_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::CeilOp operation = llvm::dyn_cast<math::CeilOp>(op)) {
Expand All @@ -388,7 +391,7 @@ namespace zk_ml_toolchain {
frames.back().locals[mlir::hash_value(operation.getResult())] =
frames.back().locals[mlir::hash_value(operation.getLhs())];
} else if (math::SqrtOp operation = llvm::dyn_cast<math::SqrtOp>(op)) {
UNREACHABLE("TODO: component for sqrt not ready");
handle_sqrt(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::SinOp operation = llvm::dyn_cast<math::SinOp>(op)) {
handle_sin(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::CosOp operation = llvm::dyn_cast<math::CosOp>(op)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.25299072265625, 0.5795745849609375, 0.419647216796875, 0.5677490234375, 0.5795135498046875, 0.6259613037109375, 0.6096954345703125, 0.8977203369140625, 0.34307861328125, 0.5006256103515625], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "powpublicbase.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_b"], llvm.emit_c_interface, output_names = ["out_a"]} {
%c0 = arith.constant 0 : index
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<2.000000e+00> : tensor<1x10xf32>} : () -> memref<1x10xf32>
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%1 = affine.load %0[%c0, %arg2] : memref<1x10xf32>
%2 = affine.load %arg0[%c0, %arg2] : memref<1x10xf32>
%3 = math.powf %1, %2 : f32
affine.store %3, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_b\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ADD THE ROWS HERE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.919586181640625, 0.2015838623046875, 0.2564697265625, 0.3241424560546875, 0.3890228271484375, 0.4170989990234375, 0.6596832275390625, 0.7839508056640625, 0.1458587646484375, 0.4071502685546875], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "powpublicexponent.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%c0 = arith.constant 0 : index
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<3.000000e+00> : tensor<1x10xf32>} : () -> memref<1x10xf32>
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%1 = affine.load %arg0[%c0, %arg2] : memref<1x10xf32>
%2 = affine.load %0[%c0, %arg2] : memref<1x10xf32>
%3 = math.powf %1, %2 : f32
affine.store %3, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.7776377201080322, 0.00819157250225544, 0.016869736835360527, 0.03405710682272911, 0.058874230831861496, 0.07256337255239487, 0.2870822548866272, 0.4817996025085449, 0.0031031130347400904, 0.06749384850263596]
ADD THE ROWS HERE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.6160430908203125, 0.445709228515625, 0.0120391845703125, 0.0461273193359375, 0.1510009765625, 0.1910400390625, 0.03076171875, 0.043548583984375, 0.2318572998046875, 0.4149627685546875], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.05859375, 0.8663330078125, 0.753448486328125, 0.1318359375, 0.8713531494140625, 0.0277557373046875, 0.5159149169921875, 0.480560302734375, 0.127532958984375, 0.01123046875], "dims": [1, 10], "type": "f32"}}]
16 changes: 16 additions & 0 deletions mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "powsimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>, %arg1: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a", "in_b"], llvm.emit_c_interface, output_names = ["out_a"]} {
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg2 = 0 to 1 {
affine.for %arg3 = 0 to 10 {
%0 = affine.load %arg0[%c0, %arg3] : memref<1x10xf32>
%1 = affine.load %arg1[%c0, %arg3] : memref<1x10xf32>
%2 = math.powf %0, %1 : f32
affine.store %2, %alloc[%arg2, %arg3] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 2 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A , { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_b\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
19 changes: 19 additions & 0 deletions mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 :n

in_a
in_bout_a"Pow PowSimpleZ
in_a



Z
in_b



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.9720140099525452, 0.49654868245124817, 0.03579552844166756, 0.6665944457054138, 0.19257567822933197, 0.9550961256027222, 0.16593657433986664, 0.22179152071475983, 0.8299362659454346, 0.9901706576347351]
ADD THE ROWS HERE

This file was deleted.

1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Exp/ExpSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.7295379638671875, 0.6165008544921875, 0.676361083984375, 0.9804840087890625, 0.72015380859375, 0.775482177734375, 0.8330841064453125, 0.0845489501953125, 0.8875274658203125, 0.79486083984375], "dims": [1, 10], "type": "f32"}}]
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Exp/ExpSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :P

in_aout_a"Exp ExpSimpleZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Exp/ExpSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[2.0741219520568848, 1.852434754371643, 1.9667080640792847, 2.6657462120056152, 2.0547492504119873, 2.1716389656066895, 2.3004024028778076, 1.088226079940796, 2.4291162490844727, 2.214132785797119]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Log/LogSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.980804443359375, 0.6964569091796875, 0.820343017578125, 0.1553955078125, 0.2219390869140625, 0.8354949951171875, 0.4540557861328125, 0.4211883544921875, 0.133148193359375, 0.9198150634765625], "dims": [1, 10], "type": "f32"}}]
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Log/LogSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :P

in_aout_a"Log LogSimpleZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Log/LogSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[-0.019382184371352196, -0.3617493510246277, -0.1980327069759369, -1.8617817163467407, -1.5053523778915405, -0.17973092198371887, -0.7895352244377136, -0.8646751642227173, -2.0162925720214844, -0.08358264714479446]
70
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1x1xf32>[1.4269211292266846]
21
22
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Sqrt/SqrtSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.39276123046875, 0.1326446533203125, 0.533203125, 0.90936279296875, 0.8437042236328125, 0.28973388671875, 0.6842041015625, 0.512542724609375, 0.26873779296875, 0.184295654296875], "dims": [1, 10], "type": "f32"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Sqrt/SqrtSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :R

in_aout_a"Sqrt
SqrtSimpleZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Sqrt/SqrtSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.6267066597938538, 0.36420413851737976, 0.7302075624465942, 0.9536051750183105, 0.9185337424278259, 0.5382693409919739, 0.8271663188934326, 0.7159208655357361, 0.5183992385864258, 0.42929670214653015]
20
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Where/Where/WhereSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0, 1, 1, 0, 0, 1, 0, 1, 1, 0], "dims": [1, 10], "type": "bool"}}, {"memref": {"data": [0.1871185302734375, 0.6104278564453125, 0.41326904296875, 0.79345703125, 0.325286865234375, 0.1401519775390625, 0.7781219482421875, 0.2669830322265625, 0.8950042724609375, 0.134918212890625], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.2599639892578125, 0.578643798828125, 0.417266845703125, 0.1128997802734375, 0.112060546875, 0.093170166015625, 0.9546661376953125, 0.3450469970703125, 0.1695098876953125, 0.98834228515625], "dims": [1, 10], "type": "f32"}}]
25 changes: 25 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Where/Where/WhereSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
 :�

in_a
in_b
in_cout_a"Where WhereSimpleZ
in_a
 


Z
in_b



Z
in_c



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Where/Where/WhereSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.2599639892578125, 0.6104278564453125, 0.41326904296875, 0.1128997802734375, 0.112060546875, 0.1401519775390625, 0.9546661376953125, 0.2669830322265625, 0.8950042724609375, 0.98834228515625]
33
4 changes: 2 additions & 2 deletions mlir-assigner/tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ long as it is applicable for ZK).
| **Range** | :x: | :white_check_mark: | |
| **Reciprocal** | :white_check_mark: | :white_check_mark: | No support for integers at the moment. |
| **ReduceL1** | :white_check_mark: | :white_check_mark: | |
| **ReduceL2** | :x: | :white_check_mark: | |
| **ReduceL2** | :white_check_mark: | :white_check_mark: | |
| **ReduceLogSum** | :white_check_mark: | :white_check_mark: | |
| **ReduceLogSumExp** | :white_check_mark: | :white_check_mark: | |
| **ReduceMax** | :white_check_mark: | :white_check_mark: | |
Expand Down Expand Up @@ -265,7 +265,7 @@ long as it is applicable for ZK).
| **SpaceToDepth** | :x: | :white_check_mark: | |
| **Split** | :x: | :white_check_mark: | |
| **SplitToSequence** | :x: | :x: | |
| **Sqrt** | :x: | :white_check_mark: | |
| **Sqrt** | :white_check_mark: | :white_check_mark: | |
| **Squeeze** | :x: | :white_check_mark: | |
| **StringNormalizer** | :x: | :x: | |
| **Sub** | :white_check_mark: | :white_check_mark: | No support for integers at the moment. |
Expand Down

0 comments on commit 526bcf7

Please sign in to comment.