Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/trigonemtric pooling ops #16

Merged
merged 3 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream
Expand All @@ -51,7 +50,7 @@ namespace nil {
using component_type = components::lookup_logic_and<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::AndIOp);
auto input = PREPARE_BINARY_INPUT(mlir::arith::AndIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand All @@ -72,7 +71,7 @@ namespace nil {
// using component_type = components::lookup_logic_or<
// crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
//
// auto input = PREPARE_INPUT(mlir::arith::OrIOp);
// auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
// const auto p = detail::PolicyManager::get_parameters(
// detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
//
Expand All @@ -91,7 +90,7 @@ namespace nil {
using component_type = components::lookup_logic_xor<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;

auto input = PREPARE_INPUT(mlir::arith::XOrIOp);
auto input = PREPARE_BINARY_INPUT(mlir::arith::XOrIOp);
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_to_fixedpoint(
mlir::arith::SIToFPOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::int_to_fix<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::arith::SIToFPOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TRIGONOMETRIC_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_TRIGONOMETRIC_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_sin(
mlir::math::SinOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_sin<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::math::SinOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_cos(
mlir::math::CosOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_cos<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::math::CosOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TRIGONOMETRIC_HPP
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,38 @@
#ifndef CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP
#define CRYPTO3_ASSIGNER_HANDLE_COMPONENT_HPP

#define PREPARE_INPUT(OP) \
#include <mlir-assigner/memory/stack_frame.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/to_fixedpoint.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/cos.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>

#define PREPARE_UNARY_INPUT(OP) \
prepare_unary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
typename component_type::input_type>(operation, frame, bp, assignment);
#define PREPARE_BINARY_INPUT(OP) \
prepare_binary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
typename component_type::input_type>(operation, frame, bp, assignment);

#include <mlir-assigner/memory/stack_frame.hpp>

namespace nil {
namespace blueprint {
template<typename BlueprintFieldType, typename ArithmetizationParams, typename UnaryOp, typename input_type>
input_type prepare_unary_operation_input(
UnaryOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment) {

assert(operation->getNumOperands() == 1 && "unary operand must have only one operand");
auto operand = frame.locals.find(mlir::hash_value(operation->getOperand(0)));
ASSERT(operand != frame.locals.end());

input_type instance_input;
instance_input.x = operand->second;
return instance_input;
}
template<typename BlueprintFieldType, typename ArithmetizationParams, typename BinOp, typename input_type>
input_type prepare_binary_operation_input(
BinOp &operation,
Expand Down Expand Up @@ -80,11 +105,11 @@ namespace nil {
}
}

template<typename BlueprintFieldType, typename ArithmetizationParams, typename component_type, typename BinOp>
template<typename BlueprintFieldType, typename ArithmetizationParams, typename component_type, typename Op>
void fill_trace(
component_type &component,
typename component_type::input_type &input,
BinOp &mlir_op,
Op &mlir_op,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
Expand Down
37 changes: 35 additions & 2 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
#include <mlir-assigner/components/fixedpoint/remainder.hpp>
#include <mlir-assigner/components/fixedpoint/subtraction.hpp>
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/fixedpoint/trigonometric.hpp>
#include <mlir-assigner/components/boolean/and.hpp>
#include <mlir-assigner/components/fixedpoint/to_fixpoint.hpp>

#include <mlir-assigner/memory/memref.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
Expand Down Expand Up @@ -347,6 +349,15 @@ namespace zk_ml_toolchain {
logger << constantValue;
UNREACHABLE("unhandled constant");
}
} else if (arith::IndexCastOp operation = llvm::dyn_cast<arith::IndexCastOp>(op)) {
assert(operation->getNumOperands() == 1 && "IndexCast must have exactly one operand");
auto index = frames.back().constant_values[mlir::hash_value(operation->getOperand(0))];
typename BlueprintFieldType::value_type field_constant = index;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
} else if (arith::SIToFPOp operation = llvm::dyn_cast<arith::SIToFPOp>(op)) {
// TODO this does not respect negative and no different ranges for ints...
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled arith operation: ") + opName);
Expand All @@ -372,6 +383,14 @@ namespace zk_ml_toolchain {
frames.back().locals[mlir::hash_value(operation.getLhs())];
} else if (math::SqrtOp operation = llvm::dyn_cast<math::SqrtOp>(op)) {
UNREACHABLE("TODO: component for sqrt not ready");
} else if (math::SinOp operation = llvm::dyn_cast<math::SinOp>(op)) {
handle_sin(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::CosOp operation = llvm::dyn_cast<math::CosOp>(op)) {
handle_cos(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::AtanOp operation = llvm::dyn_cast<math::AtanOp>(op)) {
UNREACHABLE("TODO: component for atanh not ready");
} else if (math::TanhOp operation = llvm::dyn_cast<math::TanhOp>(op)) {
UNREACHABLE("TODO: component for tanh not ready");
} else if (math::ErfOp operation = llvm::dyn_cast<math::ErfOp>(op)) {
UNREACHABLE("TODO: component for erf not ready");
} else {
Expand Down Expand Up @@ -414,7 +433,9 @@ namespace zk_ml_toolchain {
assert(res != frames.back().constant_values.end());
mapDims.push_back(res->second);
}
auto affineMap = castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineLoadOp::getMapAttrStrName())).getAffineMap();
auto affineMap =
castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineLoadOp::getMapAttrStrName()))
.getAffineMap();
auto value = memref->second.get(evalAffineMap(affineMap, mapDims));
frames.back().locals[mlir::hash_value(operation.getResult())] = value;
} else if (affine::AffineStoreOp operation = llvm::dyn_cast<affine::AffineStoreOp>(op)) {
Expand All @@ -437,7 +458,9 @@ namespace zk_ml_toolchain {
auto value = frames.back().locals.find(mlir::hash_value(operation.getValue()));
assert(value != frames.back().locals.end());
// put the element from the memref using index vector
auto affineMap = castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineStoreOp::getMapAttrStrName())).getAffineMap();
auto affineMap =
castFromAttr<AffineMapAttr>(operation->getAttr(affine::AffineStoreOp::getMapAttrStrName()))
.getAffineMap();
memref->second.put(evalAffineMap(affineMap, mapDims), value->second);

} else if (affine::AffineYieldOp operation = llvm::dyn_cast<affine::AffineYieldOp>(op)) {
Expand Down Expand Up @@ -594,8 +617,18 @@ namespace zk_ml_toolchain {
return;
} else if (KrnlAcosOp operation = llvm::dyn_cast<KrnlAcosOp>(op)) {
UNREACHABLE(std::string("TODO KrnlAcos: link to bluebrint component"));
} else if (KrnlAsinOp operation = llvm::dyn_cast<KrnlAsinOp>(op)) {
UNREACHABLE(std::string("TODO KrnlSin: link to bluebrint component"));
} else if (KrnlAcoshOp operation = llvm::dyn_cast<KrnlAcoshOp>(op)) {
UNREACHABLE(std::string("TODO KrnlAcosh: link to bluebrint component"));
} else if (KrnlAsinhOp operation = llvm::dyn_cast<KrnlAsinhOp>(op)) {
UNREACHABLE(std::string("TODO KrnlSinh: link to bluebrint component"));
} else if (KrnlTanOp operation = llvm::dyn_cast<KrnlTanOp>(op)) {
UNREACHABLE("TODO: component for tan not ready");
} else if (KrnlAtanOp operation = llvm::dyn_cast<KrnlAtanOp>(op)) {
UNREACHABLE("TODO: component for atan not ready");
} else if (KrnlAtanhOp operation = llvm::dyn_cast<KrnlAtanhOp>(op)) {
UNREACHABLE("TODO: component for atanh not ready");
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled krnl operation: ") + opName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.6515045166015625, 0.259490966796875, 0.200439453125, 0.084564208984375, 0.223297119140625, 0.867034912109375, 0.0014801025390625, 0.4618377685546875, 0.059051513671875, 0.2735137939453125], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "acossimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32>
%1 = "krnl.acos"(%0) : (f32) -> f32
affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :R

in_aout_a"Acos
AcosSimpleZ
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.8612304329872131, 1.3083012104034424, 1.3689898252487183, 1.486130952835083, 1.3456006050109863, 0.5215762257575989, 1.5693162679672241, 1.0907303094863892, 1.5117104053497314, 1.2937520742416382]
ADD THE ROWS HERE
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.5222625732421875, 0.6072845458984375, 0.3314361572265625, 0.698211669921875, 0.192108154296875, 0.699432373046875, 0.5330963134765625, 0.62982177734375, 0.908538818359375, 0.012664794921875], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "acoshsimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32>
%1 = "krnl.acosh"(%0) : (f32) -> f32
affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
return %alloc : memref<1x10xf32>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :T

in_aout_a"Acosh AcoshSimpleZ
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>['-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan', '-nan']
ADD THE ROWS HERE
Loading