Skip to content

Commit

Permalink
feat: shape transformation ops (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xThemis authored Jan 8, 2024
1 parent 798f046 commit 5f9c02a
Show file tree
Hide file tree
Showing 128 changed files with 604 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams>
template<typename BlueprintFieldType, typename ArithmetizationParams, typename MlirOp>
void handle_to_fixedpoint(
mlir::arith::SIToFPOp &operation,
MlirOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
Expand All @@ -29,7 +29,7 @@ namespace nil {
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::arith::SIToFPOp);
auto input = PREPARE_UNARY_INPUT(MlirOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
Expand Down
26 changes: 22 additions & 4 deletions mlir-assigner/include/mlir-assigner/memory/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,28 @@ namespace nil {
ss << type << ">[";
os << type_str;
if (type.isa<mlir::IntegerType>()) {
for (int i = 0; i < data.size(); i++) {
os << var_value(assignment, data[i]).data;
if (i != data.size() - 1)
os << ",";
if (type.isUnsignedInteger()) {
for (int i = 0; i < data.size(); i++) {
os << var_value(assignment, data[i]).data;
if (i != data.size() - 1)
os << ",";
}
} else {
static constexpr typename BlueprintFieldType::integral_type half_p =
(BlueprintFieldType::modulus - typename BlueprintFieldType::integral_type(1)) /
typename BlueprintFieldType::integral_type(2);
for (int i = 0; i < data.size(); i++) {
auto val =
static_cast<typename BlueprintFieldType::integral_type>(var_value(assignment, data[i]).data);
// check if negative
if (val > half_p) {
val = BlueprintFieldType::modulus - val;
os << "-";
}
os << val;
if (i != data.size() - 1)
os << ",";
}
}
} else if (type.isa<mlir::FloatType>()) {
for (int i = 0; i < data.size(); i++) {
Expand Down
48 changes: 35 additions & 13 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "nil/blueprint/blueprint/plonk/assignment.hpp"
#include <cassert>
#include <cstdint>
#include <limits>
#define TEST_WITHOUT_LOOKUP_TABLES

#include "mlir-assigner/helper/asserts.hpp"
Expand Down Expand Up @@ -191,15 +192,25 @@ namespace zk_ml_toolchain {
bool PrintCircuitOutput;
nil::blueprint::logger &logger;

template<typename NumberType>
NumberType resolve_number(VarType scalar) {
int64_t resolve_number(VarType scalar) {
auto scalar_value = var_value(assignmnt, scalar);
static constexpr auto limit_value =
typename BlueprintFieldType::integral_type(std::numeric_limits<NumberType>::max());
static constexpr auto limit_value_max =
typename BlueprintFieldType::integral_type(std::numeric_limits<int64_t>::max());
static constexpr auto limit_value_min =
BlueprintFieldType::modulus - limit_value_max - typename BlueprintFieldType::integral_type(1);
static constexpr typename BlueprintFieldType::integral_type half_p =
(BlueprintFieldType::modulus - typename BlueprintFieldType::integral_type(1)) /
typename BlueprintFieldType::integral_type(2);
auto integral_value = static_cast<typename BlueprintFieldType::integral_type>(scalar_value.data);
ASSERT_MSG(integral_value < limit_value, "Too large to cast");
NumberType number = static_cast<NumberType>(integral_value);
return number;
ASSERT_MSG(integral_value <= limit_value_max || integral_value >= limit_value_min,
"cannot fit into requested number");
// check if negative
if (integral_value > half_p) {
integral_value = BlueprintFieldType::modulus - integral_value;
return -static_cast<int64_t>(integral_value);
} else {
return static_cast<int64_t>(integral_value);
}
}

void doAffineFor(affine::AffineForOp &op, int64_t from, int64_t to, int64_t step) {
Expand Down Expand Up @@ -435,7 +446,10 @@ namespace zk_ml_toolchain {
frames.back().constant_values[mlir::hash_value(operation.getResult())] = cmpResult;
} else if (arith::ConstantOp operation = llvm::dyn_cast<arith::ConstantOp>(op)) {
TypedAttr constantValue = operation.getValueAttr();
if (constantValue.isa<IntegerAttr>()) {
if (operation->getResult(0).getType().isa<IndexType>()) {
frames.back().constant_values.insert(std::make_pair(
mlir::hash_value(operation.getResult()), llvm::dyn_cast<IntegerAttr>(constantValue).getInt()));
} else if (constantValue.isa<IntegerAttr>()) {
// this insert is ok, since this should never change, so we don't
// override it if it is already there

Expand Down Expand Up @@ -467,12 +481,12 @@ namespace zk_ml_toolchain {
} else if (arith::IndexCastOp operation = llvm::dyn_cast<arith::IndexCastOp>(op)) {
assert(operation->getNumOperands() == 1 && "IndexCast must have exactly one operand");
auto opHash = mlir::hash_value(operation->getOperand(0));
// from int to index
if (operation->getOperand(0).getType().isa<IntegerType>()) {
Type casteeType = operation->getOperand(0).getType();
if (casteeType.isa<IntegerType>()) {
auto i = frames.back().locals.find(opHash);
assert(i != frames.back().locals.end());
frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number<int64_t>(i->second);
} else if (operation->getOperand(0).getType().isa<IndexType>()) {
frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number(i->second);
} else if (casteeType.isa<IndexType>()) {
auto index = frames.back().constant_values.find(opHash);
assert(index != frames.back().constant_values.end());
typename BlueprintFieldType::value_type field_constant = index->second;
Expand All @@ -482,8 +496,16 @@ namespace zk_ml_toolchain {
UNREACHABLE("unsupported Index Cast");
}
} else if (arith::SIToFPOp operation = llvm::dyn_cast<arith::SIToFPOp>(op)) {
// TODO this does not respect negative and no different ranges for ints...
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::UIToFPOp operation = llvm::dyn_cast<arith::UIToFPOp>(op)) {
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::FPToSIOp operation = llvm::dyn_cast<arith::FPToSIOp>(op)) {
UNREACHABLE("Cast from FixedPoint to Int??");
} else if (llvm::isa<arith::ExtUIOp>(op) || llvm::isa<arith::ExtSIOp>(op) ||
llvm::isa<arith::TruncIOp>(op)) {
auto toExtend = frames.back().locals.find(mlir::hash_value(op->getOperand(0)));
assert(toExtend != frames.back().locals.end());
frames.back().locals[mlir::hash_value(op->getResult(0))] = toExtend->second;
} else {
std::string opName = op->getName().getIdentifier().str();
UNREACHABLE(std::string("unhandled arith operation: ") + opName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.052490234375, 0.0849761962890625, 0.290435791015625, 0.7718963623046875, 0.4334564208984375, 0.4372100830078125, 0.5269012451171875, 0.173980712890625, 0.0650787353515625, 0.6441802978515625], "dims": [1, 10], "type": "f32"}}]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "casttoint64.mlir"} {
func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xi64> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} {
%alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xi64>
affine.for %arg1 = 0 to 1 {
affine.for %arg2 = 0 to 10 {
%0 = affine.load %arg0[%arg1, %arg2] : memref<1x10xf32>
%1 = arith.fptosi %0 : f32 to i64
affine.store %1, %alloc[%arg1, %arg2] : memref<1x10xi64>
}
}
return %alloc : memref<1x10xi64>
}
"krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22i64\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> ()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :^

in_aout_a"Cast*
to� CastToInt64Z
in_a



b
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xint>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
ADD THE ROWS HERE
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Ops/Onnx/Add/AddSimple.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1x10xf32>[1.2744598388671875, 0.8813934326171875, 1.0463104248046875, 0.8135986328125, 0.8283538818359375, 0.2144317626953125, 0.900787353515625, 1.1385650634765625, 0.3757171630859375, 1.0577239990234375]
23 rows
22 rows
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Ops/Onnx/And/AndSimple.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1x10xi1>[0, 0, 0, 0, 0, 0, 0, 0, 1, 1]
23
22
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Ops/Onnx/ArgMax/ArgMaxSimple.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1x1xi64>[3]
23
22
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Ops/Onnx/ArgMin/ArgMinLastIndex.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1x1xi64>[4]
23
22
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1xi64>[3]
23
22
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Ops/Onnx/ArgMin/ArgMinNoKeepDims.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1xi64>[2]
23
22
2 changes: 1 addition & 1 deletion mlir-assigner/tests/Ops/Onnx/ArgMin/ArgMinSimple.res
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Result:
memref<1x1xi64>[6]
23
22
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastBoolToFloat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0, 1, 1, 0, 0, 1, 0, 1, 0, 0], "dims": [1, 10], "type": "bool"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastBoolToFloat.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :b

in_aout_a"Cast*
to�CastBoolToFloatZ
in_a
 


b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastBoolToFloat.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0, 1, 1, 0, 0, 1, 0, 1, 0, 0]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastBoolToInt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0, 1, 0, 1, 0, 0, 0, 0, 1, 1], "dims": [1, 10], "type": "bool"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastBoolToInt.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :`

in_aout_a"Cast*
to�CastBoolToIntZ
in_a
 


b
out_a



B
Expand Down
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastBoolToInt.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi64>[0, 1, 0, 1, 0, 0, 0, 0, 1, 1]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastIntToInt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-922, -18644, 15583, -18992, -23256, 16318, 19071, -2610, 25807, -12559], "dims": [1, 10], "type": "int"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastIntToInt.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :_

in_aout_a"Cast*
to� CastIntToIntZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastIntToInt.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi64>[-922, -18644, 15583, -18992, -23256, 16318, 19071, -2610, 25807, -12559]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastIntToSmallerInt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-9436, -9740, 9984, 19795, -1273, 29365, -5597, -4249, 27125, 31185], "dims": [1, 10], "type": "int"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastIntToSmallerInt.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :f

in_aout_a"Cast*
to�CastIntToSmallerIntZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastIntToSmallerInt.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi32>[-9436, -9740, 9984, 19795, -1273, 29365, -5597, -4249, 27125, 31185]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastNoop.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.4317779541015625, 0.4131011962890625, 0.35601806640625, 0.360595703125, 0.748809814453125, 0.9254913330078125, 0.939361572265625, 0.5069427490234375, 0.842010498046875, 0.088775634765625], "dims": [1, 10], "type": "f32"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastNoop.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :[

in_aout_a"Cast*
to�CastNoopZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastNoop.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0.4317779541015625, 0.4131011962890625, 0.35601806640625, 0.360595703125, 0.748809814453125, 0.9254913330078125, 0.939361572265625, 0.5069427490234375, 0.842010498046875, 0.088775634765625]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToFloat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [10830, -25594, -30056, -23995, 5899, -17295, -27610, 28494, -2621, 9410], "dims": [1, 10], "type": "int"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToFloat.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :^

in_aout_a"Cast*
to� CastToFloatZ
in_a



b
out_a



B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToFloat.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[10830, -25594, -30056, -23995, 5899, -17295, -27610, 28494, -2621, 9410]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Concat/ConcatSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.8958892822265625, 0.5222930908203125, 0.00238037109375, 0.0584869384765625, 0.2328948974609375, 0.0062255859375, 0.7711029052734375, 0.9699249267578125, 0.9550933837890625, 0.1429443359375], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.2341461181640625, 0.132080078125, 0.9774627685546875, 0.3519134521484375, 0.51129150390625, 0.9049072265625, 0.265106201171875, 0.4313507080078125, 0.0768890380859375, 0.9449615478515625], "dims": [1, 10], "type": "f32"}}]
19 changes: 19 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Concat/ConcatSimple.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 :�
(
in_a
in_bout_a"Concat*
axis� ConcatSimpleZ
in_a



Z
in_b



b
out_a


B
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Concat/ConcatSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x20xf32>[0.8958892822265625, 0.5222930908203125, 0.00238037109375, 0.0584869384765625, 0.2328948974609375, 0.0062255859375, 0.7711029052734375, 0.9699249267578125, 0.9550933837890625, 0.1429443359375, 0.2341461181640625, 0.132080078125, 0.9774627685546875, 0.3519134521484375, 0.51129150390625, 0.9049072265625, 0.265106201171875, 0.4313507080078125, 0.0768890380859375, 0.9449615478515625]
22
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Concat/ConcatSimple2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.23565673828125, 0.96636962890625, 0.41424560546875, 0.226409912109375, 0.861175537109375, 0.000701904296875, 0.7342987060546875, 0.146636962890625, 0.6063232421875, 0.6813812255859375], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.4473876953125, 0.227752685546875, 0.33148193359375, 0.32470703125, 0.6154937744140625, 0.5596160888671875, 0.816192626953125, 0.0675048828125, 0.5163116455078125, 0.7897491455078125], "dims": [1, 10], "type": "f32"}}]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Concat/ConcatSimple2.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<2x10xf32>[0.23565673828125, 0.96636962890625, 0.41424560546875, 0.226409912109375, 0.861175537109375, 0.000701904296875, 0.7342987060546875, 0.146636962890625, 0.6063232421875, 0.6813812255859375, 0.4473876953125, 0.227752685546875, 0.33148193359375, 0.32470703125, 0.6154937744140625, 0.5596160888671875, 0.816192626953125, 0.0675048828125, 0.5163116455078125, 0.7897491455078125]
22
Loading

0 comments on commit 5f9c02a

Please sign in to comment.