Skip to content

Commit

Permalink
feat: to int (#24)
Browse files Browse the repository at this point in the history
* to int working (except 8bits version)
* pinned blueprint to latest version
  • Loading branch information
0xThemis authored Jan 9, 2024
1 parent 640bc38 commit df674f7
Show file tree
Hide file tree
Showing 18 changed files with 177 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#ifndef CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP
#define CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP

#include "mlir/Dialect/zkml/IR/DotProduct.h"
#include <cstdint>
#include <mlir/Dialect/Arith/IR/Arith.h>

#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>

#include <nil/blueprint/component.hpp>
#include <nil/blueprint/basic_non_native_policy.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/lookup_tables/tester.hpp> // TODO: check if there is a new mechanism for this in nil upstream

#include <mlir-assigner/helper/asserts.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <mlir-assigner/components/handle_component.hpp>

namespace nil {
namespace blueprint {

template<typename BlueprintFieldType, typename ArithmetizationParams, typename MlirOp>
void handle_to_fixedpoint(
MlirOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::int_to_fix<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(MlirOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
namespace detail {
template<typename BlueprintFieldType, typename ArithmetizationParams, typename MlirOp, uint8_t OutputType>
void handle_to_int(
MlirOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_to_int<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
typename component_type::OutputType outputType =
static_cast<typename component_type::OutputType>(OutputType);
auto input = PREPARE_UNARY_INPUT(MlirOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1, OutputType>;
const auto p = detail::PolicyManager::get_parameters(manifest_reader::get_witness(0));

component_type component(p.witness, manifest_reader::get_constants(),
manifest_reader::get_public_inputs(), 1, 1, outputType);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace detail

#define HANDLE_TO_INT(TY) \
detail::handle_to_int<BlueprintFieldType, ArithmetizationParams, mlir::arith::FPToSIOp, TY>( \
operation, frame, bp, assignment, start_row);

#define HANDLE_TO_UINT(TY) \
detail::handle_to_int<BlueprintFieldType, ArithmetizationParams, mlir::arith::FPToUIOp, TY>( \
operation, frame, bp, assignment, start_row);
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_to_int(
mlir::arith::FPToSIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_to_int<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
switch (operation->getResult(0).getType().getIntOrFloatBitWidth()) {
case 8:
HANDLE_TO_INT(component_type::OutputType::I8);
break;
case 16:
HANDLE_TO_INT(component_type::OutputType::I16);
break;
case 32:
HANDLE_TO_INT(component_type::OutputType::I32);
break;
case 64:
HANDLE_TO_INT(component_type::OutputType::I64);
break;
}
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_to_int(
mlir::arith::FPToUIOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_to_int<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
switch (operation->getResult(0).getType().getIntOrFloatBitWidth()) {
case 8:
HANDLE_TO_UINT(component_type::OutputType::U8);
break;
case 16:
HANDLE_TO_UINT(component_type::OutputType::U16);
break;
case 32:
HANDLE_TO_UINT(component_type::OutputType::U32);
break;
case 64:
HANDLE_TO_UINT(component_type::OutputType::U64);
break;
}
}

#undef HANDLE_TO_INT
#undef HANDLE_TO_UINT
} // namespace blueprint
} // namespace nil

#endif // CRYPTO3_ASSIGNER_FIXEDPOINT_TO_FIXEDPOINT_HPP

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <functional>
#include <mlir-assigner/memory/stack_frame.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/to_fixedpoint.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/to_int.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/cos.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/tan.hpp>
Expand Down
22 changes: 20 additions & 2 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/fixedpoint/trigonometric.hpp>
#include <mlir-assigner/components/boolean/logic_ops.hpp>
#include <mlir-assigner/components/fixedpoint/to_fixpoint.hpp>
#include <mlir-assigner/components/fixedpoint/conversion.hpp>

#include <mlir-assigner/memory/memref.hpp>
#include <mlir-assigner/memory/stack_frame.hpp>
Expand Down Expand Up @@ -506,7 +506,9 @@ namespace zk_ml_toolchain {
} else if (arith::UIToFPOp operation = llvm::dyn_cast<arith::UIToFPOp>(op)) {
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::FPToSIOp operation = llvm::dyn_cast<arith::FPToSIOp>(op)) {
UNREACHABLE("Cast from FixedPoint to Int??");
handle_to_int(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::FPToUIOp operation = llvm::dyn_cast<arith::FPToUIOp>(op)) {
handle_to_int(operation, frames.back(), bp, assignmnt, start_row);
} else if (llvm::isa<arith::ExtUIOp>(op) || llvm::isa<arith::ExtSIOp>(op) ||
llvm::isa<arith::TruncIOp>(op)) {
auto toExtend = frames.back().locals.find(mlir::hash_value(op->getOperand(0)));
Expand Down Expand Up @@ -1028,6 +1030,22 @@ namespace zk_ml_toolchain {
return;
}

if (mlir::UnrealizedConversionCastOp operation = llvm::dyn_cast<mlir::UnrealizedConversionCastOp>(op)) {
// we do not like this but when onnx-mlir lowers from onnx.Cast to unsigned it uses this to cast
// from signless integers (e.g. i64) to unsigned integer(e.g. ui64)
// SO if we transform from one signless integer to an unsigned integer with the SAME bit length
// we indulge, otherwise we panic
mlir::Type SrcType = operation->getOperand(0).getType();
mlir::Type DstType = operation->getResult(0).getType();
assert(SrcType.isSignlessInteger() && "src must be signless integertype for conversion cast");
assert(DstType.isUnsignedInteger(SrcType.getIntOrFloatBitWidth()) &&
"dst must be unsigned integer with same bit width as src");
auto Src = frames.back().locals.find(mlir::hash_value(operation->getOperand(0)));
assert(Src != frames.back().locals.end());
frames.back().locals[mlir::hash_value(operation->getResult(0))] = Src->second;
return;
}

std::string opName = op->getName().getIdentifier().str();
llvm::outs() << op->getDialect()->getNamespace() << "\n";
UNREACHABLE(std::string("unhandled operation: ") + opName);
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [-4.334595012532761, 92.92536192655963, -83.79716946092812, -47.00000000000001, -58.99999999999999, 68.14344958903294, 97.32986424021, 43.620091845482136, 75.29771234198004, 57.999999999999], "dims": [1, 10], "type": "f32"}}]
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToInt64.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi64>[-4, 92, -83, -47, -58, 68, 97, 43, 75, 57]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [4.000000000000001, 92.92536192655963, 83.79716946092812, 47.67719849791331, 58.17853088366004, 68.14344958903294, 97.32986424021, 43.620091845482136, 75.29771234198004, 57.999999999999], "dims": [1, 10], "type": "f32"}}]
14 changes: 14 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
 :_

in_aout_a"Cast*
to� CastToUInt64Z
in_a



b
out_a



B
Expand Down
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Cast/CastToUInt64.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xui64>[4, 92, 83, 47, 58, 68, 97, 43, 75, 57]
12

0 comments on commit df674f7

Please sign in to comment.