Skip to content

Commit

Permalink
feat: complex ops (#23)
Browse files Browse the repository at this point in the history
* add tanh activiation function
* complex that use tanh as activiation
* identified problematic operations
  • Loading branch information
0xThemis authored Jan 9, 2024
1 parent 5f9c02a commit ff582cd
Show file tree
Hide file tree
Showing 60 changed files with 643 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ namespace nil {
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_tanh(
mlir::math::TanhOp &operation,
stack_frame<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &frame,
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
using component_type = components::fix_tanh<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

auto input = PREPARE_UNARY_INPUT(mlir::math::TanhOp);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams, 1, 1>::get_witness(0, 1, 1));

component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs(),
1, 1);
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}
} // namespace blueprint
} // namespace nil

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/to_fixedpoint.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/cos.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/tanh.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmin.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sqrt.hpp>
Expand Down
14 changes: 11 additions & 3 deletions mlir-assigner/include/mlir-assigner/memory/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ namespace nil {
return data.size();
}

void copyFrom(memref &src, uint64_t num_elements, uint64_t dst_offset, uint64_t src_offset) {
assert(this->size() >= num_elements + dst_offset && "Out of bounds access");
assert(src.size() >= num_elements + src_offset && "Out of bounds access");
for (unsigned i = 0; i < num_elements; ++i) {
this->data[dst_offset + i] = src.data[src_offset + i];
}
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
void print(
std::ostream &os,
Expand All @@ -144,13 +152,13 @@ namespace nil {
(BlueprintFieldType::modulus - typename BlueprintFieldType::integral_type(1)) /
typename BlueprintFieldType::integral_type(2);
for (int i = 0; i < data.size(); i++) {
auto val =
static_cast<typename BlueprintFieldType::integral_type>(var_value(assignment, data[i]).data);
auto val = static_cast<typename BlueprintFieldType::integral_type>(
var_value(assignment, data[i]).data);
// check if negative
if (val > half_p) {
val = BlueprintFieldType::modulus - val;
os << "-";
}
}
os << val;
if (i != data.size() - 1)
os << ",";
Expand Down
39 changes: 32 additions & 7 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ namespace zk_ml_toolchain {
}
}

double toFixpoint(VarType toConvert) {
auto val = var_value(assignmnt, toConvert).data;
nil::blueprint::components::FixedPoint<BlueprintFieldType, 1, 1> out(val, 16);
return out.to_double();
}

void handleArithOperation(Operation *op) {
std::uint32_t start_row = assignmnt.allocated_rows();
if (arith::AddFOp operation = llvm::dyn_cast<arith::AddFOp>(op)) {
Expand Down Expand Up @@ -516,6 +522,8 @@ namespace zk_ml_toolchain {
std::uint32_t start_row = assignmnt.allocated_rows();
if (math::ExpOp operation = llvm::dyn_cast<math::ExpOp>(op)) {
handle_fixedpoint_exp_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::Exp2Op operation = llvm::dyn_cast<math::Exp2Op>(op)) {
UNREACHABLE("TODO: component for exp2 not ready");
} else if (math::LogOp operation = llvm::dyn_cast<math::LogOp>(op)) {
handle_fixedpoint_log_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::PowFOp operation = llvm::dyn_cast<math::PowFOp>(op)) {
Expand All @@ -538,9 +546,9 @@ namespace zk_ml_toolchain {
} else if (math::CosOp operation = llvm::dyn_cast<math::CosOp>(op)) {
handle_cos(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::AtanOp operation = llvm::dyn_cast<math::AtanOp>(op)) {
UNREACHABLE("TODO: component for atanh not ready");
UNREACHABLE("TODO: component for atan not ready");
} else if (math::TanhOp operation = llvm::dyn_cast<math::TanhOp>(op)) {
UNREACHABLE("TODO: component for tanh not ready");
handle_tanh(operation, frames.back(), bp, assignmnt, start_row);
} else if (math::ErfOp operation = llvm::dyn_cast<math::ErfOp>(op)) {
UNREACHABLE("TODO: component for erf not ready");
} else {
Expand Down Expand Up @@ -745,7 +753,7 @@ namespace zk_ml_toolchain {
assert(funcOp != functions.end());

// only can handle single outputs atm
assert(numOutputs == 1);
// assert(numOutputs == 1);

// prepare the arguments for the function
frames.push_back(nil::blueprint::stack_frame<VarType>());
Expand Down Expand Up @@ -775,6 +783,21 @@ namespace zk_ml_toolchain {
// TODO: what to do when done...
// maybe print output?
return;
} else if (KrnlMemcpyOp operation = llvm::dyn_cast<KrnlMemcpyOp>(op)) {
// get dst and src memref
auto DstMemref = frames.back().memrefs.find(mlir::hash_value(operation.getDest()));
auto SrcMemref = frames.back().memrefs.find(mlir::hash_value(operation.getSrc()));
assert(DstMemref != frames.back().memrefs.end());
assert(SrcMemref != frames.back().memrefs.end());
// get num elements and offset
auto NumElements = frames.back().constant_values.find(mlir::hash_value(operation.getNumElems()));
auto DstOffset = frames.back().constant_values.find(mlir::hash_value(operation.getDestOffset()));
auto SrcOffset = frames.back().constant_values.find(mlir::hash_value(operation.getSrcOffset()));
assert(NumElements != frames.back().constant_values.end());
assert(DstOffset != frames.back().constant_values.end());
assert(SrcOffset != frames.back().constant_values.end());
DstMemref->second.copyFrom(SrcMemref->second, NumElements->second, DstOffset->second,
SrcOffset->second);
} else if (KrnlAcosOp operation = llvm::dyn_cast<KrnlAcosOp>(op)) {
UNREACHABLE(std::string("TODO KrnlAcos: link to bluebrint component"));
} else if (KrnlAsinOp operation = llvm::dyn_cast<KrnlAsinOp>(op)) {
Expand Down Expand Up @@ -990,14 +1013,16 @@ namespace zk_ml_toolchain {

if (func::ReturnOp operation = llvm::dyn_cast<func::ReturnOp>(op)) {
auto ops = operation.getOperands();
assert(ops.size() == 1); // only handle single return value atm
// assert(ops.size() == 1); // only handle single return value atm
// the ops[0] is something that we can hash_value to grab the result
// from maps
auto retval = frames.back().memrefs.find(mlir::hash_value(ops[0]));
assert(retval != frames.back().memrefs.end());
if (PrintCircuitOutput) {
std::cout << "Result:\n";
retval->second.print(std::cout, assignmnt);
for (unsigned i = 0; i < ops.size(); ++i) {
auto retval = frames.back().memrefs.find(mlir::hash_value(ops[i]));
assert(retval != frames.back().memrefs.end());
retval->second.print(std::cout, assignmnt);
}
}
return;
}
Expand Down
29 changes: 14 additions & 15 deletions mlir-assigner/include/mlir-assigner/parser/input_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <boost/json/kind.hpp>
#include "mlir-assigner/helper/asserts.hpp"
#include "onnx/string_utils.h"
#include <cstdint>
#include <mlir-assigner/memory/stack_frame.hpp>

Expand Down Expand Up @@ -62,12 +63,12 @@ namespace nil {
}

bool parse_fixedpoint(const boost::json::value &value, typename BlueprintFieldType::value_type &out) {
// for now only double, but later we most likely will need strings as well
// we hardcode the scale with 2^16 for now. Let's see later down the line
double d;
if (value.kind() == boost::json::kind::double_) {
d = value.as_double();
} else {
} else if (value.kind() == boost::json::kind::int64) {
d = static_cast<double>(value.as_int64());
}else {
UNREACHABLE("TODO add string support");
}
nil::blueprint::components::FixedPoint<BlueprintFieldType, 1, 1> fixed(d);
Expand All @@ -82,14 +83,14 @@ namespace nil {
}

bool parse_int(const boost::json::value &value, typename BlueprintFieldType::value_type &out) {
switch (value.kind()) {
case boost::json::kind::int64:
case boost::json::kind::uint64:
return parse_scalar(value, out);
default:
std::cerr << "unsupported int type: " << value.as_string() << std::endl;
UNREACHABLE("int must be int64 or uint64");
};
switch (value.kind()) {
case boost::json::kind::int64:
case boost::json::kind::uint64:
return parse_scalar(value, out);
default:
std::cerr << "unsupported int type: " << value.as_string() << std::endl;
UNREACHABLE("int must be int64 or uint64");
};
}

bool parse_scalar(const boost::json::value &value, typename BlueprintFieldType::value_type &out) {
Expand Down Expand Up @@ -243,7 +244,6 @@ namespace nil {
}

bool parse_memref_data(memref<var> &data, const boost::json::array &tensor_arr, std::string &type) {

if (type == "f32") {
for (size_t i = 0; i < tensor_arr.size(); ++i) {
if (!parse_fixedpoint(tensor_arr[i], assignmnt.public_input(0, public_input_idx))) {
Expand All @@ -253,16 +253,15 @@ namespace nil {
data.put_flat(i, var(0, public_input_idx++, false, var::column_type::public_input));
}
} else if (type == "int") {
//TODO do we have to handle uint?
// TODO do we have to handle uint?
for (size_t i = 0; i < tensor_arr.size(); ++i) {
if (!parse_int(tensor_arr[i], assignmnt.public_input(0, public_input_idx))) {
llvm::errs() << "expect fixedpoints in tensor\n";
return false;
}
data.put_flat(i, var(0, public_input_idx++, false, var::column_type::public_input));
}
}
else if (type == "bool") {
} else if (type == "bool") {
for (size_t i = 0; i < tensor_arr.size(); ++i) {
if (!parse_bool(tensor_arr[i], assignmnt.public_input(0, public_input_idx))) {
llvm::errs() << "expect fixedpoints in tensor\n";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"memref": {"data": [0.863128662109375, 0.9586029052734375, 0.882354736328125, 0.1265869140625, 0.845703125, 0.1936492919921875, 0.6169281005859375, 0.9657135009765625, 0.6363525390625, 0.4542388916015625, 0.6410369873046875, 0.8735504150390625, 0.3882904052734375, 0.835235595703125, 0.3632965087890625, 0.9285125732421875, 0.348907470703125, 0.7984161376953125, 0.7083587646484375, 0.02813720703125, 0.947784423828125, 0.0034942626953125, 0.203033447265625, 0.0853729248046875, 0.5583343505859375], "dims": [1, 25], "type": "f32"}}]
Loading

0 comments on commit ff582cd

Please sign in to comment.