Skip to content

Commit

Permalink
fix: address compiler warnings originating in our code
Browse files Browse the repository at this point in the history
* build(deps): update blueprint

* fix: address warnings about signedness

* fix: warnings about unused typedefs

* fix: warning for switch covering all possibilities

* fix: warnings about unqualified calls of template functions pre C++20

* fix: warnings about signedness

* fix: warnings about reordering in constructor initializer list

* fix: warnings about unused variables

* fix: duplicate assignment

* fix: warnings about signedness

* fix: warnings about unused variables

* fix: warnings about signedness and unused variables

* build(deps): update blueprint to fix warnings

* fix: signed types for loop checks in memref
  • Loading branch information
dkales authored Feb 15, 2024
1 parent 8b51964 commit ec742ee
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ namespace nil {
// we compare 64 bits with this configuration
auto result =
call_component<PreLimbs, PostLimbs, component_type>(operation, stack, bp, assignment, compParams);
// TODO should we store zero instead???
switch (operation.getPredicate()) {
case mlir::arith::CmpFPredicate::UGT:
case mlir::arith::CmpFPredicate::OGT: {
Expand Down Expand Up @@ -122,9 +121,6 @@ namespace nil {
UNREACHABLE("Unsupported fcmp predicate (UNO, ORD, AlwaysFalse, AlwaysTrue)");
break;
}
default:
UNREACHABLE("Unsupported fcmp predicate");
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ namespace nil {
circuit_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>> &bp,
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
const common_component_parameters<crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
auto c = stack.get_local(operation.getCondition());
auto x = stack.get_local(operation.getTrueValue());
auto y = stack.get_local(operation.getFalseValue());
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_select<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;
BlueprintFieldType,
basic_non_native_policy<BlueprintFieldType>>;

using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
typename component_type::input_type input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ namespace nil {
&assignment,
const common_component_parameters<
crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>> &compParams) {
using component_type = components::fix_dot_rescale_2_gates<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>,
BlueprintFieldType, basic_non_native_policy<BlueprintFieldType>>;

mlir::Value lhs = operation.getLhs();
mlir::Value rhs = operation.getRhs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ namespace nil {

typename ComponentType::result_type result =
std::uint8_t(compParams.gen_mode & generation_mode::ASSIGNMENTS) ?
result = components::generate_assignments(component, assignment, input, compParams.start_row) :
result = typename ComponentType::result_type(component, compParams.start_row);
components::generate_assignments(component, assignment, input, compParams.start_row) :
typename ComponentType::result_type(component, compParams.start_row);

// touch result variables
if (std::uint8_t(compParams.gen_mode & generation_mode::ASSIGNMENTS) == 0) {
Expand Down
38 changes: 20 additions & 18 deletions mlir-assigner/include/mlir-assigner/memory/memref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,21 @@ namespace nil {

memref(std::vector<int64_t> dims, mlir::Type type) : data(), dims(dims), strides(), type(type) {
strides.resize(dims.size());
for (int i = dims.size() - 1; i >= 0; i--) {
strides[i] = (i == dims.size() - 1) ? 1 : strides[i + 1] * dims[i + 1];
ASSERT(dims[i] > 0 && "Dims in tensor must be greater zero. Do you have a model with dynamic input?");
for (ssize_t i = dims.size() - 1; i >= 0; i--) {
strides[i] = (i == ssize_t(dims.size()) - 1) ? 1 : strides[i + 1] * dims[i + 1];
ASSERT(dims[i] > 0 &&
"Dims in tensor must be greater zero. Do you have a model with dynamic input?");
}
// this also handles the case when dims is empty, since we still allocate
// 1 here
data.resize(std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<uint32_t>()));
}
memref(llvm::ArrayRef<int64_t> dims, mlir::Type type) : data(), dims(dims), strides(), type(type) {
strides.resize(dims.size());
for (int i = dims.size() - 1; i >= 0; i--) {
strides[i] = (i == dims.size() - 1) ? 1 : strides[i + 1] * dims[i + 1];
ASSERT(dims[i] > 0 && "Dims in tensor must be greater zero. Do you have a model with dynamic input?");
for (ssize_t i = dims.size() - 1; i >= 0; i--) {
strides[i] = (i == ssize_t(dims.size()) - 1) ? 1 : strides[i + 1] * dims[i + 1];
ASSERT(dims[i] > 0 &&
"Dims in tensor must be greater zero. Do you have a model with dynamic input?");
}
// this also handles the case when dims is empty, since we still allocate 1
// here
Expand All @@ -47,8 +49,8 @@ namespace nil {
memref(std::vector<int64_t> dims, std::vector<VarType> data, mlir::Type type) :
data(data), dims(dims), strides() {
strides.resize(dims.size());
for (int i = dims.size() - 1; i >= 0; i--) {
strides[i] = (i == dims.size() - 1) ? 1 : strides[i + 1] * dims[i + 1];
for (ssize_t i = dims.size() - 1; i >= 0; i--) {
strides[i] = (i == ssize_t(dims.size()) - 1) ? 1 : strides[i + 1] * dims[i + 1];
}
assert(data.size() ==
std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<uint32_t>()));
Expand All @@ -57,7 +59,7 @@ namespace nil {
const VarType &get(const std::vector<int64_t> &indices) const {
assert(indices.size() == dims.size());
uint32_t offset = 0;
for (int i = 0; i < indices.size(); i++) {
for (size_t i = 0; i < indices.size(); i++) {
assert(indices[i] < dims[i]);
offset += indices[i] * strides[i];
}
Expand All @@ -67,7 +69,7 @@ namespace nil {
const VarType &get(const llvm::SmallVector<int64_t> &indices) const {
assert(indices.size() == dims.size());
uint32_t offset = 0;
for (int i = 0; i < indices.size(); i++) {
for (size_t i = 0; i < indices.size(); i++) {
assert(indices[i] < dims[i]);
offset += indices[i] * strides[i];
}
Expand All @@ -81,7 +83,7 @@ namespace nil {
void put(const std::vector<int64_t> &indices, const VarType &value) {
assert(indices.size() == dims.size());
uint32_t offset = 0;
for (int i = 0; i < indices.size(); i++) {
for (size_t i = 0; i < indices.size(); i++) {
assert(indices[i] < dims[i]);
offset += indices[i] * strides[i];
}
Expand All @@ -91,22 +93,22 @@ namespace nil {
void put(const llvm::SmallVector<int64_t> &indices, const VarType &value) {
assert(indices.size() == dims.size());
uint32_t offset = 0;
for (int i = 0; i < indices.size(); i++) {
for (size_t i = 0; i < indices.size(); i++) {
assert(indices[i] < dims[i]);
offset += indices[i] * strides[i];
}
data[offset] = value;
}

void put_flat(const int64_t idx, const VarType &value) {
assert(idx >= 0 && idx < data.size());
assert(idx >= 0 && size_t(idx) < data.size());
data[idx] = value;
}

mlir::Type getType() const {
return type;
}
int64_t size() const {
size_t size() const {
return data.size();
}

Expand All @@ -126,7 +128,7 @@ namespace nil {
&assignment) {
using FixedPoint = components::FixedPoint<BlueprintFieldType, PreLimbs, PostLimbs>;
os << "memref<";
for (int i = 0; i < dims.size(); i++) {
for (size_t i = 0; i < dims.size(); i++) {
os << dims[i];
os << "x";
}
Expand All @@ -136,7 +138,7 @@ namespace nil {
os << type_str;
if (type.isa<mlir::IntegerType>()) {
if (type.isUnsignedInteger()) {
for (int i = 0; i < data.size(); i++) {
for (size_t i = 0; i < data.size(); i++) {
os << var_value(assignment, data[i]).data;
if (i != data.size() - 1)
os << ",";
Expand All @@ -145,7 +147,7 @@ namespace nil {
static constexpr typename BlueprintFieldType::integral_type half_p =
(BlueprintFieldType::modulus - typename BlueprintFieldType::integral_type(1)) /
typename BlueprintFieldType::integral_type(2);
for (int i = 0; i < data.size(); i++) {
for (size_t i = 0; i < data.size(); i++) {
auto val = static_cast<typename BlueprintFieldType::integral_type>(
var_value(assignment, data[i]).data);
// check if negative
Expand All @@ -159,7 +161,7 @@ namespace nil {
}
}
} else if (type.isa<mlir::FloatType>()) {
for (int i = 0; i < data.size(); i++) {
for (size_t i = 0; i < data.size(); i++) {
auto value = var_value(assignment, data[i]).data;
FixedPoint out(value, FixedPoint::SCALE);
os << out.to_double();
Expand Down
42 changes: 19 additions & 23 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,9 @@ namespace zk_ml_toolchain {
boost::json::array &public_output, generation_mode gen_mode,
nil::blueprint::print_format print_circuit_format, std::string &clip,
nil::blueprint::logger &logger) :
bp(circuit),
assignmnt(assignment), public_input(public_input), private_input(private_input),
public_output(public_output), gen_mode(gen_mode), print_circuit_format(print_circuit_format),
logger(logger) {
gen_mode(gen_mode),
print_circuit_format(print_circuit_format), logger(logger), bp(circuit), assignmnt(assignment),
public_input(public_input), private_input(private_input), public_output(public_output) {
lower_bound = FixedPoint(1, FixedPoint::SCALE).to_double();
if ("clip" == clip) {
clip_strategy = ClipStrategy::CLIP;
Expand All @@ -208,16 +207,15 @@ namespace zk_ml_toolchain {
evaluator &operator=(const evaluator &pass) = delete;

void handleKrnlEntryOperation(KrnlEntryPointOp &EntryPoint, std::string &func) {
int32_t numInputs = -1;
int32_t numOutputs = -1;

for (auto a : EntryPoint->getAttrs()) {
if (a.getName() == EntryPoint.getEntryPointFuncAttrName()) {
func = a.getValue().cast<SymbolRefAttr>().getLeafReference().str();
} else if (a.getName() == EntryPoint.getNumInputsAttrName()) {
numInputs = a.getValue().cast<IntegerAttr>().getInt();
// do nothing for num inputs atm
// a.getValue().cast<IntegerAttr>().getInt();
} else if (a.getName() == EntryPoint.getNumOutputsAttrName()) {
numOutputs = a.getValue().cast<IntegerAttr>().getInt();
// do nothing for num outputs atm
// a.getValue().cast<IntegerAttr>().getInt();
} else if (a.getName() == EntryPoint.getSignatureAttrName()) {
// do nothing for signature atm
// TODO: check against input types & shapes
Expand Down Expand Up @@ -373,13 +371,13 @@ namespace zk_ml_toolchain {
} else if (arith::SubFOp operation = llvm::dyn_cast<arith::SubFOp>(op)) {
nil::blueprint::handle_sub(operation, stack, bp, assignmnt, compParams);
} else if (arith::MulFOp operation = llvm::dyn_cast<arith::MulFOp>(op)) {
handle_fmul<PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_fmul<PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (arith::DivFOp operation = llvm::dyn_cast<arith::DivFOp>(op)) {
handle_fdiv<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_fdiv<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (arith::RemFOp operation = llvm::dyn_cast<arith::RemFOp>(op)) {
handle_frem<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_frem<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (arith::CmpFOp operation = llvm::dyn_cast<arith::CmpFOp>(op)) {
handle_fcmp<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_fcmp<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (arith::SelectOp operation = llvm::dyn_cast<arith::SelectOp>(op)) {
ASSERT(operation.getNumOperands() == 3 && "Select must have three operands");
ASSERT(operation->getOperand(1).getType() == operation->getOperand(2).getType() &&
Expand All @@ -406,7 +404,7 @@ namespace zk_ml_toolchain {
stack.push_local(operation->getResult(0), falsy);
}
} else {
handle_select(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_select(operation, stack, bp, assignmnt, compParams);
}
} else {
std::string typeStr;
Expand All @@ -415,7 +413,7 @@ namespace zk_ml_toolchain {
UNREACHABLE(std::string("unhandled select operand: ") + typeStr);
}
} else if (arith::NegFOp operation = llvm::dyn_cast<arith::NegFOp>(op)) {
handle_neg(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_neg(operation, stack, bp, assignmnt, compParams);
} else if (arith::AndIOp operation = llvm::dyn_cast<arith::AndIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
Expand Down Expand Up @@ -608,7 +606,7 @@ namespace zk_ml_toolchain {
if (math::ExpOp operation = llvm::dyn_cast<math::ExpOp>(op)) {
nil::blueprint::handle_exp<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (math::LogOp operation = llvm::dyn_cast<math::LogOp>(op)) {
handle_log<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_log<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (math::PowFOp operation = llvm::dyn_cast<math::PowFOp>(op)) {
UNREACHABLE("powf not supported. Did you compile the model with standard MLIR?");
} else if (math::IPowIOp operation = llvm::dyn_cast<math::IPowIOp>(op)) {
Expand All @@ -620,11 +618,11 @@ namespace zk_ml_toolchain {
assert(base == 2 && "For now we only support powi to power of 2");
stack.push_constant(operation.getResult(), 1 << pow);
} else if (math::AbsFOp operation = llvm::dyn_cast<math::AbsFOp>(op)) {
handle_abs<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_abs<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (math::CeilOp operation = llvm::dyn_cast<math::CeilOp>(op)) {
handle_ceil<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_ceil<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (math::FloorOp operation = llvm::dyn_cast<math::FloorOp>(op)) {
handle_floor<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
nil::blueprint::handle_floor<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, compParams);
} else if (math::CopySignOp operation = llvm::dyn_cast<math::CopySignOp>(op)) {
// TODO: do nothing for now since it only comes up during mod, and there
// the component handles this correctly; do we need this later on?
Expand Down Expand Up @@ -828,7 +826,8 @@ namespace zk_ml_toolchain {

void handleZkMlOperation(Operation *op, const ComponentParameters &compParams) {
if (zkml::DotProductOp operation = llvm::dyn_cast<zkml::DotProductOp>(op)) {
handle_dot_product<PreLimbs, PostLimbs>(operation, zero_var, stack, bp, assignmnt, compParams);
nil::blueprint::handle_dot_product<PreLimbs, PostLimbs>(operation, zero_var, stack, bp, assignmnt,
compParams);
} else if (zkml::ArgMinOp operation = llvm::dyn_cast<zkml::ArgMinOp>(op)) {
auto nextIndexVar = put_into_assignment(stack.get_constant(operation.getNextIndex()));
nil::blueprint::handle_argmin<PreLimbs, PostLimbs>(operation, stack, bp, assignmnt, nextIndexVar,
Expand Down Expand Up @@ -868,9 +867,6 @@ namespace zk_ml_toolchain {
logger.debug("allocating memref");
logger << operation;
MemRefType type = operation.getType();
auto uses = operation->getResult(0).getUsers();
auto res = operation->getResult(0);
auto res2 = operation.getMemref();
// check for dynamic size
std::vector<int64_t> dims;
auto operands = operation.getOperands();
Expand Down
4 changes: 2 additions & 2 deletions mlir-assigner/include/mlir-assigner/parser/parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ namespace nil {
std::uint32_t target_prover_idx, const std::string &policy = "",
generation_mode gen_mode = generation_mode::ASSIGNMENTS & generation_mode::CIRCUIT,
print_format output_print_format = no_print) :
max_num_provers(max_num_provers),
gen_mode(gen_mode), print_output_format(output_print_format) {
gen_mode(gen_mode),
print_output_format(output_print_format), max_num_provers(max_num_provers) {
if (max_num_provers != 1) {
throw std::runtime_error(
"Currently only one prover is supported, please "
Expand Down
Loading

0 comments on commit ec742ee

Please sign in to comment.