Skip to content

Commit

Permalink
feat: trigonemtric pooling ops (#16)
Browse files Browse the repository at this point in the history
* added AveragePool; int to fix no range check
* added sin,sinh,cos,cosh
* rest of trigonometric tests
  • Loading branch information
0xThemis committed Jan 3, 2024
1 parent e76da81 commit 0b30bf6
Show file tree
Hide file tree
Showing 53 changed files with 510 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
Expand All @@ -51,7 +50,7 @@ namespace nil {
using component_type = components::lookup_logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::AndIOp);
auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand All @@ -72,7 +71,7 @@ namespace nil {
// using component_type = components::lookup_logic_or<
// crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
//
// auto input = PREPARE_INPUT(mlir::arith::OrIOp);
// auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
// const auto p = detail::PolicyManager::get_parameters(
// detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
//
Expand All @@ -91,7 +90,7 @@ namespace nil {
using component_type = components::lookup_logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::XOrIOp);
auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_to_fixedpoint(
mlir::arith::SIToFPOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::int_to_fix<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::arith::SIToFPOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TRIGONOMETRIC_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_TRIGONOMETRIC_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_sin(
mlir::math::SinOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_sin<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::math::SinOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_cos(
mlir::math::CosOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_cos<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::math::CosOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TRIGONOMETRIC_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,38 @@
#ifndef CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP
#define CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP

#define PREPARE_INPUT(OP) \
#include <mlir-assigner/memory/stack_frame.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/to_fixedpoint.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/cos.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>

#define PREPARE_UNARY_INPUT(OP) \
prepare_unary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
typename component_type::input_type>(operation, frame, bp, assignment);
#define PREPARE_BINARY_INPUT(OP) \
prepare_binary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
typename component_type::input_type>(operation, frame, bp, assignment);

#include <mlir-assigner/memory/stack_frame.hpp>

namespace nil {
namespace blueprint {
template<typename BlueprintFieldType, typename ArithmetizationParams, typename UnaryOp, typename input_type>
input_type prepare_unary_operation_input(
UnaryOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {

assert(operation->getNumOperands() == 1 && "unary operand must have only one operand");
auto operand = frame.locals.find(mlir::hash_value(operation->getOperand(0)));
ASSERT(operand != frame.locals.end());

input_type instance_input;
instance_input.x = operand->second;
return instance_input;
}
template<typename BlueprintFieldType, typename ArithmetizationParams, typename BinOp, typename input_type>
input_type prepare_binary_operation_input(
BinOp &operation,
Expand Down Expand Up @@ -80,11 +105,11 @@ namespace nil {
}
}

template<typename BlueprintFieldType, typename ArithmetizationParams, typename component_type, typename BinOp>
template<typename BlueprintFieldType, typename ArithmetizationParams, typename component_type, typename Op>
void fill_trace(
component_type &component,
typename component_type::input_type &input,
BinOp &mlir_op,
Op &mlir_op,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
Expand Down
37 changes: 35 additions & 2 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
#include <mlir-assigner/components/fixedpoint/remainder.hpp>
#include <mlir-assigner/components/fixedpoint/subtraction.hpp>
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/fixedpoint/trigonometric.hpp>
#include <mlir-assigner/components/boolean/and.hpp>
#include <mlir-assigner/components/fixedpoint/to_fixpoint.hpp>

#include <mlir-assigner/memory/memref.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
Expand Down Expand Up @@ -347,6 +349,15 @@ namespace zk_ml_toolchain {
logger << constantValue;
UNREACHABLE("unhandled constant");
}
} else if (arith::IndexCastOp operation = llvm::dyn_cast<arith::IndexCastOp>(op)) {
assert(operation->getNumOperands() == 1 && "IndexCast must have exactly one operand");
auto index = frames.back().constant_values[mlir::hash_value(operation->getOperand(0))];
typename BlueprintFieldType::value_type field_constant = index;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
} else if (arith::SIToFPOp operation = llvm::dyn_cast<arith::SIToFPOp>(op)) {
// TODO this does not respect negative and no different ranges for ints...
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled arith operation: ") + opName);
Expand All @@ -372,6 +383,14 @@ namespace zk_ml_toolchain {
frames.back().locals[mlir::hash_value(operation.getLhs())];
} else if (math::SqrtOp operation = llvm::dyn_cast<math::SqrtOp>(op)) {
UNREACHABLE("TODO: component for sqrt not ready");
} else if (math::SinOp operation = llvm::dyn_cast<math::SinOp>(op)) {
handle_sin(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::CosOp operation = llvm::dyn_cast<math::CosOp>(op)) {
handle_cos(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::AtanOp operation = llvm::dyn_cast<math::AtanOp>(op)) {
UNREACHABLE("TODO: component for atanh not ready");
} else if (math::TanhOp operation = llvm::dyn_cast<math::TanhOp>(op)) {
UNREACHABLE("TODO: component for tanh not ready");
} else if (math::ErfOp operation = llvm::dyn_cast<math::ErfOp>(op)) {
UNREACHABLE("TODO: component for erf not ready");
} else {
Expand Down Expand Up @@ -414,7 +433,9 @@ namespace zk_ml_toolchain {
assert(res != frames.back().constant_values.end());
mapDims.push_back(res->second);
}
auto affineMap = castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineLoadOp::getMapAttrStrName())).getAffineMap();
auto affineMap =
castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineLoadOp::getMapAttrStrName()))
.getAffineMap();
auto value = memref->second.get(evalAffineMap(affineMap, mapDims));
frames.back().locals[mlir::hash_value(operation.getResult())] = value;
} else if (affine::AffineStoreOp operation = llvm::dyn_cast<affine::AffineStoreOp>(op)) {
Expand All @@ -437,7 +458,9 @@ namespace zk_ml_toolchain {
auto value = frames.back().locals.find(mlir::hash_value(operation.getValue()));
assert(value != frames.back().locals.end());
// put the element from the memref using index vector
auto affineMap = castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineStoreOp::getMapAttrStrName())).getAffineMap();
auto affineMap =
castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineStoreOp::getMapAttrStrName()))
.getAffineMap();
memref->second.put(evalAffineMap(affineMap, mapDims), value->second);

} else if (affine::AffineYieldOp operation = llvm::dyn_cast<affine::AffineYieldOp>(op)) {
Expand Down Expand Up @@ -594,8 +617,18 @@ namespace zk_ml_toolchain {
return;
} else if (KrnlAcosOp operation = llvm::dyn_cast<KrnlAcosOp>(op)) {
UNREACHABLE(std::string("TODO KrnlAcos: link to bluebrint component"));
} else if (KrnlAsinOp operation = llvm::dyn_cast<KrnlAsinOp>(op)) {
UNREACHABLE(std::string("TODO KrnlSin: link to bluebrint component"));
} else if (KrnlAcoshOp operation = llvm::dyn_cast<KrnlAcoshOp>(op)) {
UNREACHABLE(std::string("TODO KrnlAcosh: link to bluebrint component"));
} else if (KrnlAsinhOp operation = llvm::dyn_cast<KrnlAsinhOp>(op)) {
UNREACHABLE(std::string("TODO KrnlSinh: link to bluebrint component"));
} else if (KrnlTanOp operation = llvm::dyn_cast<KrnlTanOp>(op)) {
UNREACHABLE("TODO: component for tan not ready");
} else if (KrnlAtanOp operation = llvm::dyn_cast<KrnlAtanOp>(op)) {
UNREACHABLE("TODO: component for atan not ready");
} else if (KrnlAtanhOp operation = llvm::dyn_cast<KrnlAtanhOp>(op)) {
UNREACHABLE("TODO: component for atanh not ready");
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled krnl operation: ") + opName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.6515045166015625, 0.259490966796875, 0.200439453125, 0.084564208984375, 0.223297119140625, 0.867034912109375, 0.0014801025390625, 0.4618377685546875, 0.059051513671875, 0.2735137939453125], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "acossimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32>
%1 = "krnl.acos"(%0) : (f32) -> f32
affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :R

in_aout_a"Acos
AcosSimpleZ
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.8612304329872131, 1.3083012104034424, 1.3689898252487183, 1.486130952835083, 1.3456006050109863, 0.5215762257575989, 1.5693162679672241, 1.0907303094863892, 1.5117104053497314, 1.2937520742416382]
ADD THE ROWS HERE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.5222625732421875, 0.6072845458984375, 0.3314361572265625, 0.698211669921875, 0.192108154296875, 0.699432373046875, 0.5330963134765625, 0.62982177734375, 0.908538818359375, 0.012664794921875], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "acoshsimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32>
%1 = "krnl.acosh"(%0) : (f32) -> f32
affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :T

in_aout_a"Acosh AcoshSimpleZ
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>['-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan']
ADD THE ROWS HERE
Loading

0 comments on commit 0b30bf6

Please sign in to comment.