Skip to content

Commit

Permalink
feat: added log handling and activation functions test (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis authored Dec 21, 2023
1 parent 11e2275 commit 7b9f76a
Show file tree
Hide file tree
Showing 37 changed files with 277 additions and 39 deletions.
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# build the mlir-assigner
build:
make -C build/ -j 12
make -C build/ -j 12 zkml-onnx-compiler mlir-assigner

# setsup the build folder
setup-build:
Expand Down
92 changes: 92 additions & 0 deletions mlir-assigner/include/mlir-assigner/components/fixedpoint/log.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_LOG_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_LOG_HPP

#include <mlir/Dialect/Math/IR/Math.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/log.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/policy/policy_manager.hpp>

namespace nil {
namespace blueprint {
namespace detail {

template<typename BlueprintFieldType, typename ArithmetizationParams>
typename components::fix_log<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>::result_type
handle_fixedpoint_log_component(
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>
x,
circuit_proxy<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType,
ArithmetizationParams>> &assignment,
std::uint32_t start_row) {

using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

using component_type = components::fix_log<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
const auto p = PolicyManager::get_parameters(
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));
component_type component_instance(
p.witness,
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_constants(),
ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_public_inputs(),
1, 1);

if constexpr (nil::blueprint::use_custom_lookup_tables<component_type>()) {
auto lookup_tables = component_instance.component_custom_lookup_tables();
for (auto &t : lookup_tables) {
bp.register_lookup_table(
std::shared_ptr<nil::crypto3::zk::snark::lookup_table_definition<BlueprintFieldType>>(t));
}
};

if constexpr (nil::blueprint::use_lookups<component_type>()) {
auto lookup_tables = component_instance.component_lookup_tables();
for (auto &[k, v] : lookup_tables) {
bp.reserve_table(k);
}
};

// TACEO_TODO in the previous line I hardcoded 1 for now!!! CHANGE THAT
// TACEO_TODO make an assert that both have the same scale?
// TACEO_TODO we probably have to extract the field element from the type here

components::generate_circuit(component_instance, bp, assignment, {x}, start_row);
return components::generate_assignments(component_instance, assignment, {x}, start_row);
}

} // namespace detail
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_fixedpoint_log_component(
mlir::math::LogOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
auto operand = frame.locals.find(mlir::hash_value(operation.getOperand()));
ASSERT(operand != frame.locals.end());

auto x = operand->second;

// TACEO_TODO: check types

auto result = detail::handle_fixedpoint_log_component(x, bp, assignment, start_row);
frame.locals[mlir::hash_value(operation.getResult())] = result.output;
}
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_LOG_HPP
5 changes: 5 additions & 0 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <mlir-assigner/components/fixedpoint/ceil.hpp>
#include <mlir-assigner/components/fixedpoint/division.hpp>
#include <mlir-assigner/components/fixedpoint/exp.hpp>
#include <mlir-assigner/components/fixedpoint/log.hpp>
#include <mlir-assigner/components/fixedpoint/floor.hpp>
#include <mlir-assigner/components/fixedpoint/mul_rescale.hpp>
#include <mlir-assigner/components/fixedpoint/neg.hpp>
Expand Down Expand Up @@ -323,6 +324,8 @@ namespace zk_ml_toolchain {
std::uint32_t start_row = assignmnt.allocated_rows();
if (math::ExpOp operation = llvm::dyn_cast<math::ExpOp>(op)) {
handle_fixedpoint_exp_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::LogOp operation = llvm::dyn_cast<math::LogOp>(op)) {
handle_fixedpoint_log_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::AbsFOp operation = llvm::dyn_cast<math::AbsFOp>(op)) {
handle_fixedpoint_abs_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::CeilOp operation = llvm::dyn_cast<math::CeilOp>(op)) {
Expand All @@ -336,6 +339,8 @@ namespace zk_ml_toolchain {
frames.back().locals[mlir::hash_value(operation.getLhs())];
} else if (math::SqrtOp operation = llvm::dyn_cast<math::SqrtOp>(op)) {
UNREACHABLE("TODO: sqrt");
} else if (math::ErfOp operation = llvm::dyn_cast<math::ErfOp>(op)) {
UNREACHABLE("TODO: component for erf not ready");
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled math operation: ") + opName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.68353271484375, 0.50244140625, 0.590240478515625, 0.337371826171875, 0.681549072265625, 0.1300048828125, 0.109771728515625, 0.8895263671875, 0.0695648193359375, 0.1826171875], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu", "onnx-mlir.symbol-postfix" = "acoshsimple.0.mlir"} {
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "erfsimple.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32>
%1 = "krnl.acosh"(%0) : (f32) -> f32
%1 = math.erf %0 : f32
affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xf32>
}
}
Expand Down
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/NeedsBlueprintComponent/Erf/ErfSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :P

in_aout_a"Erf ErfSimpleZ
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.6662865281105042, 0.5226426720619202, 0.596127986907959, 0.3667203485965729, 0.6648818254470825, 0.14587253332138062, 0.12336840480566025, 0.7916010618209839, 0.07836905121803284, 0.2037934958934784]
ADD THE ROWS HERE
16 changes: 0 additions & 16 deletions mlir-assigner/tests/Ops/Onnx/And/AndSimple.mlir

This file was deleted.

16 changes: 0 additions & 16 deletions mlir-assigner/tests/Ops/Onnx/Div/DivSimple.mlir

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.389190673828125, 0.105712890625, 0.56243896484375, 0.2190704345703125, 0.134674072265625, 0.9574127197265625, 0.547149658203125, 0.1509857177734375, 0.1473541259765625, 0.0181121826171875], "dims": [1, 10], "type": "f32"}}]
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.577838122844696, 0.5211426019668579, 0.61248779296875, 0.5438140630722046, 0.526934802532196, 0.6914825439453125, 0.6094299554824829, 0.5301971435546875, 0.5294708013534546, 0.5036224126815796]
60
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.2706146240234375, 0.7530059814453125, 0.69873046875, 0.49169921875, 0.7146148681640625, 0.91375732421875, 0.047088623046875, 0.3588714599609375, 0.4225006103515625, 0.7489166259765625], "dims": [1, 10], "type": "f32"}}]
15 changes: 15 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/LeakyRelu/LeakyReluSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
 :m
)
in_aout_a" LeakyRelu*
alpha
�#<�LeakyReluSimpleZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/LeakyRelu/LeakyReluSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.2706146240234375, 0.7530059814453125, 0.69873046875, 0.49169921875, 0.7146148681640625, 0.91375732421875, 0.047088623046875, 0.3588714599609375, 0.4225006103515625, 0.7489166259765625]
30
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.9263153076171875, 0.5552978515625, 0.488739013671875, 0.1752166748046875, 0.75836181640625, 0.498870849609375, 0.184600830078125, 0.7884979248046875, 0.725982666015625, 0.805877685546875], "dims": [1, 10], "type": "f32"}}]
15 changes: 15 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/LogSoftmax/LogSoftmaxBasicMnist.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
 :o
