Skip to content

Commit

Permalink
feat: comparasion ops (#18)
Browse files Browse the repository at this point in the history
* Basic Comparisions
* stub for icmp
* pinned net onnx-mlir patch
* pinned new zk-ml-dialect
* ArgMin and ArgMax 
* hardmax moved to todo
  • Loading branch information
0xThemis authored Jan 4, 2024
1 parent 023c4a5 commit 4d18e8a
Show file tree
Hide file tree
Showing 54 changed files with 493 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_ARGMINMAX_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_ARGMINMAX_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_argmin(
mlir::zkml::ArgMinOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &nextIndex,
std::uint32_t start_row) {
using component_type = components::fix_argmin<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using input_type = typename component_type::input_type;
auto acc = frame.locals.find(mlir::hash_value(operation.getAcc()));
auto next = frame.locals.find(mlir::hash_value(operation.getNext()));
auto accIndex = frame.locals.find(mlir::hash_value(operation.getAccIndex()));
ASSERT(acc != frame.locals.end());
ASSERT(next != frame.locals.end());
ASSERT(accIndex != frame.locals.end());
input_type instance_input;
instance_input.x = acc->second;
instance_input.y = next->second;
instance_input.index_x = accIndex->second;

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1, var_value(assignment, nextIndex), operation.getSelectLastIndex());

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

handle_component_input<BlueprintFieldType, ArithmetizationParams, component_type>(assignment, instance_input);

components::generate_circuit(component, bp, assignment, instance_input, start_row);
auto result = components::generate_assignments(component, assignment, instance_input, start_row);
frame.locals[mlir::hash_value(operation.getResult(0))] = result.min;
frame.locals[mlir::hash_value(operation.getResult(1))] = result.index;
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_argmax(
mlir::zkml::ArgMaxOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type> &nextIndex,
std::uint32_t start_row) {
using component_type = components::fix_argmax<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using input_type = typename component_type::input_type;
auto acc = frame.locals.find(mlir::hash_value(operation.getAcc()));
auto next = frame.locals.find(mlir::hash_value(operation.getNext()));
auto accIndex = frame.locals.find(mlir::hash_value(operation.getAccIndex()));
ASSERT(acc != frame.locals.end());
ASSERT(next != frame.locals.end());
ASSERT(accIndex != frame.locals.end());
input_type instance_input;
instance_input.x = acc->second;
instance_input.y = next->second;
instance_input.index_x = accIndex->second;

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1, var_value(assignment, nextIndex), operation.getSelectLastIndex());

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

handle_component_input<BlueprintFieldType, ArithmetizationParams, component_type>(assignment, instance_input);

components::generate_circuit(component, bp, assignment, instance_input, start_row);
auto result = components::generate_assignments(component, assignment, instance_input, start_row);
frame.locals[mlir::hash_value(operation.getResult(0))] = result.max;
frame.locals[mlir::hash_value(operation.getResult(1))] = result.index;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_ARGMINMAX_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ namespace nil {
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {
&assignment,
std::uint32_t start_row) {

auto lhs = frame.memrefs.find(mlir::hash_value(operation.getLhs()));
ASSERT(lhs != frame.memrefs.end());
Expand All @@ -96,8 +97,7 @@ namespace nil {

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_dot_product_component(x, y, zero_var, bp, assignment,
assignment.allocated_rows());
auto result = detail::handle_fixedpoint_dot_product_component(x, y, zero_var, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/to_fixedpoint.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/cos.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>

#define PREPARE_UNARY_INPUT(OP) \
Expand Down
7 changes: 1 addition & 6 deletions mlir-assigner/include/mlir-assigner/memory/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ namespace nil {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void print(
std::ostream& os,
std::ostream &os,
const assignment<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {
os << "memref<";
Expand All @@ -133,16 +133,11 @@ namespace nil {
ss << type << ">[";
os << type_str;
if (type.isa<mlir::IntegerType>()) {
if (type.getIntOrFloatBitWidth() == 1) {
//bool
for (int i = 0; i < data.size(); i++) {
os << var_value(assignment, data[i]).data;
if (i != data.size() - 1)
os << ",";
}
} else {
//int
}
} else if (type.isa<mlir::FloatType>()) {
for (int i = 0; i < data.size(); i++) {
auto value = var_value(assignment, data[i]).data;
Expand Down
19 changes: 18 additions & 1 deletion mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include "mlir/Dialect/zkml/IR/ArgMin.h"
#include "mlir/Dialect/zkml/IR/ArgMax.h"

#include <cstddef>
#include <cstdlib>
Expand All @@ -35,6 +37,7 @@
#include <nil/blueprint/components/algebra/fixedpoint/type.hpp>

#include <mlir-assigner/components/comparison/fixed_comparison.hpp>
#include <mlir-assigner/components/comparison/argminmax.hpp>
#include <mlir-assigner/components/comparison/select.hpp>
#include <mlir-assigner/components/fixedpoint/abs.hpp>
#include <mlir-assigner/components/fixedpoint/addition.hpp>
Expand Down Expand Up @@ -318,6 +321,9 @@ namespace zk_ml_toolchain {
auto result = lhs->second * rhs->second;
frames.back().constant_values[mlir::hash_value(operation.getResult())] = result;

} else if (arith::CmpIOp operation = llvm::dyn_cast<arith::CmpIOp>(op)) {
llvm::outs() << "icmp\n";
exit(0);
} else if (arith::ConstantOp operation = llvm::dyn_cast<arith::ConstantOp>(op)) {
TypedAttr constantValue = operation.getValueAttr();
if (constantValue.isa<IntegerAttr>()) {
Expand Down Expand Up @@ -636,15 +642,26 @@ namespace zk_ml_toolchain {
}

void handleZkMlOperation(Operation *op) {
std::uint32_t start_row = assignmnt.allocated_rows();
if (zkml::DotProductOp operation = llvm::dyn_cast<zkml::DotProductOp>(op)) {
mlir::Value lhs = operation.getLhs();
mlir::Value rhs = operation.getRhs();
assert(lhs.getType() == rhs.getType() && "memrefs must be same type for DotProduct");
mlir::MemRefType MemRefType = mlir::cast<mlir::MemRefType>(lhs.getType());
assert(MemRefType.getShape().size() == 1 && "DotProduct must have tensors of rank 1");
logger.debug("computing DotProduct with %d x %d", MemRefType.getShape().back());
handle_fixedpoint_dot_product_component(operation, zero_var, frames.back(), bp, assignmnt);
handle_fixedpoint_dot_product_component(operation, zero_var, frames.back(), bp, assignmnt, start_row);
return;
} else if (zkml::ArgMinOp operation = llvm::dyn_cast<zkml::ArgMinOp>(op)) {
auto nextIndex = frames.back().constant_values.find(mlir::hash_value(operation.getNextIndex()));
ASSERT(nextIndex != frames.back().constant_values.end());
auto nextIndexVar = put_into_assignment(nextIndex->second);
handle_argmin(operation, frames.back(), bp, assignmnt, nextIndexVar, start_row);
} else if (zkml::ArgMaxOp operation = llvm::dyn_cast<zkml::ArgMaxOp>(op)) {
auto nextIndex = frames.back().constant_values.find(mlir::hash_value(operation.getNextIndex()));
ASSERT(nextIndex != frames.back().constant_values.end());
auto nextIndexVar = put_into_assignment(nextIndex->second);
handle_argmax(operation, frames.back(), bp, assignmnt, nextIndexVar, start_row);
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled zkML operation: ") + opName);
Expand Down
1 change: 1 addition & 0 deletions mlir-assigner/tests/Models/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**.mlir

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions mlir-assigner/tests/Ops/NeedsCompilerWork/Hardmax/HardMaxAxis.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "hardmaxaxis.mlir"} {
func.func @main_graph(%arg0: memref<1x10x20x30xf32>) -> memref<1x10x20x30xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10x20x30xf32>
%alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1x10x20x30xindex>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 20 {
affine.for %arg4 = 0 to 30 {
affine.store %c0, %alloc_1[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xindex>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 20 {
affine.for %arg4 = 0 to 30 {
%0 = affine.load %alloc_1[%c0, %arg2, %arg3, %arg4] : memref<1x10x20x30xindex>
%1 = memref.load %arg0[%0, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
%2 = affine.load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
%3 = arith.cmpf ogt, %2, %1 : f32
scf.if %3 {
affine.store %arg1, %alloc_1[%c0, %arg2, %arg3, %arg4] : memref<1x10x20x30xindex>
}
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 20 {
affine.for %arg4 = 0 to 30 {
%0 = affine.load %alloc_1[%c0, %arg2, %arg3, %arg4] : memref<1x10x20x30xindex>
%1 = arith.cmpi eq, %0, %arg1 : index
scf.if %1 {
affine.store %cst_0, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
} else {
affine.store %cst, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
}
}
}
}
}
return %alloc : memref<1x10x20x30xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10 , 20 , 30] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10 , 20 , 30] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Binary file not shown.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "hardmaxsimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10x20x30xf32>) -> memref<1x10x20x30xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10x20x30xf32>
%alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1x10x20x1xindex>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 20 {
affine.for %arg4 = 0 to 1 {
affine.store %c0, %alloc_1[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x1xindex>
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 20 {
affine.for %arg4 = 0 to 30 {
%0 = affine.load %alloc_1[%arg1, %arg2, %arg3, %c0] : memref<1x10x20x1xindex>
%1 = memref.load %arg0[%arg1, %arg2, %arg3, %0] : memref<1x10x20x30xf32>
%2 = affine.load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
%3 = arith.cmpf ogt, %2, %1 : f32
scf.if %3 {
affine.store %arg4, %alloc_1[%arg1, %arg2, %arg3, %c0] : memref<1x10x20x1xindex>
}
}
}
}
}
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 20 {
affine.for %arg4 = 0 to 30 {
%0 = affine.load %alloc_1[%arg1, %arg2, %arg3, %c0] : memref<1x10x20x1xindex>
%1 = arith.cmpi eq, %0, %arg4 : index
scf.if %1 {
affine.store %cst_0, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
} else {
affine.store %cst, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x10x20x30xf32>
}
}
}
}
}
return %alloc : memref<1x10x20x30xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10 , 20 , 30] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10 , 20 , 30] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
 :h

in_aout_a"HardmaxHardMaxSimpleZ
in_a





b
out_a





B
Expand Down

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**.mlir
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/ArgMax/ArgMaxSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.0930633544921875, 0.1421966552734375, 0.0830841064453125, 0.876251220703125, 0.2035980224609375, 0.38433837890625, 0.6446533203125, 0.8219451904296875, 0.4010162353515625, 0.7171173095703125], "dims": [1, 10], "type": "f32"}}]
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/ArgMax/ArgMaxSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :c
"
in_aout_a"ArgMax*
axis� ArgMaxSimpleZ
in_a



b
out_a


B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/ArgMax/ArgMaxSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x1xi64>[3]
23
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/ArgMin/ArgMinLastIndex.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.2682342529296875, 0.8263092041015625, 0.6944732666015625, 0.6793212890625, 0.09832763671875, 0.728607177734375, 0.28619384765625, 0.761749267578125, 0.6746368408203125, 0.40509033203125], "dims": [1, 10], "type": "f32"}}]
Binary file not shown.
Loading

0 comments on commit 4d18e8a

Please sign in to comment.