Skip to content

Commit

Permalink
feat: MatMulInteger support (#49)
Browse files Browse the repository at this point in the history
* matmulint working
* tests: updated randomforest rows
  • Loading branch information
0xThemis authored Feb 15, 2024
1 parent a13bb8d commit d91fb31
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_DOT_PRODUCT_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_DOT_PRODUCT_HPP

#include "mlir-assigner/memory/memref.hpp"
#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <cstdint>
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>
Expand All @@ -26,28 +28,49 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, PreLimbs, PostLimbs>;
mlir::Value lhs = operation.getLhs();
mlir::Value rhs = operation.getRhs();
assert(lhs.getType() == rhs.getType() && "memrefs must be same type for DotProduct");
mlir::MemRefType MemRefType = mlir::cast<mlir::MemRefType>(lhs.getType());
assert(MemRefType.getShape().size() == 1 && "DotProduct must have tensors of rank 1");

auto &x = stack.get_memref(operation.getLhs());
auto &y = stack.get_memref(operation.getRhs());
auto dims = x.getDims();
ASSERT(dims.size() == 1 && "must be one-dim for dot product");
const auto p =
detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, dims.front(), PostLimbs));
component_type component_instance(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), dims.front(), PostLimbs);

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
dims.front(), PostLimbs);
typename component_type::input_type input = {x.getData(), y.getData(), zero_var};

fill_trace(component, input, operation, stack, bp, assignment, compParams);
if (MemRefType.getElementType().isa<mlir::IntegerType>()) {
using ComponentType = components::dot_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
using manifest_reader = detail::ManifestReader<ComponentType, ArithmetizationParams>;
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, dims.front()));
ComponentType component(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), dims.front());
typename ComponentType::input_type input = {x.getData(), y.getData(), zero_var};
fill_trace(component, input, operation, stack, bp, assignment, compParams);
} else if (MemRefType.getElementType().isa<mlir::FloatType>()) {
using ComponentType = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
// FIXME the PreLimbs paramter here gets ignored. It is in fact
using manifest_reader = detail::ManifestReader<ComponentType, ArithmetizationParams, PostLimbs>;
const auto p =
detail::PolicyManager::get_parameters(manifest_reader::get_witness(0, dims.front(), PostLimbs));
ComponentType component(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), dims.front(), PostLimbs);
typename ComponentType::input_type input = {x.getData(), y.getData(), zero_var};
fill_trace(component, input, operation, stack, bp, assignment, compParams);
} else {
UNREACHABLE("Unsupported type for dot-product. Only floats and ints supported");
}
}

} // namespace blueprint
} // namespace nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/erf.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/div.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/rem.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/select.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/dot_rescale_2_gates.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/dot_2_gates.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_and.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_or.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/bitwise_xor.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/dot_rescale_2_gates.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/logic_ops.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/addition.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/select.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/subtraction.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/multiplication.hpp>
#include <optional>
Expand Down
6 changes: 0 additions & 6 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,6 @@ namespace zk_ml_toolchain {

void handleZkMlOperation(Operation *op, const ComponentParameters &compParams) {
if (zkml::DotProductOp operation = llvm::dyn_cast<zkml::DotProductOp>(op)) {
mlir::Value lhs = operation.getLhs();
mlir::Value rhs = operation.getRhs();
assert(lhs.getType() == rhs.getType() && "memrefs must be same type for DotProduct");
mlir::MemRefType MemRefType = mlir::cast<mlir::MemRefType>(lhs.getType());
assert(MemRefType.getShape().size() == 1 && "DotProduct must have tensors of rank 1");
logger.debug("computing DotProduct with %d x %d", MemRefType.getShape().back());
handle_dot_product<PreLimbs, PostLimbs>(operation, zero_var, stack, bp, assignmnt, compParams);
} else if (zkml::ArgMinOp operation = llvm::dyn_cast<zkml::ArgMinOp>(op)) {
auto nextIndexVar = put_into_assignment(stack.get_constant(operation.getNextIndex()));
Expand Down
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Models/RandomForest/RandomForest.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1xf64>[0.0]
30605
29239
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
[
{
"memref": {
"idx": 0,
"data": [
-8,
6,
-6,
5,
-13,
12,
6,
12,
6,
3,
-10,
3,
8,
1,
-8,
1
],
"dims": [4, 4],
"type": "int"
}
},
{
"memref": {
"idx": 1,
"data": [
-1,
13,
-4,
11,
0,
5,
14,
-4,
12,
-13,
-6,
11,
7,
-15,
-4,
2
],
"dims": [4, 4],
"type": "int"
}
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 :�
"
in_a
in_bout_a"MatMulIntegerMatMulIntegerSimpleZ
in_a


Z
in_b


b
out_a


B
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<4x4xi32>[ -29, -71, 132, -168, 169, -367, 136, -101, -105, 178, 66, -50, -97, 198, 26, -2]
48
2 changes: 1 addition & 1 deletion mlir-assigner/tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ long as it is applicable for ZK).
| **LpNormalization** | :x: | :x: | |
| **LpPool** | :x: | :x: | |
| **MatMul** | :white_check_mark: | :white_check_mark: | |
| **MatMulInteger** | :x: | :white_check_mark: | |
| **MatMulInteger** | :white_check_mark: | :white_check_mark: | |
| **Max** | :white_check_mark: | :white_check_mark: | |
| **MaxPool** | :white_check_mark: | :white_check_mark: | No support for integers at the moment. |
| **MaxRoiPool** | :x: | :x: | |
Expand Down

0 comments on commit d91fb31

Please sign in to comment.