&
in_aout_a"
LogSoftmax*
axis�LogSoftmaxBasicMnistZ
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[-1.9957325458526611, -2.3667500019073486, -2.4333088397979736, -2.746831178665161, -2.1636860370635986, -2.4231770038604736, -2.7374470233917236, -2.133549928665161, -2.1960651874542236, -2.1161701679229736]
140
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/PRelu/PReluConst.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.7888336181640625, 0.274871826171875, 0.862274169921875, 0.5487213134765625, 0.911895751953125, 0.7361602783203125, 0.2882537841796875, 0.3439788818359375, 0.0133514404296875, 0.8032379150390625], "dims": [1, 10], "type": "f32"}}]
16 changes: 16 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/PRelu/PReluConst.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 :�

in_a
in_bout_a"PRelu
PReluConst*6
"(���=���=���=���=���=���=���=���=���=���=Bin_bZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/PRelu/PReluConst.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.7888336181640625, 0.274871826171875, 0.862274169921875, 0.5487213134765625, 0.911895751953125, 0.7361602783203125, 0.2882537841796875, 0.3439788818359375, 0.0133514404296875, 0.8032379150390625]
30
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/PRelu/PReluSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.891021728515625, 0.5626983642578125, 0.2191619873046875, 0.7632598876953125, 0.9755706787109375, 0.7555694580078125, 0.72698974609375, 0.3393096923828125, 0.0352630615234375, 0.398712158203125], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.0781097412109375, 0.0261383056640625, 0.6385345458984375, 0.70574951171875, 0.1352691650390625, 0.2262420654296875, 0.101715087890625, 0.2310028076171875, 0.1707611083984375, 0.8590545654296875], "dims": [1, 10], "type": "f32"}}]
19 changes: 19 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/PRelu/PReluSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 :r

in_a
in_bout_a"PRelu PReluSimpleZ
in_a



Z
in_b



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/PRelu/PReluSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.891021728515625, 0.5626983642578125, 0.2191619873046875, 0.7632598876953125, 0.9755706787109375, 0.7555694580078125, 0.72698974609375, 0.3393096923828125, 0.0352630615234375, 0.398712158203125]
30
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Selu/SeluSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.335845947265625, 0.8036041259765625, 0.855743408203125, 0.6555633544921875, 0.09368896484375, 0.340057373046875, 0.5518798828125, 0.47216796875, 0.814453125, 0.56549072265625], "dims": [1, 10], "type": "f32"}}]
16 changes: 16 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Selu/SeluSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 :t
5
in_aout_a"Selu*
alphab-�?�*
gammaV}�?�
SeluSimpleZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Selu/SeluSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.3528733253479004, 0.844346821308136, 0.8991295695304871, 0.6888003945350647, 0.09843899309635162, 0.35729825496673584, 0.5798601508140564, 0.4961068630218506, 0.8557458519935608, 0.5941610932350159]
60
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Sigmoid/SigmoidSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.420074462890625, 0.58660888671875, 0.7588348388671875, 0.926849365234375, 0.4851531982421875, 0.1836090087890625, 0.9698486328125, 0.407806396484375, 0.208160400390625, 0.3989105224609375], "dims": [1, 10], "type": "f32"}}]
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Sigmoid/SigmoidSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :X

in_aout_a"SigmoidSigmoidSimpleZ
in_a



b
out_a



B
Expand Down
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Sigmoid/SigmoidSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.6035010814666748, 0.6425867080688477, 0.6811007261276245, 0.7164356112480164, 0.6189640164375305, 0.5457737445831299, 0.7250893712043762, 0.600561797618866, 0.5518530011177063, 0.5984258651733398]
50
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.0689697265625, 0.609283447265625, 0.095550537109375, 0.4514007568359375, 0.439605712890625, 0.26214599609375, 0.4383544921875, 0.354248046875, 0.21820068359375, 0.04864501953125], "dims": [1, 10], "type": "f32"}}]
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Softplus/SoftplusBasicMnist.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :^

in_aout_a"SoftplusSoftplusBasicMnistZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Softplus/SoftplusBasicMnist.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.7282265424728394, 1.0434917211532593, 0.7420632839202881, 0.9441044926643372, 0.9369146823883057, 0.8327857255935669, 0.9361538290977478, 0.8858763575553894, 0.8081871867179871, 0.7177654504776001]
90
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.293365478515625, 0.330474853515625, 0.533905029296875, 0.4608154296875, 0.0807037353515625, 0.4675445556640625, 0.616790771484375, 0.2982177734375, 0.3966522216796875, 0.8043975830078125], "dims": [1, 10], "type": "f32"}}]
13 changes: 13 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Softsign/SoftsignBasicMnist.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
 :^

in_aout_a"SoftsignSoftsignBasicMnistZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Softsign/SoftsignBasicMnist.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.2268233448266983, 0.2483886480331421, 0.3480691611766815, 0.3154508173465729, 0.07467702031135559, 0.3185896873474121, 0.3814907670021057, 0.22971321642398834, 0.2840021550655365, 0.44579842686653137]
40
Loading

0 comments on commit 7b9f76a

Please sign in to comment.