diff --git a/libspu/mpc/swift/BUILD.bazel b/libspu/mpc/swift/BUILD.bazel new file mode 100644 index 00000000..340150ca --- /dev/null +++ b/libspu/mpc/swift/BUILD.bazel @@ -0,0 +1,152 @@ +# Copyright 2021 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") + +package(default_visibility = ["//visibility:public"]) + +spu_cc_library( + name = "type", + srcs = ["type.cc"], + hdrs = ["type.h"], + deps = [ + "//libspu/core:type", + "//libspu/mpc/common:pv2k", + ], +) + +spu_cc_test( + name = "type_test", + srcs = ["type_test.cc"], + deps = [ + ":type", + ], +) + +spu_cc_library( + name = "value", + srcs = ["value.cc"], + hdrs = ["value.h"], + deps = [ + ":type", + "//libspu/core:ndarray_ref", + "//libspu/mpc/utils:ring_ops", + ], +) + +spu_cc_library( + name = "io", + srcs = ["io.cc"], + hdrs = ["io.h"], + deps = [ + ":type", + ":value", + "//libspu/mpc:io_interface", + ], +) + +spu_cc_test( + name = "io_test", + srcs = ["io_test.cc"], + deps = [ + ":io", + "//libspu/mpc:io_test", + ], +) + +spu_cc_library( + name = "arithmetic", + srcs = ["arithmetic.cc"], + hdrs = ["arithmetic.h"], + deps = [ + ":type", + "value", + "commitment", + "//libspu/core:vectorize", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:circuits", + "//libspu/mpc/utils:ring_ops", + ], +) + +spu_cc_library( + name = "protocol", + srcs = ["protocol.cc"], + hdrs = ["protocol.h"], + deps = [ + ":arithmetic", + ":boolean", + ":value", + "//libspu/mpc/common:prg_state", + "//libspu/mpc/standard_shape:protocol", + ], +) + +spu_cc_test( + name = "protocol_test", + srcs = ["protocol_test.cc"], + deps = [ + ":protocol", + ":protocol_single_test", + ], +) + +spu_cc_library( + name = "protocol_single_test", + testonly = 1, + srcs = ["protocol_single_test.cc"], + hdrs = ["protocol_single_test.h"], + deps = [ + ":arithmetic", + ":boolean", + ":type", + ":value", + "//libspu/mpc:ab_api", + "//libspu/mpc:api", + "//libspu/mpc:api_test_params", + "//libspu/core:context", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:simulate", + "@com_google_googletest//:gtest", + ], + alwayslink = True, +) + +spu_cc_library( + name = "commitment", + srcs = ["commitment.cc"], + hdrs = ["commitment.h"], + deps = [ + "//libspu/core:prelude", + "@yacl//yacl/crypto/hash:blake3", + "@yacl//yacl/crypto/hash:hash_utils", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/link", + ], +) + +spu_cc_library( + name = "boolean", + srcs = ["boolean.cc"], + hdrs = ["boolean.h"], + deps = [ + ":arithmetic", + ":type", + ":value", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + ], +) \ No newline at end of file diff --git a/libspu/mpc/swift/arithmetic.cc b/libspu/mpc/swift/arithmetic.cc new file mode 100644 index 00000000..a352dbf6 --- /dev/null +++ b/libspu/mpc/swift/arithmetic.cc @@ -0,0 +1,1781 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/arithmetic.h" + +#include +#include +#include + +#include "libspu/core/type_util.h" +#include "libspu/core/vectorize.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/swift/commitment.h" +#include "libspu/mpc/swift/type.h" +#include "libspu/mpc/swift/value.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::swift { + + NdArrayRef getOrCreateCompactArray(const NdArrayRef& in) { + if (!in.isCompact()) { + return in.clone(); + } + + return in; + } + + NdArrayRef Jmp::proc(KernelEvalContext* ctx, const NdArrayRef& msg, + size_t rank_i, size_t rank_j, size_t rank_k, + std::string_view tag) { + auto const field = msg.eltype().as()->field(); + auto* comm = ctx->getState(); + auto ty = makeType(field); + + NdArrayRef res(ty, msg.shape()); + + bool inconsistent_bit = false; + + auto rank = comm->getRank(); + + if (rank == rank_i) { + // send v to P_k + comm->sendAsync(rank_k, msg, tag); + res = msg; + + // malicious action 1 : P_i send wrong msg + // comm->sendAsync(rank_k, ring_neg(msg), tag); + + // recv inconsistent_bit from P_k + auto recv_b_from_pk = comm->recv(rank_k, tag); + inconsistent_bit = recv_b_from_pk[0]; + + // exchange inconsistent bit between P_i and P_j + // reset inconsistent bit to b_i || b_j + std::array send_b; + send_b[0] = inconsistent_bit; + + // malicious action 2 : P_i send wrong inconsistent bit + // send_b[0] = inconsistent_bit ^ true; + + auto recv_b_from_pj = comm->recv(rank_j, tag); + comm->sendAsync(rank_j, absl::MakeSpan(send_b), tag); + inconsistent_bit = recv_b_from_pk[0] || recv_b_from_pj[0]; + + // std::cout << "consistent_bit of P_i: " << inconsistent_bit << std::endl; + + // broadcast Hash(v) + // without considering the situation that some party is silent + // which means the inconsistent bit of each party is all true or false + if (inconsistent_bit == true) { + std::string broadcast_msg(getOrCreateCompactArray(msg).data(), + msg.numel() * msg.elsize()); + auto broadcast_msg_hash = commit(0, broadcast_msg, tag); + yacl::ByteContainerView broadcast_msg_hash_bytes( + reinterpret_cast(broadcast_msg_hash.data()), + broadcast_msg_hash.size()); + auto all_hash_bytes = + yacl::link::AllGather(comm->lctx(), broadcast_msg_hash_bytes, tag); + std::vector all_hash(3); + for (int i = 0; i < 3; i++) { + all_hash[i] = + std::string(reinterpret_cast(all_hash_bytes[i].data()), + all_hash_bytes[i].size()); + + // for (unsigned char c : all_hash[i]) { + // std::cout << std::hex << std::uppercase << std::setfill('0') + // << std::setw(2) << static_cast(c); + // } + } + if (all_hash[rank_i] != all_hash[rank_j]) { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_i, rank_k); + } + else if (all_hash[rank_i] != all_hash[rank_k]) { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_i, rank_j); + } + else { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_i, rank_i); + } + } + } + if (rank == rank_j) { + res = msg; + + // send hash(v) to P_k + std::string msg_str(getOrCreateCompactArray(msg).data(), + msg.numel() * msg.elsize()); + + // malicious action 1 : P_j send wrong hash + // std::string msg_str(getOrCreateCompactArray(ring_neg(msg)).data(), + // msg.numel() * msg.elsize()); + + auto msg_hash = commit(rank_j, msg_str, tag); + + yacl::ByteContainerView msg_hash_bytes( + reinterpret_cast(msg_hash.data()), msg_hash.size()); + comm->sendAsync(rank_k, absl::MakeSpan(msg_hash_bytes), tag); + + // recv inconsistent_bit from P_k + auto recv_b_from_pk = comm->recv(rank_k, tag); + inconsistent_bit = recv_b_from_pk[0]; + + // exchange inconsistent bit between P_i and P_j + // reset inconsistent bit to b_i || b_j + std::array send_b; + send_b[0] = inconsistent_bit; + comm->sendAsync(rank_i, absl::MakeSpan(send_b), tag); + auto recv_b_from_pi = comm->recv(rank_i, tag); + inconsistent_bit = recv_b_from_pk[0] || recv_b_from_pi[0]; + + // std::cout << "consistent_bit of P_j: " << inconsistent_bit << std::endl; + + // broadcast Hash(v) + // without considering the situation that some party is silent + // which means the inconsistent bit of each party is all true or false + if (inconsistent_bit == true) { + std::string broadcast_msg(getOrCreateCompactArray(msg).data(), + msg.numel() * msg.elsize()); + auto broadcast_msg_hash = commit(0, broadcast_msg, tag); + yacl::ByteContainerView broadcast_msg_hash_bytes( + reinterpret_cast(broadcast_msg_hash.data()), + broadcast_msg_hash.size()); + auto all_hash_bytes = + yacl::link::AllGather(comm->lctx(), broadcast_msg_hash_bytes, tag); + std::vector all_hash(3); + for (int i = 0; i < 3; i++) { + all_hash[i] = + std::string(reinterpret_cast(all_hash_bytes[i].data()), + all_hash_bytes[i].size()); + + // for (unsigned char c : all_hash[i]) { + // std::cout << std::hex << std::uppercase << std::setfill('0') + // << std::setw(2) << static_cast(c); + // } + } + if (all_hash[rank_i] != all_hash[rank_j]) { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_j, rank_k); + } + else if (all_hash[rank_i] != all_hash[rank_k]) { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_j, rank_j); + } + else { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_j, rank_i); + } + } + } + if (rank == rank_k) { + // recv v and H_v from P_i and P_j respectively + auto res_v = comm->recv(rank_i, msg.eltype(), tag); + res_v = res_v.reshape(msg.shape()); + + auto recv_bytes = comm->recv(rank_j, tag); + + // check Hash(v) = H_v + + std::string recv_hash = std::string( + reinterpret_cast(recv_bytes.data()), recv_bytes.size()); + + std::string recv_msg_str(getOrCreateCompactArray(res_v).data(), + res_v.numel() * res_v.elsize()); + auto recv_msg_hash = commit(rank_j, recv_msg_str, tag); + + if (recv_msg_hash != recv_hash){ + inconsistent_bit = true; + } + + if (inconsistent_bit == false) { + // send inconsistent_bit to P_i and P_j + std::array send_b; + send_b[0] = inconsistent_bit; + + comm->sendAsync(rank_j, absl::MakeSpan(send_b), tag); + comm->sendAsync(rank_i, absl::MakeSpan(send_b), tag); + + // std::cout << "consistent_bit of P_k: " << inconsistent_bit << + // std::endl; + } + else { + SPDLOG_INFO("commit check fail for tag {}", tag); + inconsistent_bit = true; + + // send inconsistent_bit to P_i and P_j + std::array send_b; + send_b[0] = inconsistent_bit; + + comm->sendAsync(rank_j, absl::MakeSpan(send_b), tag); + comm->sendAsync(rank_i, absl::MakeSpan(send_b), tag); + + // std::cout << "consistent_bit of P_k: " << inconsistent_bit << + // std::endl; + + // broadcast Hash(v) + // without considering the situation that some party is silent + // which means the inconsistent bit of each party is all true or false + std::string broadcast_msg(getOrCreateCompactArray(res_v).data(), + res_v.numel() * res_v.elsize()); + auto broadcast_msg_hash = commit(0, broadcast_msg, tag); + yacl::ByteContainerView broadcast_msg_hash_bytes( + reinterpret_cast(broadcast_msg_hash.data()), + broadcast_msg_hash.size()); + auto all_hash_bytes = + yacl::link::AllGather(comm->lctx(), broadcast_msg_hash_bytes, tag); + std::vector all_hash(3); + for (int i = 0; i < 3; i++) { + all_hash[i] = + std::string(reinterpret_cast(all_hash_bytes[i].data()), + all_hash_bytes[i].size()); + + // for (unsigned char c : all_hash[i]) { + // std::cout << std::hex << std::uppercase << std::setfill('0') + // << std::setw(2) << static_cast(c); + // } + } + + if (all_hash[rank_i] != all_hash[rank_j]) { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_k, rank_k); + } + else if (all_hash[rank_i] != all_hash[rank_k]) { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_k, rank_j); + } + else { + SPDLOG_INFO( + "inconsistent check fail for tag {} from Party_{}, TTP = Party_{}", + tag, rank_k, rank_i); + } + } + res = res_v; + } + // TODO: inconsistent bit check + return res; + } + + NdArrayRef Sharing::proc(KernelEvalContext* ctx, const NdArrayRef& msg, + size_t owner, std::string_view tag) { + auto const field = msg.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + auto ty = makeType(field); + auto out_ty = makeType(field); + auto rank = comm->getRank(); + auto jmp = Jmp(); + + NdArrayRef alpha1(ty, msg.shape()); + NdArrayRef alpha2(ty, msg.shape()); + NdArrayRef beta(ty, msg.shape()); + NdArrayRef gamma(ty, msg.shape()); + + if (owner == 0) { + // P0, Pj together sample random alpha_j + auto [r0, r1] = + prg_state->genPrssPair(field, msg.shape(), PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha2 = r0; + alpha1 = r1; + } + if (rank == 1) { + alpha1 = r0; + } + if (rank == 2) { + alpha2 = r1; + } + + // parties sample random gamma + auto r2 = prg_state->genPubl(field, msg.shape()); + gamma = r2; + + // P0 send beta = v + alpha to P1 + if (rank == 0) { + beta = ring_add(msg, ring_add(alpha1, alpha2)); + comm->sendAsync(1, beta, "beta_01"); + } + if (rank == 1) { + beta = comm->recv(0, ty, "beta_01"); + beta = beta.reshape(msg.shape()); + } + + // P0 and P1 jmp-send beta to P2 + beta = jmp.proc(ctx, beta, 0, 1, 2, "beta_012"); + } + if (owner == 1) { + // P0, P1 together sample alpha1 + // P1, P2 together sample gamma + auto [r0, r1] = + prg_state->genPrssPair(field, msg.shape(), PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha1 = r1; + } + if (rank == 1) { + alpha1 = r0; + gamma = r1; + } + if (rank == 2) { + gamma = r0; + } + + // parties sample random alpha2 + auto r2 = prg_state->genPubl(field, msg.shape()); + alpha2 = r2; + + // P1 send beta = v + alpha to P2 + if (rank == 1) { + beta = ring_add(msg, ring_add(alpha1, alpha2)); + comm->sendAsync(2, beta, "beta_12"); + } + if (rank == 2) { + beta = comm->recv(1, ty, "beta_12"); + beta = beta.reshape(msg.shape()); + } + + // P1, P2 jmp-send beta + gamma to P0 + auto beta_plus_gamma = ring_add(beta, gamma); + beta_plus_gamma = + jmp.proc(ctx, beta_plus_gamma, 1, 2, 0, "beta_plus_gamma_120"); + } + if (owner == 2) { + // P0, P2 together sample alpha2 + // P1, P2 together sample gamma + auto [r0, r1] = + prg_state->genPrssPair(field, msg.shape(), PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha2 = r0; + } + if (rank == 1) { + gamma = r1; + } + if (rank == 2) { + alpha2 = r1; + gamma = r0; + } + + // parties sample random alpha1 + auto r2 = prg_state->genPubl(field, msg.shape()); + alpha1 = r2; + + // P2 send beta = v + alpha to P1 + if (rank == 2) { + beta = ring_add(msg, ring_add(alpha1, alpha2)); + comm->sendAsync(1, beta, "beta_12"); + } + if (rank == 1) { + beta = comm->recv(2, ty, "beta_12"); + beta = beta.reshape(msg.shape()); + } + + // P1, P2 jmp-send beta + gamma to P0 + auto beta_plus_gamma = ring_add(beta, gamma); + beta_plus_gamma = + jmp.proc(ctx, beta_plus_gamma, 2, 1, 0, "beta_plus_gamma_210"); + } + + return DISPATCH_ALL_FIELDS(field, [&]() { + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayView _alpha1(alpha1); + NdArrayView _alpha2(alpha2); + NdArrayView _beta(beta); + NdArrayView _gamma(gamma); + + NdArrayRef out(out_ty, msg.shape()); + NdArrayView _out(out); + + if (rank == 0) { + pforeach(0, msg.numel(), [&](int64_t idx) { + _out[idx][0] = _alpha1[idx]; + _out[idx][1] = _alpha2[idx]; + _out[idx][2] = _beta[idx] + _gamma[idx]; + }); + } + if (rank == 1) { + pforeach(0, msg.numel(), [&](int64_t idx) { + _out[idx][0] = _alpha1[idx]; + _out[idx][1] = _beta[idx]; + _out[idx][2] = _gamma[idx]; + }); + } + if (rank == 2) { + pforeach(0, msg.numel(), [&](int64_t idx) { + _out[idx][0] = _alpha2[idx]; + _out[idx][1] = _beta[idx]; + _out[idx][2] = _gamma[idx]; + }); + } + return out; + }); + } + + NdArrayRef JointSharing::proc(KernelEvalContext* ctx, const NdArrayRef& msg, + size_t rank_i, size_t rank_j, + std::string_view tag) { + auto const field = msg.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + auto ty = makeType(field); + auto out_ty = makeType(field); + auto rank = comm->getRank(); + + NdArrayRef out(out_ty, msg.shape()); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayView _out(out); + NdArrayView _msg(msg); + + if ((rank_i == 1 && rank_j == 2) || (rank_i == 2 && rank_j == 1)) { + // 0 0 r + // 0 v r - v + // 0 v r - v + auto r = prg_state->genPubl(field, msg.shape()); + auto r_v = ring_sub(r, msg); + NdArrayView _r_v(r_v); + NdArrayView _r(r); + + pforeach(0, msg.numel(), [&](int64_t idx) { + _out[idx][0] = ring2k_t(0); + _out[idx][1] = rank == 0 ? ring2k_t(0) : _msg[idx]; + _out[idx][2] = rank == 0 ? _r[idx] : _r_v[idx]; + }); + } + else if ((rank_i == 1 && rank_j == 0) || (rank_i == 0 && rank_j == 1)) { + // -v 0 r + // -v 0 r + // 0 0 r + auto r = prg_state->genPubl(field, msg.shape()); + auto neg_msg = ring_neg(msg); + NdArrayView _neg_msg(neg_msg); + NdArrayView _r(r); + + pforeach(0, msg.numel(), [&](int64_t idx) { + _out[idx][0] = rank == 2 ? ring2k_t(0) : _neg_msg[idx]; + _out[idx][1] = ring2k_t(0); + _out[idx][2] = _r[idx]; + }); + } + else if ((rank_i == 2 && rank_j == 0) || (rank_i == 0 && rank_j == 2)) { + // 0 -v r + // 0 0 r + // -v 0 r + auto r = prg_state->genPubl(field, msg.shape()); + auto neg_msg = ring_neg(msg); + NdArrayView _neg_msg(neg_msg); + NdArrayView _r(r); + + pforeach(0, msg.numel(), [&](int64_t idx) { + _out[idx][0] = rank == 2 ? _neg_msg[idx] : ring2k_t(0); + _out[idx][1] = rank == 0 ? _neg_msg[idx] : ring2k_t(0); + _out[idx][2] = _r[idx]; + }); + } + else { + SPU_THROW("Party idx wrong in Joint Sharing"); + } + return out; + }); + } + + NdArrayRef UnaryTest1::proc(KernelEvalContext* ctx, + const NdArrayRef& in) const { + // Sharing Test + // auto sharing = Sharing(); + // auto out = sharing.proc(ctx, in, 0, "sh test"); + // auto out = sharing.proc(ctx, in, 1, "sh test"); + // auto out = sharing.proc(ctx, in, 2, "sh test"); + + // Joint Sharing Test + // auto jsh = JointSharing(); + // auto out = jsh.proc(ctx, in, 1, 2, "jsh test"); + // auto out = jsh.proc(ctx, in, 1, 0, "jsh test"); + // auto out = jsh.proc(ctx, in, 2, 0, "jsh test"); + + // jmp test + auto jmp = Jmp(); + return jmp.proc(ctx, in, 0, 1, 2, "test jmp"); + } + + NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + auto* comm = ctx->getState(); + + const auto* in_ty = in.eltype().as(); + const auto field = in_ty->field(); + + auto rank = comm->getRank(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + using pshr_el_t = ring2k_t; + + NdArrayRef out(makeType(field), in.shape()); + NdArrayView _out(out); + NdArrayView _in(in); + + // 0, 0, v + // 0, v, 0 + // 0, v, 0 + + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx][0] = ring2k_t(0); + _out[idx][1] = rank == 0 ? ring2k_t(0) : _in[idx]; + _out[idx][2] = rank == 0 ? _in[idx] : ring2k_t(0); + }); + + return out; + }); + } + + NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + auto* comm = ctx->getState(); + const auto field = in.eltype().as()->field(); + auto numel = in.numel(); + auto rank = comm->getRank(); + auto ty = makeType(field); + auto jmp = Jmp(); + + return DISPATCH_ALL_FIELDS(field, [&] { + using pshr_el_t = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayRef out(makeType(field), in.shape()); + + NdArrayView _out(out); + NdArrayView _in(in); + + NdArrayRef alpha1(ty, in.shape()); + NdArrayRef alpha2(ty, in.shape()); + NdArrayRef beta(ty, in.shape()); + + if (rank == 0) { + alpha1 = getFirstShare(in); + alpha2 = getSecondShare(in); + } + if (rank == 1) { + alpha1 = getFirstShare(in); + beta = getSecondShare(in); + } + if (rank == 2) { + alpha2 = getFirstShare(in); + beta = getSecondShare(in); + } + + // P1, P2 -> P0 : beta + // P0, P1 -> P2 : alpha1 + // P2, P0 -> P1 : alpha2 + beta = jmp.proc(ctx, beta, 1, 2, 0, "beta"); + alpha1 = jmp.proc(ctx, alpha1, 0, 1, 2, "alpha1"); + alpha2 = jmp.proc(ctx, alpha2, 2, 0, 1, "alpha2"); + + NdArrayView _alpha1(alpha1); + NdArrayView _alpha2(alpha2); + NdArrayView _beta(beta); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx] = _beta[idx] - _alpha1[idx] - _alpha2[idx]; + }); + return out; + }); + } + + NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t rank_dst) const { + const auto field = in.eltype().as()->field(); + auto* comm = ctx->getState(); + auto rank = comm->getRank(); + auto jmp = Jmp(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using vshr_el_t = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayView _in(in); + auto out_ty = makeType(field, rank); + auto ty = makeType(field); + + NdArrayRef alpha1(ty, in.shape()); + NdArrayRef alpha2(ty, in.shape()); + NdArrayRef beta(ty, in.shape()); + + if (rank == 0) { + alpha1 = getFirstShare(in); + alpha2 = getSecondShare(in); + } + if (rank == 1) { + alpha1 = getFirstShare(in); + beta = getSecondShare(in); + } + if (rank == 2) { + alpha2 = getFirstShare(in); + beta = getSecondShare(in); + } + + if (rank_dst == 0) { + beta = jmp.proc(ctx, beta, 1, 2, 0, "beta"); + } + if (rank_dst == 1) { + alpha2 = jmp.proc(ctx, alpha2, 2, 0, 1, "alpha2"); + } + if (rank_dst == 2) { + alpha1 = jmp.proc(ctx, alpha1, 0, 1, 2, "alpha1"); + } + + if (rank == rank_dst) { + NdArrayView _alpha1(alpha1); + NdArrayView _alpha2(alpha2); + NdArrayView _beta(beta); + + NdArrayRef out(out_ty, in.shape()); + NdArrayView _out(out); + + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx] = _beta[idx] - _alpha1[idx] - _alpha2[idx]; + }); + return out; + } + else { + return makeConstantArrayRef(out_ty, in.shape()); + } + }); + } + + NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + const auto* in_ty = in.eltype().as(); + const auto field = in_ty->field(); + auto rank = comm->getRank(); + auto jmp = Jmp(); + + size_t owner_rank = in_ty->owner(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + + auto ty = makeType(field); + + NdArrayRef out(makeType(field), in.shape()); + NdArrayRef alpha1(ty, in.shape()); + NdArrayRef alpha2(ty, in.shape()); + NdArrayRef beta(ty, in.shape()); + NdArrayRef gamma(ty, in.shape()); + + if (owner_rank == 0) { + // P0, Pj together sample random alpha_j + auto [r0, r1] = prg_state->genPrssPair(field, in.shape(), + PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha2 = r0; + alpha1 = r1; + } + if (rank == 1) { + alpha1 = r0; + } + if (rank == 2) { + alpha2 = r1; + } + + // parties sample random gamma + auto r2 = prg_state->genPubl(field, in.shape()); + gamma = r2; + + // P0 send beta = v + alpha to P1 + if (rank == 0) { + beta = ring_add(in, ring_add(alpha1, alpha2)); + comm->sendAsync(1, beta, "v2a_01"); + } + if (rank == 1) { + beta = comm->recv(0, ty, "v2a_01"); + } + + // P0 and P1 jmp-send beta to P2 + beta = jmp.proc(ctx, beta, 0, 1, 2, "v2a_012"); + } + + NdArrayView _alpha1(alpha1); + NdArrayView _alpha2(alpha2); + NdArrayView _beta(beta); + NdArrayView _gamma(gamma); + NdArrayView _out(out); + + if (rank == 0) { + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx][0] = _alpha1[idx]; + _out[idx][1] = _alpha2[idx]; + _out[idx][2] = _beta[idx] + _gamma[idx]; + }); + } + if (rank == 1) { + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx][0] = _alpha1[idx]; + _out[idx][1] = _beta[idx]; + _out[idx][2] = _gamma[idx]; + }); + } + if (rank == 2) { + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx][0] = _alpha2[idx]; + _out[idx][1] = _beta[idx]; + _out[idx][2] = _gamma[idx]; + }); + } + return out; + }); + } + + NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto* in_ty = in.eltype().as(); + const auto field = in_ty->field(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = std::make_unsigned_t; + using shr_t = std::array; + + NdArrayRef out(makeType(field), in.shape()); + NdArrayView _out(out); + NdArrayView _in(in); + + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx][0] = -_in[idx][0]; + _out[idx][1] = -_in[idx][1]; + _out[idx][2] = -_in[idx][2]; + }); + + return out; + }); + } + + NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { + auto* prg_state = ctx->getState(); + auto* comm = ctx->getState(); + const auto field = ctx->getState()->getDefaultField(); + auto ty = makeType(field); + auto rank = comm->getRank(); + auto jmp = Jmp(); + + NdArrayRef alpha1(ty, shape); + NdArrayRef alpha2(ty, shape); + NdArrayRef beta(ty, shape); + NdArrayRef gamma(ty, shape); + + NdArrayRef out(makeType(field), shape); + + // Comparison only works for [-2^(k-2), 2^(k-2)] + auto [r0, r1] = + prg_state->genPrssPair(field, shape, PrgState::GenPrssCtrl::Both); + auto [r2, r3] = + prg_state->genPrssPair(field, shape, PrgState::GenPrssCtrl::Both); + + r0 = ring_rshift(r0, { 2 }); + r1 = ring_rshift(r1, { 2 }); + + if (rank == 0) { + alpha2 = r0; + alpha1 = r1; + } + if (rank == 1) { + alpha1 = r0; + beta = r1; + gamma = ring_rshift(r3, { 2 }); + } + if (rank == 2) { + alpha2 = r1; + beta = r0; + gamma = ring_rshift(r2, { 2 }); + } + + auto beta_plus_gamma = ring_add(beta, gamma); + beta_plus_gamma = + jmp.proc(ctx, beta_plus_gamma, 1, 2, 0, "beta_plus_gamma_120"); + + DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using ashr_t = std::array; + NdArrayView _alpha1(alpha1); + NdArrayView _alpha2(alpha2); + NdArrayView _beta(beta); + NdArrayView _gamma(gamma); + NdArrayView _beta_plus_gamma(beta_plus_gamma); + + NdArrayView _out(out); + pforeach(0, out.numel(), [&](int64_t idx) { + _out[idx][0] = rank == 2 ? _alpha2[idx] : _alpha1[idx]; + _out[idx][1] = rank == 0 ? _alpha2[idx] : _beta[idx]; + _out[idx][2] = rank == 0 ? _beta_plus_gamma[idx] : _gamma[idx]; + }); + }); + return out; + } + + //////////////////////////////////////////////////////////////////// + // add family + //////////////////////////////////////////////////////////////////// + NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + auto* comm = ctx->getState(); + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + + auto rank = comm->getRank(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayRef out(makeType(field), lhs.shape()); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0]; + _out[idx][1] = _lhs[idx][1]; + _out[idx][2] = _lhs[idx][2]; + if (rank == 0) _out[idx][2] += _rhs[idx]; + if (rank == 1 || rank == 2) _out[idx][1] += _rhs[idx]; + }); + return out; + }); + + } + + NdArrayRef AddAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using ashr_t = std::array; + + NdArrayRef out(makeType(field), lhs.shape()); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0] + _rhs[idx][0]; + _out[idx][1] = _lhs[idx][1] + _rhs[idx][1]; + _out[idx][2] = _lhs[idx][2] + _rhs[idx][2]; + }); + return out; + }); + } + + //////////////////////////////////////////////////////////////////// + // multiply family + //////////////////////////////////////////////////////////////////// + NdArrayRef MulAP::proc(KernelEvalContext*, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto* lhs_ty = lhs.eltype().as(); + const auto* rhs_ty = rhs.eltype().as(); + + SPU_ENFORCE(lhs_ty->field() == rhs_ty->field()); + const auto field = lhs_ty->field(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayRef out(makeType(field), lhs.shape()); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0] * _rhs[idx]; + _out[idx][1] = _lhs[idx][1] * _rhs[idx]; + _out[idx][2] = _lhs[idx][2] * _rhs[idx]; + }); + return out; + }); + } + + NdArrayRef MulAA_semi::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + // Debug only + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + auto rank = comm->getRank(); + auto ty = makeType(field); + auto shape = lhs.shape(); + + NdArrayRef alpha1(ty, shape); + NdArrayRef alpha2(ty, shape); + NdArrayRef beta_z(ty, shape); + NdArrayRef out(makeType(field), shape); + + // P0, Pj together sample random alpha_j + auto [r0, r1] = + prg_state->genPrssPair(field, lhs.shape(), PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha2 = r0; + alpha1 = r1; + } + if (rank == 1) { + alpha1 = r0; + } + if (rank == 2) { + alpha2 = r1; + } + + if (rank == 0) { + auto alpha_x1 = getFirstShare(lhs); + auto alpha_x2 = getSecondShare(rhs); + auto alpha_x = ring_add(alpha_x1, alpha_x2); + + auto alpha_y1 = getFirstShare(lhs); + auto alpha_y2 = getSecondShare(rhs); + auto alpha_y = ring_add(alpha_y1, alpha_y2); + auto Gamma = ring_mul(alpha_x, alpha_y); + auto Gammas = ring_rand_additive_splits(Gamma, 2); + comm->sendAsync(1, Gammas[0], "Gamma_i"); + comm->sendAsync(2, Gammas[1], "Gamma_i"); + } + if (rank == 1 || rank == 2) { + auto Gamma = comm->recv(0, ty, "Gamma_i"); + auto alpha_xi = getFirstShare(lhs); + auto alpha_yi = getFirstShare(rhs); + auto beta_x = getSecondShare(lhs); + auto beta_y = getSecondShare(rhs); + + auto beta_zi = + rank == 2 ? ring_mul(beta_x, beta_y) : ring_zeros(field, shape); + ring_sub_(beta_zi, ring_mul(beta_x, alpha_yi)); + ring_sub_(beta_zi, ring_mul(beta_y, alpha_xi)); + ring_add_(beta_zi, Gamma); + if (rank == 1) { + ring_add_(beta_zi, alpha1); + } + if (rank == 2) { + ring_add_(beta_zi, alpha2); + } + + comm->sendAsync((3 - rank), beta_zi, "beta_zi"); + auto beta_zi_ = comm->recv((3 - rank), ty, "beta_zi"); + beta_z = ring_add(beta_zi, beta_zi_); + } + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayView _out(out); + NdArrayView _alpha1(alpha1); + NdArrayView _alpha2(alpha2); + NdArrayView _beta_z(beta_z); + + pforeach(0, rhs.numel(), [&](int64_t idx) { + _out[idx][0] = rank == 2 ? _alpha2[idx] : _alpha1[idx]; + _out[idx][1] = rank == 0 ? _alpha2[idx] : _beta_z[idx]; + }); + + return out; + }); + } + + NdArrayRef MulPre_semi(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) { + // semi-honest mult based on RSS + // store the shares like RSS + // P0 : x0 x1 dummy + // P1 : x1 x2 dummy + // P2 : x2 x0 dummy + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + + std::vector r0(lhs.numel()); + std::vector r1(lhs.numel()); + + prg_state->fillPrssPair(r0.data(), r1.data(), r0.size(), + PrgState::GenPrssCtrl::Both); + + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + // z1 = (x1 * y1) + (x1 * y2) + (x2 * y1) + (r0 - r1); + pforeach(0, lhs.numel(), [&](int64_t idx) { + r0[idx] = (_lhs[idx][0] * _rhs[idx][0]) + (_lhs[idx][0] * _rhs[idx][1]) + + (_lhs[idx][1] * _rhs[idx][0]) + (r0[idx] - r1[idx]); + }); + + r1 = comm->rotate(r0, "mulpre"); // comm => 1, k + + NdArrayRef out(makeType(field), lhs.shape()); + NdArrayView _out(out); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = r1[idx]; + }); + + return out; + }); + } + + NdArrayRef RSS_A2P(KernelEvalContext* ctx, const NdArrayRef& in) { + auto* comm = ctx->getState(); + const auto field = in.eltype().as()->field(); + auto numel = in.numel(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using pshr_el_t = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayRef out(makeType(field), in.shape()); + NdArrayView _out(out); + NdArrayView _in(in); + + std::vector x2(numel); + + pforeach(0, numel, [&](int64_t idx) { x2[idx] = _in[idx][1]; }); + + auto x3 = comm->rotate(x2, "rss_a2p"); // comm => 1, k + + pforeach(0, numel, [&](int64_t idx) { + _out[idx] = _in[idx][0] + _in[idx][1] + x3[idx]; + }); + + return out; + }); + } + + NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + auto rank = comm->getRank(); + auto ty = makeType(field); + auto shape = lhs.shape(); + auto numel = lhs.numel(); + auto jmp = Jmp(); + + NdArrayRef alpha_z1(ty, shape); + NdArrayRef alpha_z2(ty, shape); + NdArrayRef gamma_z(ty, shape); + NdArrayRef out(makeType(field), shape); + NdArrayRef d(makeType(field), shape); + NdArrayRef e(makeType(field), shape); + + NdArrayRef chi_1(ty, shape); + NdArrayRef chi_2(ty, shape); + NdArrayRef Phi(ty, shape); + + NdArrayRef beta_z1_start(ty, shape); + NdArrayRef beta_z2_start(ty, shape); + + NdArrayRef beta_plus_gamma_z(ty, shape); + + // P0, Pj together sample random alpha_j + auto [r0, r1] = + prg_state->genPrssPair(field, lhs.shape(), PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha_z2 = r0; + alpha_z1 = r1; + } + if (rank == 1) { + alpha_z1 = r0; + gamma_z = r1; + } + if (rank == 2) { + alpha_z2 = r1; + gamma_z = r0; + } + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + + NdArrayView _d(d); + NdArrayView _e(e); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + NdArrayView _alpha_z1(alpha_z1); + NdArrayView _alpha_z2(alpha_z2); + NdArrayView _gamma_z(gamma_z); + + // generate RSS of e, d + // refer to Table 3 in Swift + // and init out + if (rank == 0) { + pforeach(0, numel, [&](int64_t idx) { + _d[idx][0] = _lhs[idx][1]; + _d[idx][1] = _lhs[idx][0]; + _e[idx][0] = _rhs[idx][1]; + _e[idx][1] = _rhs[idx][0]; + + _out[idx][0] = _alpha_z1[idx]; + _out[idx][1] = _alpha_z2[idx]; + }); + } + else if (rank == 1) { + pforeach(0, numel, [&](int64_t idx) { + _d[idx][0] = _lhs[idx][0]; + _d[idx][1] = _lhs[idx][2]; + _e[idx][0] = _rhs[idx][0]; + _e[idx][1] = _rhs[idx][2]; + + _out[idx][0] = _alpha_z1[idx]; + _out[idx][2] = _gamma_z[idx]; + }); + } + else if (rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + _d[idx][0] = _lhs[idx][2]; + _d[idx][1] = _lhs[idx][0]; + _e[idx][0] = _rhs[idx][2]; + _e[idx][1] = _rhs[idx][0]; + + _out[idx][0] = _alpha_z2[idx]; + _out[idx][2] = _gamma_z[idx]; + }); + } + + // p0, p1 : chi_1 = f1 + // p0, p2 : chi_2 = f0 + // p1, p2 : Phi = f2 - gamma_x * gamma_y + auto f = MulPre_semi(ctx, d, e); + + // Debug : correctness of Mulpre + // auto open_d = RSS_A2P(ctx, d); + // auto open_e = RSS_A2P(ctx, e); + // auto open_f = RSS_A2P(ctx, f); + // SPU_ENFORCE(ring_all_equal(ring_mul(open_d, open_e), open_f), + // "MulPre_semi error"); + + // if (rank == 1) { + // NdArrayRef gamma_x(ty, shape); + // NdArrayRef gamma_y(ty, shape); + // NdArrayView _gamma_x(gamma_x); + // NdArrayView _gamma_y(gamma_y); + // pforeach(0, numel, [&](int64_t idx) { + // _gamma_x[idx] = _lhs[idx][2]; + // _gamma_y[idx] = _rhs[idx][2]; + // }); + // comm->sendAsync(0, gamma_x, "gamma_x"); + // comm->sendAsync(0, gamma_y, "gamma_y"); + // } + // if (rank == 0) { + // auto gamma_x = comm->recv(1, ty, "gamma_x"); + // auto gamma_y = comm->recv(1, ty, "gamma_y"); + // NdArrayRef alpha_x(ty, shape); + // NdArrayRef alpha_y(ty, shape); + // NdArrayView _alpha_x(alpha_x); + // NdArrayView _alpha_y(alpha_y); + // pforeach(0, numel, [&](int64_t idx) { + // _alpha_x[idx] = _lhs[idx][0] + _lhs[idx][1]; + // _alpha_y[idx] = _rhs[idx][0] + _rhs[idx][1]; + // }); + // auto d_tmp = ring_add(gamma_x, alpha_x); + // auto e_tmp = ring_add(gamma_y, alpha_y); + // auto f_tmp = ring_mul(d_tmp, e_tmp); + // fmt::print("gamma_x\n"); + // ring_print(gamma_x); + // fmt::print("d_tmp\n"); + // ring_print(d_tmp); + // fmt::print("open_d\n"); + // ring_print(open_d); + // SPU_ENFORCE(ring_all_equal(d_tmp, open_d), "d_tmp != open_d"); + // SPU_ENFORCE(ring_all_equal(e_tmp, open_e), "e_tmp != open_e"); + // SPU_ENFORCE(ring_all_equal(f_tmp, open_f), "f_tmp != open_f"); + // } + + NdArrayView _f(f); + NdArrayView _chi_1(chi_1); + NdArrayView _chi_2(chi_2); + NdArrayView _Phi(Phi); + if (rank == 0) { + pforeach(0, numel, [&](int64_t idx) { + _chi_1[idx] = _f[idx][1]; + _chi_2[idx] = _f[idx][0]; + }); + } + else if (rank == 1) { + pforeach(0, numel, [&](int64_t idx) { + _chi_1[idx] = _f[idx][0]; + _Phi[idx] = _f[idx][1] - _lhs[idx][2] * _rhs[idx][2]; + }); + } + else if (rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + _chi_2[idx] = _f[idx][1]; + _Phi[idx] = _f[idx][0] - _lhs[idx][2] * _rhs[idx][2]; + }); + } + + // Debug: chi_1 + chi_2 + Phi =? f - gamma_x * gamma_y + // all send to P1 + // { + // if (rank == 0) { + // comm->sendAsync(1, chi_2, "chi_2"); + // } + // if (rank == 1) { + // chi_2 = comm->recv(0, ty, "chi_2"); + // auto test1 = ring_add(ring_add(chi_1, chi_2), Phi); + // auto test2 = + // ring_sub(open_f, ring_mul(getThirdShare(lhs), + // getThirdShare(rhs))); + // SPU_ENFORCE(ring_all_equal(test1, test2), + // "chi_1 + chi_2 + Phi != f - gamma_x * gamma_y"); + // } + // } + + NdArrayView _beta_z1_start(beta_z1_start); + NdArrayView _beta_z2_start(beta_z2_start); + // [beta*_z] = -(beta_x + gamma_x)[alpha_y] - (beta_y + gamma_y)[alpha_x] + // +[alpha_z] + [chi] + if (rank == 0) { + pforeach(0, numel, [&](int64_t idx) { + _beta_z1_start[idx] = -_lhs[idx][2] * _rhs[idx][0] - + _rhs[idx][2] * _lhs[idx][0] + _alpha_z1[idx] + + _chi_1[idx]; + _beta_z2_start[idx] = -_lhs[idx][2] * _rhs[idx][1] - + _rhs[idx][2] * _lhs[idx][1] + _alpha_z2[idx] + + _chi_2[idx]; + }); + } + else if (rank == 1) { + pforeach(0, numel, [&](int64_t idx) { + _beta_z1_start[idx] = -(_lhs[idx][1] + _lhs[idx][2]) * _rhs[idx][0] - + (_rhs[idx][1] + _rhs[idx][2]) * _lhs[idx][0] + + _alpha_z1[idx] + _chi_1[idx]; + }); + } + else if (rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + _beta_z2_start[idx] = -(_lhs[idx][1] + _lhs[idx][2]) * _rhs[idx][0] - + (_rhs[idx][1] + _rhs[idx][2]) * _lhs[idx][0] + + _alpha_z2[idx] + _chi_2[idx]; + }); + } + + beta_z1_start = jmp.proc(ctx, beta_z1_start, 0, 1, 2, "beta_z1_start"); + beta_z2_start = jmp.proc(ctx, beta_z2_start, 0, 2, 1, "beta_z2_start"); + auto beta_z_start = ring_add(beta_z1_start, beta_z2_start); + + NdArrayView _beta_z_start(beta_z_start); + NdArrayView _beta_plus_gamma_z(beta_plus_gamma_z); + if (rank == 1 || rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + // beta_z = beta*_z + beta_x * beta_y + Phi + _out[idx][1] = + _beta_z_start[idx] + _lhs[idx][1] * _rhs[idx][1] + _Phi[idx]; + + _beta_plus_gamma_z[idx] = _out[idx][1] + _out[idx][2]; + }); + } + beta_plus_gamma_z = + jmp.proc(ctx, beta_plus_gamma_z, 1, 2, 0, "beta_plus_gamma_z"); + if (rank == 0) { + pforeach(0, numel, + [&](int64_t idx) { _out[idx][2] = _beta_plus_gamma_z[idx]; }); + } + + return out; + }); + } + + //////////////////////////////////////////////////////////////////// + // matmul family + //////////////////////////////////////////////////////////////////// + NdArrayRef MatMulAP::proc(KernelEvalContext*, const NdArrayRef& x, + const NdArrayRef& y) const { + const auto field = x.eltype().as()->field(); + + NdArrayRef z(makeType(field), { x.shape()[0], y.shape()[1] }); + + auto x1 = getFirstShare(x); + auto x2 = getSecondShare(x); + auto x3 = getThirdShare(x); + + auto z1 = getFirstShare(z); + auto z2 = getSecondShare(z); + auto z3 = getThirdShare(z); + + ring_mmul_(z1, x1, y); + ring_mmul_(z2, x2, y); + ring_mmul_(z3, x3, y); + + return z; + } + + NdArrayRef MatMulPre_semi(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) { + // semi-honest mult based on RSS + // store the shares like RSS + // P0 : x0 x1 dummy + // P1 : x1 x2 dummy + // P2 : x2 x0 dummy + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + + auto M = lhs.shape()[0]; + auto N = rhs.shape()[1]; + + auto [r0, r1] = prg_state->genPrssPair(field, {M, N}, + PrgState::GenPrssCtrl::Both); + + NdArrayRef out(makeType(field), {M, N}); + auto o1 = getFirstShare(out); + auto o2 = getSecondShare(out); + + auto x1 = getFirstShare(lhs); + auto x2 = getSecondShare(lhs); + + auto y1 = getFirstShare(rhs); + auto y2 = getSecondShare(rhs); + + // o2 = (x1 * y1) + (x1 * y2) + (x2 * y1) + (r0 - r1); + auto t1 = ring_mmul(x1, y1); + auto t2 = ring_mmul(x1, y2); + auto t3 = ring_mmul(x2, y1); + auto tmp1 = ring_sum({t1, t2, t3}); + + auto tmp2 = comm->rotate(tmp1, "matmulpre"); + + ring_assign(o1, tmp1); + ring_assign(o2, tmp2); + + return out; + } + + NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + auto* prg_state = ctx->getState(); + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + auto rank = comm->getRank(); + auto ty = makeType(field); + auto jmp = Jmp(); + auto M = x.shape()[0]; + auto N = y.shape()[1]; + + NdArrayRef out(makeType(field), {M, N}); + NdArrayRef d(makeType(field), x.shape()); + NdArrayRef e(makeType(field), y.shape()); + + NdArrayRef chi_1(ty, {M, N}); + NdArrayRef chi_2(ty, {M, N}); + NdArrayRef Phi(ty, {M, N}); + + NdArrayRef beta_z1_start(ty, {M, N}); + NdArrayRef beta_z2_start(ty, {M, N}); + + NdArrayRef beta_plus_gamma_z(ty, {M, N}); + + // P0, Pj together sample random alpha_j + auto [r0, r1] = + prg_state->genPrssPair(field, {M, N}, PrgState::GenPrssCtrl::Both); + + auto d0 = getFirstShare(d); + auto d1 = getSecondShare(d); + + auto e0 = getFirstShare(e); + auto e1 = getSecondShare(e); + + auto x0 = getFirstShare(x); + auto x1 = getSecondShare(x); + auto x2 = getThirdShare(x); + + auto y0 = getFirstShare(y); + auto y1 = getSecondShare(y); + auto y2 = getThirdShare(y); + + auto z0 = getFirstShare(out); + auto z1 = getSecondShare(out); + auto z2 = getThirdShare(out); + + if (rank == 0) { + ring_assign(d0, x1); + ring_assign(d1, x0); + ring_assign(e0, y1); + ring_assign(e1, y0); + + ring_assign(z0, r1); + ring_assign(z1, r0); + } + if (rank == 1) { + ring_assign(d0, x0); + ring_assign(d1, x2); + ring_assign(e0, y0); + ring_assign(e1, y2); + + ring_assign(z0, r0); + ring_assign(z2, r1); + } + if (rank == 2) { + ring_assign(d0, x2); + ring_assign(d1, x0); + ring_assign(e0, y2); + ring_assign(e1, y0); + + ring_assign(z0, r1); + ring_assign(z2, r0); + } + + // p0, p1 : chi_1 = f1 + // p0, p2 : chi_2 = f0 + // p1, p2 : Phi = f2 - gamma_x * gamma_y + auto f = MatMulPre_semi(ctx, d, e); + + auto f0 = getFirstShare(f); + auto f1 = getSecondShare(f); + auto f2 = getThirdShare(f); + + if (rank == 0) { + ring_assign(chi_1, f1); + ring_assign(chi_2, f0); + } + if (rank == 1) { + ring_assign(chi_1, f0); + auto tmp1 = ring_sub(f1, ring_mmul(x2, y2)); + ring_assign(Phi, tmp1); + } + if (rank == 2) { + ring_assign(chi_2, f1); + auto tmp1 = ring_sub(f0, ring_mmul(x2, y2)); + ring_assign(Phi, tmp1); + } + + // [beta*_z] = -(beta_x + gamma_x)[alpha_y] - (beta_y + gamma_y)[alpha_x] + // +[alpha_z] + [chi] + if (rank == 0) { + // auto tmp2 = ring_neg(ring_mmul(x2, y0)); + // auto tmp3 = ring_neg(ring_mmul(x0, y2)); + beta_z1_start = ring_sum({ring_neg(ring_mmul(x2, y0)), ring_neg(ring_mmul(x0, y2)), z0, chi_1}); + beta_z2_start = ring_sum({ring_neg(ring_mmul(x2, y1)), ring_neg(ring_mmul(x1, y2)), z1, chi_2}); + } + if (rank == 1) { + auto tmp2 = ring_neg(ring_add(x1, x2)); + auto tmp3 = ring_neg(ring_add(y1, y2)); + tmp2 = ring_mmul(tmp2, y0); + tmp3 = ring_mmul(x0, tmp3); + beta_z1_start = ring_sum({tmp2, tmp3, z0, chi_1}); + } + if (rank == 2) { + auto tmp2 = ring_neg(ring_add(x1, x2)); + auto tmp3 = ring_neg(ring_add(y1, y2)); + tmp2 = ring_mmul(tmp2, y0); + tmp3 = ring_mmul(x0, tmp3); + beta_z2_start = ring_sum({tmp2, tmp3, z0, chi_2}); + } + beta_z1_start = jmp.proc(ctx, beta_z1_start, 0, 1, 2, "beta_z1_start"); + beta_z2_start = jmp.proc(ctx, beta_z2_start, 0, 2, 1, "beta_z2_start"); + auto beta_z_start = ring_add(beta_z1_start, beta_z2_start); + + if (rank == 1 || rank == 2) { + // beta_z = beta*_z + beta_x * beta_y + Phi + ring_assign(z1, ring_sum({beta_z_start, ring_mmul(x1, y1), Phi})); + ring_assign(beta_plus_gamma_z, ring_add(z1, z2)); + } + beta_plus_gamma_z = jmp.proc(ctx, beta_plus_gamma_z, 1, 2, 0, "beta_plus_gamma_z"); + if (rank == 0) { + ring_assign(z2, beta_plus_gamma_z); + } + + return out; + + } + + NdArrayRef LShiftA::proc(KernelEvalContext*, const NdArrayRef& in, + const Sizes& bits) const { + const auto* in_ty = in.eltype().as(); + const auto field = in_ty->field(); + bool is_splat = bits.size() == 1; + return DISPATCH_ALL_FIELDS(field, [&]() { + using shr_t = std::array; + + NdArrayRef out(makeType(field), in.shape()); + NdArrayView _out(out); + NdArrayView _in(in); + pforeach(0, in.numel(), [&](int64_t idx) { + auto shift_bit = is_splat ? bits[0] : bits[idx]; + _out[idx][0] = _in[idx][0] << shift_bit; + _out[idx][1] = _in[idx][1] << shift_bit; + _out[idx][2] = _in[idx][2] << shift_bit; + }); + + return out; + }); + } + + std::pair TruncA::Trgen(KernelEvalContext* ctx, + int64_t bits, FieldType field, int64_t numel) const { + auto* prg_state = ctx->getState(); + auto* comm = ctx->getState(); + auto ty_ring = makeType(field); + auto ashrty = makeType(field); + auto rank = comm->getRank(); + auto shape = {numel}; + const int64_t k = SizeOf(field) * 8; + + auto jsh = JointSharing(); + auto dotp = MatMulAA(); + auto a2p = A2P(); + + NdArrayRef r(ashrty, shape); + NdArrayRef r1(ty_ring, shape); + NdArrayRef r2(ty_ring, shape); + NdArrayRef rd(ashrty, shape); + + NdArrayRef public_const1(ty_ring, {k - bits}); + NdArrayRef public_const2(ty_ring, {k}); + NdArrayRef x(ashrty, {1, k - bits}); + NdArrayRef y(ashrty, {k - bits, 1}); + NdArrayRef p(ashrty, {1, k}); + NdArrayRef q(ashrty, {k, 1}); + NdArrayRef tmp(ashrty, {1}); + NdArrayRef A(ashrty, shape); + NdArrayRef B(ashrty, shape); + + // pack bits together + NdArrayRef r1_bits(ty_ring, {numel * k}); + NdArrayRef r2_bits(ty_ring, {numel * k}); + NdArrayRef r1_bits_share(ashrty, {numel * k}); + NdArrayRef r2_bits_share(ashrty, {numel * k}); + + // P_0 and P_j generate r_j by PRG + // P0.prg_r0 = P2.prg_r1 = r2 + // P0.prg_r1 = P1.prg_r0 = r1 + auto [prg_r0, prg_r1] = prg_state->genPrssPair(field, shape, + PrgState::GenPrssCtrl::Both); + + // actuall, for the trunc pair: r, rd + // they should satisfy: rd = arshift(r, d) + // but in swift, which generate [[ยท]] share of each bit + // and use the following expression to calculate r and rd + // r = \Sigma_{i=0}^{k-1} (2^i * r[i]) + // rd = \Sigma_{i=d}^{k-1} (2^{i-d} * r[i]) + // so in swift : r = rshift(r, d) + // which cause the truncation result to be wrong + // so we need to guarantee the msb of r is 0, + // so that arshift(r, d) = rshift(r, d) + ring_rshift_(prg_r0, {static_cast(1)}); + ring_rshift_(prg_r1, {static_cast(1)}); + if (rank == 0){ + r1 = prg_r1; + r2 = prg_r0; + } + if (rank == 1) { + r1 = prg_r0; + } + if (rank == 2) { + r2 = prg_r1; + } + + DISPATCH_ALL_FIELDS(field, [&](){ + using el_t = ring2k_t; + using shr_t = std::array; + + NdArrayView _r1(r1); + NdArrayView _r2(r2); + NdArrayView _r1_bits(r1_bits); + NdArrayView _r2_bits(r2_bits); + + // bit decompose r1 and r2 + pforeach(0, numel, [&](int64_t idx) { + for (int64_t i = 0; i < k; i++){ + _r1_bits[idx * k + i] = static_cast((_r1[idx] >> i) & 0x1); + _r2_bits[idx * k + i] = static_cast((_r2[idx] >> i) & 0x1); + } + }); + + // joint share r1_bits and r2_bits + r1_bits_share = jsh.proc(ctx, r1_bits, 0, 1, "r1_bits share"); + r2_bits_share = jsh.proc(ctx, r2_bits, 0, 2, "r2_bits share"); + + // for each r in batch: + // A = x \cdot y + // B = p \cdot q + NdArrayView _x(x); + NdArrayView _y(y); + NdArrayView _p(p); + NdArrayView _q(q); + NdArrayView _r1_bits_share(r1_bits_share); + NdArrayView _r2_bits_share(r2_bits_share); + + // public_const1 = 2 ^ {i - bits + 1} for i \in {d, ..., k - 1} + // public_const2 = 2 ^ {i + 1} for i \in {0, 1, ..., k - 1} + NdArrayView _public_const1(public_const1); + NdArrayView _public_const2(public_const2); + for(int64_t i = bits; i < k; i++) { + _public_const1[i - bits] = (static_cast(1) << (i - bits + 1)); + } + for(int64_t i = 0; i < k; i++) { + _public_const2[i] = (static_cast(1) << (i + 1)); + } + + NdArrayView _tmp(tmp); + NdArrayView _A(A); + NdArrayView _B(B); + pforeach(0, numel, [&](int64_t idx) { + for(int64_t i = bits; i < k; i++) { + // MulAP + _x[i - bits][0] = (ring2k_t(1) << (i - bits + 1)) * _r1_bits_share[idx * k + i][0]; + _x[i - bits][1] = (ring2k_t(1) << (i - bits + 1)) * _r1_bits_share[idx * k + i][1]; + _x[i - bits][2] = (ring2k_t(1) << (i - bits + 1)) * _r1_bits_share[idx * k + i][2]; + + _y[i - bits][0] = _r2_bits_share[idx * k + i][0]; + _y[i - bits][1] = _r2_bits_share[idx * k + i][1]; + _y[i - bits][2] = _r2_bits_share[idx * k + i][2]; + } + for(int64_t i = 0; i < k; i++) { + // MulAP + _p[i][0] = (ring2k_t(1) << (i + 1)) * _r1_bits_share[idx * k + i][0]; + _p[i][1] = (ring2k_t(1) << (i + 1)) * _r1_bits_share[idx * k + i][1]; + _p[i][2] = (ring2k_t(1) << (i + 1)) * _r1_bits_share[idx * k + i][2]; + + _q[i][0] = _r2_bits_share[idx * k + i][0]; + _q[i][1] = _r2_bits_share[idx * k + i][1]; + _q[i][2] = _r2_bits_share[idx * k + i][2]; + } + + // x \cdot y + // x.reshape({1, k - bits}); + // y.reshape({k - bits, 1}); + tmp = dotp.proc(ctx, x, y); + _A[idx][0] = _tmp[0][0]; + _A[idx][1] = _tmp[0][1]; + _A[idx][2] = _tmp[0][2]; + + // p \cdot q + // p.reshape({1, k}); + // q.reshape({k, 1}); + tmp = dotp.proc(ctx, p, q); + _B[idx][0] = _tmp[0][0]; + _B[idx][1] = _tmp[0][1]; + _B[idx][2] = _tmp[0][2]; + }); + + NdArrayView _r(r); + NdArrayView _rd(rd); + + pforeach(0, numel, [&](int64_t idx) { + // use tmp as sum + _rd[idx][0] = (ring2k_t)0; + _rd[idx][1] = (ring2k_t)0; + _rd[idx][2] = (ring2k_t)0; + for (int64_t i = bits; i < k; i++){ + _rd[idx][0] += (((ring2k_t)1 << (i - bits)) * (_r1_bits_share[idx * k + i][0] + _r2_bits_share[idx * k + i][0])); + _rd[idx][1] += (((ring2k_t)1 << (i - bits)) * (_r1_bits_share[idx * k + i][1] + _r2_bits_share[idx * k + i][1])); + _rd[idx][2] += (((ring2k_t)1 << (i - bits)) * (_r1_bits_share[idx * k + i][2] + _r2_bits_share[idx * k + i][2])); + } + _rd[idx][0] -= _A[idx][0]; + _rd[idx][1] -= _A[idx][1]; + _rd[idx][2] -= _A[idx][2]; + + + _r[idx][0] = (ring2k_t)0; + _r[idx][1] = (ring2k_t)0; + _r[idx][2] = (ring2k_t)0; + for (int64_t i = 0; i < k; i++){ + _r[idx][0] += (((ring2k_t)1 << (i)) * (_r1_bits_share[idx * k + i][0] + _r2_bits_share[idx * k + i][0])); + _r[idx][1] += (((ring2k_t)1 << (i)) * (_r1_bits_share[idx * k + i][1] + _r2_bits_share[idx * k + i][1])); + _r[idx][2] += (((ring2k_t)1 << (i)) * (_r1_bits_share[idx * k + i][2] + _r2_bits_share[idx * k + i][2])); + } + _r[idx][0] -= _B[idx][0]; + _r[idx][1] -= _B[idx][1]; + _r[idx][2] -= _B[idx][2]; + }); + }); + return std::make_pair(r, rd); + } + + NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, + size_t bits, SignType sign) const { + const auto field = x.eltype().as()->field(); + auto* comm = ctx->getState(); + auto rank = comm->getRank(); + + NdArrayRef out(makeType(field), x.shape()); + auto numel = x.numel(); + auto a2p = A2P(); + + auto [r, rd] = TruncA::Trgen(ctx, static_cast(bits), field, numel); + + r.reshape(x.shape()); + rd.reshape(x.shape()); + + NdArrayRef x_minux_r_share(makeType(field), x.shape()); + DISPATCH_ALL_FIELDS(field, [&](){ + using el_t = ring2k_t; + using shr_t = std::array; + + NdArrayView _x_minux_r_share(x_minux_r_share); + NdArrayView _r(r); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { + _x_minux_r_share[idx][0] = _x[idx][0] - _r[idx][0]; + _x_minux_r_share[idx][1] = _x[idx][1] - _r[idx][1]; + _x_minux_r_share[idx][2] = _x[idx][2] - _r[idx][2]; + }); + + auto x_minus_r = a2p.proc(ctx, x_minux_r_share); + + auto x_minus_r_d = ring_arshift(x_minus_r, {static_cast(bits)}); + + NdArrayView _out(out); + NdArrayView _rd(rd); + NdArrayView _x_minus_r_d(x_minus_r_d); + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = _rd[idx][0]; + _out[idx][1] = _rd[idx][1]; + _out[idx][2] = _rd[idx][2]; + if (rank == 0) _out[idx][2] += _x_minus_r_d[idx]; + if (rank == 1 || rank == 2) _out[idx][1] += _x_minus_r_d[idx]; + }); + + }); + + // res = (x - r)^d + r^d + return out.as(x.eltype()); + + } + +} // namespace spu::mpc::swift \ No newline at end of file diff --git a/libspu/mpc/swift/arithmetic.h b/libspu/mpc/swift/arithmetic.h new file mode 100644 index 00000000..7f38d105 --- /dev/null +++ b/libspu/mpc/swift/arithmetic.h @@ -0,0 +1,272 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::swift { + +class Jmp { + public: + // rank_i : send msg + // rank_j : send H(msg) + // rank_k : receive + // Pi, Pj -> Pj : msg + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& msg, size_t rank_i, + size_t rank_j, size_t rank_k, std::string_view tag); +}; + +class Sharing { + public: + // owner shares msg to parties + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& msg, size_t owner, + std::string_view tag); +}; + +class JointSharing { + public: + // Pi, Pj jonit generate share of a value that is known to both + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& msg, size_t rank_i, + size_t rank_j, std::string_view tag); +}; + +class UnaryTest1 : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "negate_a_tmp"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class P2A : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "p2a"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class A2P : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "a2p"; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class A2V : public RevealToKernel { + public: + static constexpr const char* kBindName() { return "a2v"; } + + Kind kind() const override { return Kind::Dynamic; } + + ce::CExpr latency() const override { + // 1 * send/recv: 1 + return ce::Const(1); + } + + ce::CExpr comm() const override { + // 1 * rotate: k + return ce::K(); + } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t rank_dst) const override; +}; + +class V2A : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "v2a"; } + + Kind kind() const override { return Kind::Dynamic; } + + ce::CExpr latency() const override { + // 1 * rotate: 1 + return ce::Const(1); + } + + ce::CExpr comm() const override { + // 1 * rotate: k + return ce::K(); + } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class NegateA : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "negate_a"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class RandA : public RandKernel { + public: + static constexpr const char* kBindName() { return "rand_a"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override; +}; + +//////////////////////////////////////////////////////////////////// +// add family +//////////////////////////////////////////////////////////////////// +class AddAP : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "add_ap"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class AddAA : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "add_aa"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +//////////////////////////////////////////////////////////////////// +// multiply family +//////////////////////////////////////////////////////////////////// +class MulAP : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "mul_ap"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class MulAA_semi : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "mul_aa_semi"; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class MulAA : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "mul_aa"; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +//////////////////////////////////////////////////////////////////// +// matmul family +//////////////////////////////////////////////////////////////////// +class MatMulAP : public MatmulKernel { + public: + static constexpr const char* kBindName() { return "mmul_ap"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +class MatMulAA : public MatmulKernel { + public: + static constexpr const char* kBindName() { return "mmul_aa"; } + + ce::CExpr latency() const override { + // 1 * rotate: 1 + return ce::Const(1); + } + + ce::CExpr comm() const override { + // 1 * rotate: k + auto m = ce::Variable("m", "rows of lhs"); + auto n = ce::Variable("n", "cols of rhs"); + return ce::K() * m * n; + } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +class LShiftA : public ShiftKernel { + public: + static constexpr const char* kBindName() { return "lshift_a"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Sizes& bits) const override; +}; + +class TruncA : public TruncAKernel { + public: + static constexpr const char* kBindName() { return "trunc_a"; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t bits, + SignType sign) const override; + + bool hasMsbError() const override { return true; } + + TruncLsbRounding lsbRounding() const override { + return TruncLsbRounding::Random; + } + + std::pair Trgen(KernelEvalContext* ctx, int64_t bits, FieldType field, int64_t numel) const; +}; + +} // namespace spu::mpc::swift \ No newline at end of file diff --git a/libspu/mpc/swift/boolean.cc b/libspu/mpc/swift/boolean.cc new file mode 100644 index 00000000..4315df3a --- /dev/null +++ b/libspu/mpc/swift/boolean.cc @@ -0,0 +1,564 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/boolean.h" +#include "libspu/mpc/swift/arithmetic.h" + +#include + +#include "libspu/core/bit_utils.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/swift/type.h" +#include "libspu/mpc/swift/value.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::swift{ +namespace{ + +size_t getNumBits(const NdArrayRef& in) { + if (in.eltype().isa()) { + const auto field = in.eltype().as()->field(); + return DISPATCH_ALL_FIELDS(field, + [&]() { return maxBitWidth(in); }); + } else if (in.eltype().isa()) { + return in.eltype().as()->nbits(); + } else { + SPU_THROW("should not be here, {}", in.eltype()); + } +} +} + +NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto field = in.eltype().as()->field(); + auto* comm = ctx->getState(); + auto rank = comm->getRank(); + auto ty = makeType(field); + auto jmp = Jmp(); + + NdArrayRef out(makeType(field), in.shape()); + + NdArrayRef alpha1(ty, in.shape()); + NdArrayRef alpha2(ty, in.shape()); + NdArrayRef beta(ty, in.shape()); + + if (rank == 0) { + alpha1 = getFirstShare(in); + alpha2 = getSecondShare(in); + } + if (rank == 1){ + alpha1 = getFirstShare(in); + beta = getSecondShare(in); + } + if(rank == 2){ + alpha2 = getFirstShare(in); + beta = getSecondShare(in); + } + + // P1, P2 -> P0 : beta + // P0, P1 -> P2 : alpha1 + // P2, P0 -> P1 : alpha2 + beta = jmp.proc(ctx, beta, 1, 2, 0, "beta"); + alpha1 = jmp.proc(ctx, alpha1, 0, 1, 2, "alpha1"); + alpha2 = jmp.proc(ctx, alpha2, 2, 0, 1, "alpha2"); + + out = ring_xor(ring_xor(beta, alpha1), alpha2); + + return out.as(makeType(field)); +} + +NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto field = in.eltype().as()->field(); + auto* comm = ctx->getState(); + + auto rank = comm->getRank(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using shr_el_t = ring2k_t; + using shr_t = std::array; + using pshr_el_t = ring2k_t; + + NdArrayRef out(makeType(field), in.shape()); + NdArrayView _out(out); + NdArrayView _in(in); + + // 0, 0, v + // 0, v, 0 + // 0, v, 0 + + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx][0] = ring2k_t(0); + _out[idx][1] = rank == 0 ? ring2k_t(0) : _in[idx]; + _out[idx][2] = rank == 0 ? _in[idx] : ring2k_t(0); + }); + + return out; + }); +} + + NdArrayRef XorBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + auto* comm = ctx->getState(); + const auto* lhs_ty = lhs.eltype().as(); + + const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); + const auto field = lhs_ty->field(); + + auto rank = comm->getRank(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using ashr_t = std::array; + + NdArrayRef out(makeType(field, out_nbits), lhs.shape()); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0]; + _out[idx][1] = _lhs[idx][1]; + _out[idx][2] = _lhs[idx][2]; + if (rank == 0) _out[idx][2] ^= _rhs[idx]; + if (rank == 1 || rank == 2) _out[idx][1] ^= _rhs[idx]; + }); + return out.as(makeType(field, out_nbits)); + }); + } + + NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto* lhs_ty = lhs.eltype().as(); + + const auto field = lhs_ty->field(); + + const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using shr_t = std::array; + + NdArrayRef out(makeType(field, out_nbits), lhs.shape()); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0] ^ _rhs[idx][0]; + _out[idx][1] = _lhs[idx][1] ^ _rhs[idx][1]; + _out[idx][2] = _lhs[idx][2] ^ _rhs[idx][2]; + }); + return out.as(makeType(field, out_nbits)); + }); + } + + NdArrayRef AndBP::proc(KernelEvalContext*, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto* lhs_ty = lhs.eltype().as(); + const auto field = lhs_ty->field(); + const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + + NdArrayRef out(makeType(field, out_nbits), lhs.shape()); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = _lhs[idx][0] & _rhs[idx]; + _out[idx][1] = _lhs[idx][1] & _rhs[idx]; + _out[idx][2] = _lhs[idx][2] & _rhs[idx]; + }); + return out.as(makeType(field, out_nbits)); + }); + } + + NdArrayRef AndPre_semi(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) { + // semi-honest mult based on RSS + // store the shares like RSS + // P0 : x0 x1 dummy + // P1 : x1 x2 dummy + // P2 : x2 x0 dummy + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + + const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + + std::vector r0(lhs.numel()); + std::vector r1(lhs.numel()); + + prg_state->fillPrssPair(r0.data(), r1.data(), r0.size(), + PrgState::GenPrssCtrl::Both); + + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + + // z1 = (x1 & y1) ^ (x1 & y2) ^ (x2 & y1) ^ (r0 ^ r1); + pforeach(0, lhs.numel(), [&](int64_t idx) { + r0[idx] = (_lhs[idx][0] & _rhs[idx][0]) ^ (_lhs[idx][0] & _rhs[idx][1]) ^ + (_lhs[idx][1] & _rhs[idx][0]) ^ (r0[idx] ^ r1[idx]); + }); + + r1 = comm->rotate(r0, "andpre"); // comm => 1, k + + NdArrayRef out(makeType(field, out_nbits), lhs.shape()); + NdArrayView _out(out); + + pforeach(0, lhs.numel(), [&](int64_t idx) { + _out[idx][0] = r0[idx]; + _out[idx][1] = r1[idx]; + }); + + return out.as(makeType(field, out_nbits)); + }); + } + + NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* prg_state = ctx->getState(); + const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); + auto rank = comm->getRank(); + auto ty = makeType(field); + auto shape = lhs.shape(); + auto numel = lhs.numel(); + auto jmp = Jmp(); + + NdArrayRef alpha_z1(ty, shape); + NdArrayRef alpha_z2(ty, shape); + NdArrayRef gamma_z(ty, shape); + NdArrayRef out(makeType(field, out_nbits), shape); + NdArrayRef d(makeType(field, out_nbits), shape); + NdArrayRef e(makeType(field, out_nbits), shape); + + NdArrayRef chi_1(ty, shape); + NdArrayRef chi_2(ty, shape); + NdArrayRef Phi(ty, shape); + + NdArrayRef beta_z1_start(ty, shape); + NdArrayRef beta_z2_start(ty, shape); + + NdArrayRef beta_plus_gamma_z(ty, shape); + + // P0, Pj together sample random alpha_j + auto [r0, r1] = + prg_state->genPrssPair(field, lhs.shape(), PrgState::GenPrssCtrl::Both); + if (rank == 0) { + alpha_z2 = r0; + alpha_z1 = r1; + } + if (rank == 1) { + alpha_z1 = r0; + gamma_z = r1; + } + if (rank == 2) { + alpha_z2 = r1; + gamma_z = r0; + } + + return DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + + NdArrayView _d(d); + NdArrayView _e(e); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + NdArrayView _alpha_z1(alpha_z1); + NdArrayView _alpha_z2(alpha_z2); + NdArrayView _gamma_z(gamma_z); + + // generate RSS of e, d + // refer to Table 3 in Swift + // and init out + if (rank == 0) { + pforeach(0, numel, [&](int64_t idx) { + _d[idx][0] = _lhs[idx][1]; + _d[idx][1] = _lhs[idx][0]; + _e[idx][0] = _rhs[idx][1]; + _e[idx][1] = _rhs[idx][0]; + + _out[idx][0] = _alpha_z1[idx]; + _out[idx][1] = _alpha_z2[idx]; + }); + } + else if (rank == 1) { + pforeach(0, numel, [&](int64_t idx) { + _d[idx][0] = _lhs[idx][0]; + _d[idx][1] = _lhs[idx][2]; + _e[idx][0] = _rhs[idx][0]; + _e[idx][1] = _rhs[idx][2]; + + _out[idx][0] = _alpha_z1[idx]; + _out[idx][2] = _gamma_z[idx]; + }); + } + else if (rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + _d[idx][0] = _lhs[idx][2]; + _d[idx][1] = _lhs[idx][0]; + _e[idx][0] = _rhs[idx][2]; + _e[idx][1] = _rhs[idx][0]; + + _out[idx][0] = _alpha_z2[idx]; + _out[idx][2] = _gamma_z[idx]; + }); + } + + // p0, p1 : chi_1 = f1 + // p0, p2 : chi_2 = f0 + // p1, p2 : Phi = f2 ^ gamma_x & gamma_y + auto f = AndPre_semi(ctx, d, e); + + NdArrayView _f(f); + NdArrayView _chi_1(chi_1); + NdArrayView _chi_2(chi_2); + NdArrayView _Phi(Phi); + if (rank == 0) { + pforeach(0, numel, [&](int64_t idx) { + _chi_1[idx] = _f[idx][1]; + _chi_2[idx] = _f[idx][0]; + }); + } + else if (rank == 1) { + pforeach(0, numel, [&](int64_t idx) { + _chi_1[idx] = _f[idx][0]; + _Phi[idx] = _f[idx][1] ^ (_lhs[idx][2] & _rhs[idx][2]); + }); + } + else if (rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + _chi_2[idx] = _f[idx][1]; + _Phi[idx] = _f[idx][0] ^ (_lhs[idx][2] & _rhs[idx][2]); + }); + } + + NdArrayView _beta_z1_start(beta_z1_start); + NdArrayView _beta_z2_start(beta_z2_start); + // [beta*_z] = -(beta_x + gamma_x)[alpha_y] - (beta_y + gamma_y)[alpha_x] + // +[alpha_z] + [chi] + if (rank == 0) { + pforeach(0, numel, [&](int64_t idx) { + _beta_z1_start[idx] = (_lhs[idx][2] & _rhs[idx][0]) ^ + (_rhs[idx][2] & _lhs[idx][0]) ^ _alpha_z1[idx] ^ + _chi_1[idx]; + _beta_z2_start[idx] =( _lhs[idx][2] & _rhs[idx][1]) ^ + (_rhs[idx][2] & _lhs[idx][1]) ^ _alpha_z2[idx] ^ + _chi_2[idx]; + }); + } + else if (rank == 1) { + pforeach(0, numel, [&](int64_t idx) { + _beta_z1_start[idx] = ((_lhs[idx][1] ^ _lhs[idx][2]) & _rhs[idx][0]) ^ + ((_rhs[idx][1] ^ _rhs[idx][2]) & _lhs[idx][0]) ^ + _alpha_z1[idx] ^ _chi_1[idx]; + }); + } + else if (rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + _beta_z2_start[idx] = ((_lhs[idx][1] ^ _lhs[idx][2]) & _rhs[idx][0]) ^ + ((_rhs[idx][1] ^ _rhs[idx][2]) & _lhs[idx][0]) ^ + _alpha_z2[idx] ^ _chi_2[idx]; + }); + } + + beta_z1_start = jmp.proc(ctx, beta_z1_start, 0, 1, 2, "beta_z1_start"); + beta_z2_start = jmp.proc(ctx, beta_z2_start, 0, 2, 1, "beta_z2_start"); + auto beta_z_start = ring_xor(beta_z1_start, beta_z2_start); + + NdArrayView _beta_z_start(beta_z_start); + NdArrayView _beta_plus_gamma_z(beta_plus_gamma_z); + if (rank == 1 || rank == 2) { + pforeach(0, numel, [&](int64_t idx) { + // beta_z = beta*_z + beta_x * beta_y + Phi + _out[idx][1] = + _beta_z_start[idx] ^ (_lhs[idx][1] & _rhs[idx][1]) ^ _Phi[idx]; + + _beta_plus_gamma_z[idx] = _out[idx][1] ^ _out[idx][2]; + }); + } + beta_plus_gamma_z = + jmp.proc(ctx, beta_plus_gamma_z, 1, 2, 0, "beta_plus_gamma_z"); + if (rank == 0) { + pforeach(0, numel, + [&](int64_t idx) { _out[idx][2] = _beta_plus_gamma_z[idx]; }); + } + + return out.as(makeType(field, out_nbits)); + }); + } + + NdArrayRef LShiftB::proc(KernelEvalContext*, const NdArrayRef& in, + const Sizes& shift) const { + const auto field = in.eltype().as()->field(); + + size_t out_nbits = in.eltype().as()->nbits() + + *std::max_element(shift.begin(), shift.end()); + out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); + + NdArrayRef out(makeType(field, out_nbits), in.shape()); + + auto in1 = getFirstShare(in); + auto in2 = getSecondShare(in); + auto in3 = getThirdShare(in); + + auto out1 = getFirstShare(out); + auto out2 = getSecondShare(out); + auto out3 = getThirdShare(out); + + ring_assign(out1, ring_lshift(in1, shift)); + ring_assign(out2, ring_lshift(in2, shift)); + ring_assign(out3, ring_lshift(in3, shift)); + + return out.as(makeType(field, out_nbits)); + } + + NdArrayRef RShiftB::proc(KernelEvalContext*, const NdArrayRef& in, + const Sizes& shift) const { + const auto field = in.eltype().as()->field(); + + int64_t nbits = in.eltype().as()->nbits(); + int64_t out_nbits = + nbits - std::min(nbits, *std::min_element(shift.begin(), shift.end())); + SPU_ENFORCE(nbits <= static_cast(SizeOf(field) * 8)); + + NdArrayRef out(makeType(field, out_nbits), in.shape()); + + auto in1 = getFirstShare(in); + auto in2 = getSecondShare(in); + auto in3 = getThirdShare(in); + + auto out1 = getFirstShare(out); + auto out2 = getSecondShare(out); + auto out3 = getThirdShare(out); + + ring_assign(out1, ring_rshift(in1, shift)); + ring_assign(out2, ring_rshift(in2, shift)); + ring_assign(out3, ring_rshift(in3, shift)); + + return out.as(makeType(field, out_nbits)); + } + + NdArrayRef ARShiftB::proc(KernelEvalContext*, const NdArrayRef& in, + const Sizes& shift) const { + const auto field = in.eltype().as()->field(); + + // arithmetic right shift expects to work on ring, or the behaviour is + // undefined. + + NdArrayRef out(makeType(field, SizeOf(field) * 8), in.shape()); + + auto in1 = getFirstShare(in); + auto in2 = getSecondShare(in); + auto in3 = getThirdShare(in); + + auto out1 = getFirstShare(out); + auto out2 = getSecondShare(out); + auto out3 = getThirdShare(out); + + ring_assign(out1, ring_arshift(in1, shift)); + ring_assign(out2, ring_arshift(in2, shift)); + ring_assign(out3, ring_arshift(in3, shift)); + + return out.as(makeType(field, SizeOf(field) * 8)); + } + + NdArrayRef BitrevB::proc(KernelEvalContext*, const NdArrayRef& in, size_t start, + size_t end) const { + const auto field = in.eltype().as()->field(); + + SPU_ENFORCE(start <= end); + SPU_ENFORCE(end <= SizeOf(field) * 8); + const size_t out_nbits = std::max(getNumBits(in), end); + + NdArrayRef out(makeType(field, out_nbits), in.shape()); + + auto in1 = getFirstShare(in); + auto in2 = getSecondShare(in); + auto in3 = getThirdShare(in); + + auto out1 = getFirstShare(out); + auto out2 = getSecondShare(out); + auto out3 = getThirdShare(out); + + ring_assign(out1, ring_bitrev(in1, start, end)); + ring_assign(out2, ring_bitrev(in2, start, end)); + ring_assign(out3, ring_bitrev(in3, start, end)); + + return out.as(makeType(field, out_nbits)); + } + + NdArrayRef BitIntlB::proc(KernelEvalContext*, const NdArrayRef& in, + size_t stride) const { + const auto field = in.eltype().as()->field(); + const auto nbits = getNumBits(in); + SPU_ENFORCE(absl::has_single_bit(nbits)); + + NdArrayRef out(in.eltype(), in.shape()); + auto numel = in.numel(); + + DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + NdArrayView _in(in); + NdArrayView _out(out); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = BitIntl(_in[idx][0], stride, nbits); + _out[idx][1] = BitIntl(_in[idx][1], stride, nbits); + _out[idx][2] = BitIntl(_in[idx][2], stride, nbits); + }); + }); + + return out.as(in.eltype()); + } + + NdArrayRef BitDeintlB::proc(KernelEvalContext*, const NdArrayRef& in, + size_t stride) const { + const auto field = in.eltype().as()->field(); + const auto nbits = getNumBits(in); + SPU_ENFORCE(absl::has_single_bit(nbits)); + + NdArrayRef out(in.eltype(), in.shape()); + auto numel = in.numel(); + + DISPATCH_ALL_FIELDS(field, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + NdArrayView _in(in); + NdArrayView _out(out); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx][0] = BitDeintl(_in[idx][0], stride, nbits); + _out[idx][1] = BitDeintl(_in[idx][1], stride, nbits); + _out[idx][2] = BitDeintl(_in[idx][2], stride, nbits); + }); + }); + + return out.as(in.eltype()); + } + + +} // namespace spu::mpc::swift \ No newline at end of file diff --git a/libspu/mpc/swift/boolean.h b/libspu/mpc/swift/boolean.h new file mode 100644 index 00000000..3a67a853 --- /dev/null +++ b/libspu/mpc/swift/boolean.h @@ -0,0 +1,164 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::swift { + +class B2P : public UnaryKernel { +public: + static constexpr const char* kBindName() {return "b2p";} + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K() * (ce::N() - 1); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; + +}; + +class P2B : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "p2b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class XorBP : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "xor_bp"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class XorBB : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "xor_bb"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class AndBP : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "and_bp"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class AndBB : public BinaryKernel { + public: + static constexpr const char* kBindName() { return "and_bb"; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K() * 2 * (ce::N() - 1); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class LShiftB : public ShiftKernel { + public: + static constexpr const char* kBindName() { return "lshift_b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Sizes& shift) const override; +}; + +class RShiftB : public ShiftKernel { + public: + static constexpr const char* kBindName() { return "rshift_b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Sizes& shift) const override; +}; + +class ARShiftB : public ShiftKernel { + public: + static constexpr const char* kBindName() { return "arshift_b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Sizes& shift) const override; +}; + +class BitrevB : public BitrevKernel { + public: + static constexpr const char* kBindName() { return "bitrev_b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t start, + size_t end) const override; +}; + +class BitIntlB : public BitSplitKernel { + public: + static constexpr const char* kBindName() { return "bitintl_b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t stride) const override; +}; + +class BitDeintlB : public BitSplitKernel { + public: + static constexpr const char* kBindName() { return "bitdeintl_b"; } + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t stride) const override; +}; + +} // namespace spu::mpc::swift \ No newline at end of file diff --git a/libspu/mpc/swift/commitment.cc b/libspu/mpc/swift/commitment.cc new file mode 100644 index 00000000..ba137c73 --- /dev/null +++ b/libspu/mpc/swift/commitment.cc @@ -0,0 +1,90 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/commitment.h" + +#include "spdlog/spdlog.h" +#include "yacl/crypto/hash/blake3.h" +#include "yacl/crypto/hash/hash_utils.h" +#include "yacl/crypto/rand/rand.h" + +#include "libspu/core/prelude.h" + +namespace spu::mpc { +std::string commit(size_t send_player, absl::string_view msg, + absl::string_view r, size_t hash_len, + yacl::crypto::HashAlgorithm hash_type) { + std::unique_ptr hash_algo; + switch (hash_type) { + case yacl::crypto::HashAlgorithm::BLAKE3: + hash_algo = std::make_unique(); + break; + default: + SPU_THROW("Unsupported hash algo in commitment scheme"); + } + + hash_algo->Update(std::to_string(send_player)); + hash_algo->Update(msg); + hash_algo->Update(r); + std::vector hash = hash_algo->CumulativeHash(); + SPU_ENFORCE(hash_len <= hash.size()); + + std::string hash_str(reinterpret_cast(hash.data()), hash_len); + + return hash_str; +} + +bool commit_and_open(const std::shared_ptr& lctx, + const std::string& z_str, + std::vector* z_strs) { + bool res = true; + size_t send_player = lctx->Rank(); + uint128_t rs = yacl::crypto::SecureRandSeed(); + std::string rs_str(reinterpret_cast(&rs), sizeof(rs)); + // 1. commit and send + auto cmt = commit(send_player, z_str, rs_str); + auto all_cmts = yacl::link::AllGather( + lctx, yacl::ByteContainerView(cmt.data(), cmt.size()), + "COMMITMENT::COMMIT"); + + // 2. open commit + std::string open_str = z_str + rs_str; + auto all_opens = yacl::link::AllGather( + lctx, yacl::ByteContainerView(open_str.data(), open_str.size()), + "COMMITMENT::OPEN"); + + // 3. check consistency + YACL_ENFORCE(z_strs != nullptr); + for (size_t i = 0; i < lctx->WorldSize(); ++i) { + if (i == lctx->Rank()) { + z_strs->emplace_back(z_str); + continue; + } + auto _open = std::string_view(all_opens[i]); + auto _z = _open.substr(0, z_str.size()); + auto _rs = _open.substr(z_str.size(), rs_str.size()); + + auto ref_cmt = commit(i, _z, _rs); + if (ref_cmt != std::string_view(all_cmts[i])) { + res = false; + SPDLOG_INFO("commit check fail for rank {}", i); + } + + z_strs->emplace_back(_z); + } + + return res; +} + +} // namespace spu::mpc diff --git a/libspu/mpc/swift/commitment.h b/libspu/mpc/swift/commitment.h new file mode 100644 index 00000000..18bf7887 --- /dev/null +++ b/libspu/mpc/swift/commitment.h @@ -0,0 +1,31 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/crypto/hash/hash_interface.h" +#include "yacl/link/link.h" + +namespace spu::mpc { + +std::string commit(size_t send_player, absl::string_view msg, + absl::string_view r, size_t hash_len = 32, + yacl::crypto::HashAlgorithm hash_type = + yacl::crypto::HashAlgorithm::BLAKE3); + +bool commit_and_open(const std::shared_ptr& lctx, + const std::string& z_str, + std::vector* z_strs); + +} // namespace spu::mpc diff --git a/libspu/mpc/swift/io.cc b/libspu/mpc/swift/io.cc new file mode 100644 index 00000000..8ff76c1d --- /dev/null +++ b/libspu/mpc/swift/io.cc @@ -0,0 +1,146 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/io.h" + +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/prg.h" + +#include "libspu/core/context.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/swift/type.h" +#include "libspu/mpc/swift/value.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::swift { + +Type SwiftIo::getShareType(Visibility vis, int owner_rank) const { + if (vis == VIS_PUBLIC) { + return makeType(field_); + } else if (vis == VIS_SECRET) { + if (owner_rank >= 0 && owner_rank <= 2) { + return makeType(field_, owner_rank); + } else { + return makeType(field_); + } + } + + SPU_THROW("unsupported vis type {}", vis); +} + +std::vector SwiftIo::toShares(const NdArrayRef& raw, Visibility vis, + int owner_rank) const { + SPU_ENFORCE(raw.eltype().isa(), "expected RingTy, got {}", + raw.eltype()); + const auto field = raw.eltype().as()->field(); + SPU_ENFORCE(field == field_, "expect raw value encoded in field={}, got={}", + field_, field); + + if (vis == VIS_PUBLIC) { + const auto share = raw.as(makeType(field)); + return std::vector(world_size_, share); + } else if (vis == VIS_SECRET) { + if (owner_rank >= 0 && owner_rank <= 2) { + // indicates private + std::vector shares; + + const auto ty = makeType(field, owner_rank); + for (int idx = 0; idx < 3; idx++) { + if (idx == owner_rank) { + shares.push_back(raw.as(ty)); + } else { + shares.push_back(makeConstantArrayRef(ty, raw.shape())); + } + } + return shares; + } else { + // normal secret + SPU_ENFORCE(owner_rank == -1, "not a valid owner {}", owner_rank); + + // by default, make as arithmetic share. + std::vector shares; + + const auto alpha = ring_rand(field, raw.shape()); + // beta = raw + alpha + const auto gamma = ring_rand(field, raw.shape()); + const auto beta = ring_add(raw, alpha); + const auto gamma_plus_beta = ring_add(gamma, beta); + const auto split_alpha = ring_rand_additive_splits(alpha, 2); + + // P0 : alpha_1, alpha_2, beta + gamma + // P1 : alpha_1, beta, gamma + // P2 : alpha_2, beta, gamma + shares.push_back( + makeAShare(split_alpha[0], split_alpha[1], gamma_plus_beta, field)); + + shares.push_back(makeAShare(split_alpha[0], beta, gamma, field)); + + shares.push_back(makeAShare(split_alpha[1], beta, gamma, field)); + + return shares; + } + } + + SPU_THROW("unsupported vis type {}", vis); +} + +NdArrayRef SwiftIo::fromShares(const std::vector& shares) const { + const auto& eltype = shares.at(0).eltype(); + + if (eltype.isa()) { + SPU_ENFORCE(field_ == eltype.as()->field(), "field_={}, got={}", + field_, eltype.as()->field()); + return shares[0].as(makeType(field_)); + } else if (eltype.isa()) { + SPU_ENFORCE(field_ == eltype.as()->field(), "field_={}, got={}", + field_, eltype.as()->field()); + const size_t owner = eltype.as()->owner(); + return shares[owner].as(makeType(field_)); + } else if (eltype.isa()) { + SPU_ENFORCE(field_ == eltype.as()->field(), "field_={}, got={}", + field_, eltype.as()->field()); + NdArrayRef out(makeType(field_), shares[0].shape()); + SPU_ENFORCE(shares.size() == 3, "expect shares.size()=3, got={}", + shares.size()); + + DISPATCH_ALL_FIELDS(field_, [&]() { + using el_t = ring2k_t; + using shr_t = std::array; + NdArrayView _out(out); + NdArrayView _s0(shares[0]); + NdArrayView _s1(shares[1]); + + if (eltype.isa()) { + for (auto idx = 0; idx < shares[0].numel(); ++idx) { + _out[idx] = _s1[idx][1] - _s0[idx][0] - _s0[idx][1]; + } + } else if (eltype.isa()) { + for (auto idx = 0; idx < shares[0].numel(); ++idx) { + _out[idx] = _s1[idx][1] ^ _s0[idx][0] ^ _s0[idx][1]; + } + } + }); + return out; + } + + SPU_THROW("unsupported eltype {}", eltype); +} + +std::unique_ptr makeSwiftIo(FieldType field, size_t npc) { + SPU_ENFORCE(npc == 3U, "swift is only for 3pc."); + registerTypes(); + return std::make_unique(field, npc); +} + +} // namespace spu::mpc::swift \ No newline at end of file diff --git a/libspu/mpc/swift/io.h b/libspu/mpc/swift/io.h new file mode 100644 index 00000000..372e3bc8 --- /dev/null +++ b/libspu/mpc/swift/io.h @@ -0,0 +1,42 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/io_interface.h" + +namespace spu::mpc::swift { + +class SwiftIo final : public BaseIo { + public: + using BaseIo::BaseIo; + + std::vector toShares(const NdArrayRef& raw, Visibility vis, + int owner_rank) const override; + + Type getShareType(Visibility vis, int owner_rank = -1) const override; + + NdArrayRef fromShares(const std::vector& shares) const override; + + // std::vector makeBitSecret(const PtBufferView& in) const + // override; + + // size_t getBitSecretShareSize(size_t numel) const override; + + // bool hasBitSecretSupport() const override { return true; } +}; + +std::unique_ptr makeSwiftIo(FieldType field, size_t npc); + +} // namespace spu::mpc::swift diff --git a/libspu/mpc/swift/io_test.cc b/libspu/mpc/swift/io_test.cc new file mode 100644 index 00000000..e66c5b74 --- /dev/null +++ b/libspu/mpc/swift/io_test.cc @@ -0,0 +1,31 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/io_test.h" + +#include "libspu/mpc/swift/io.h" + +namespace spu::mpc::swift { + +INSTANTIATE_TEST_SUITE_P( + SwiftIoTest, IoTest, + testing::Combine(testing::Values(makeSwiftIo), // + testing::Values(3), // + testing::Values(FieldType::FM32, FieldType::FM64, + FieldType::FM128)), + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}", std::get<1>(p.param), std::get<2>(p.param)); + }); + +} // namespace spu::mpc::swift diff --git a/libspu/mpc/swift/protocol.cc b/libspu/mpc/swift/protocol.cc new file mode 100644 index 00000000..672efc90 --- /dev/null +++ b/libspu/mpc/swift/protocol.cc @@ -0,0 +1,69 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/standard_shape/protocol.h" + +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/swift/arithmetic.h" +#include "libspu/mpc/swift/boolean.h" +#include "libspu/mpc/swift/protocol.h" +#include "libspu/mpc/swift/type.h" +#include "libspu/mpc/swift/value.h" + +namespace spu::mpc { + +void regSwiftProtocol(SPUContext* ctx, + const std::shared_ptr& lctx) { + swift::registerTypes(); + + // add communicator + ctx->prot()->addState(lctx); + + // register random states & kernels. + ctx->prot()->addState(lctx); + + // add Z2k state. + ctx->prot()->addState(ctx->config().field()); + + // register public kernels. + regPV2kKernels(ctx->prot()); + + // Register standard shape ops + regStandardShapeOps(ctx); + + ctx->prot() + ->regKernel(); +} + +std::unique_ptr makeSwiftProtocol( + const RuntimeConfig& conf, + const std::shared_ptr& lctx) { + swift::registerTypes(); + + auto ctx = std::make_unique(conf, lctx); + + regSwiftProtocol(ctx.get(), lctx); + + return ctx; +} + +} // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/swift/protocol.h b/libspu/mpc/swift/protocol.h new file mode 100644 index 00000000..8edaa507 --- /dev/null +++ b/libspu/mpc/swift/protocol.h @@ -0,0 +1,30 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/link/context.h" + +#include "libspu/core/context.h" + +namespace spu::mpc { + +void regSwiftProtocol(SPUContext* ctx, + const std::shared_ptr& lctx); + +std::unique_ptr makeSwiftProtocol( + const RuntimeConfig& conf, + const std::shared_ptr& lctx); + +} // namespace spu::mpc diff --git a/libspu/mpc/swift/protocol_single_test.cc b/libspu/mpc/swift/protocol_single_test.cc new file mode 100644 index 00000000..427e932b --- /dev/null +++ b/libspu/mpc/swift/protocol_single_test.cc @@ -0,0 +1,601 @@ +#include "libspu/mpc/swift/protocol_single_test.h" + +#include "libspu/core/prelude.h" +#include "libspu/mpc/ab_api.h" +#include "libspu/mpc/api.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/kernel.h" +#include "libspu/mpc/swift/arithmetic.h" +#include "libspu/mpc/swift/type.h" +#include "libspu/mpc/swift/value.h" +#include "libspu/mpc/utils/ring_ops.h" +#include "libspu/mpc/utils/simulate.h" + +namespace spu::mpc::test { +namespace { + +Shape kShape = {20}; +const std::vector kShiftBits = {0, 1, 2, 31, 32, 33, 64, 1000}; + +#define EXPECT_VALUE_EQ(X, Y) \ + { \ + EXPECT_EQ((X).shape(), (Y).shape()); \ + EXPECT_TRUE(ring_all_equal((X).data(), (Y).data())); \ + } + +#define EXPECT_VALUE_ALMOST_EQ(X, Y, ERR) \ + { \ + EXPECT_EQ((X).shape(), (Y).shape()); \ + EXPECT_TRUE(ring_all_equal((X).data(), (Y).data(), ERR)); \ + } + + +TEST_P(ArithmeticTest, A2P_P2A) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + // auto rank = obj->prot()->getState()->getRank(); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + // auto p0 = ring_rand(conf.field(), kShape); + + /* WHEN */ + auto a0 = p2a(obj.get(), p0); + // ring_print(p0.data()); + // if (rank == 0) { + // fmt::print("P_0: "); + // ring_print(spu::mpc::swift::getFirstShare(a0.data())); + // fmt::print("\n"); + // } + + // if (rank == 1) { + // fmt::print("P_1: "); + // ring_print(spu::mpc::swift::getFirstShare(a0.data())); + // fmt::print("\n"); + // } + + // if (rank == 2) { + // fmt::print("P_2: "); + // ring_print(spu::mpc::swift::getFirstShare(a0.data())); + // fmt::print("\n"); + // } + + auto p1 = a2p(obj.get(), a0); + // if (rank == 0) { + // fmt::print("output P_0: "); + // ring_print(p1.data()); + // fmt::print("\n"); + // } + + // if (rank == 1) { + // fmt::print("output P_1: "); + // ring_print(p1.data()); + // fmt::print("\n"); + // } + + // if (rank == 2) { + // fmt::print("output P_2: "); + // ring_print(p1.data()); + // fmt::print("\n"); + // } + + /* THEN */ + EXPECT_VALUE_EQ(p0, p1); + }); +} + +// TEST_P(ArithmeticTest, ShairngTest) { +// const auto factory = std::get<0>(GetParam()); +// const RuntimeConfig& conf = std::get<1>(GetParam()); +// const size_t npc = std::get<2>(GetParam()); + +// utils::simulate(npc, [&](const std::shared_ptr& lctx) +// { +// auto obj = factory(conf, lctx); +// // auto rank = obj->prot()->getState()->getRank(); + +// /* GIVEN */ +// auto p0 = rand_p(obj.get(), kShape); + +// /* WHEN */ +// auto a0 = negate_a(obj.get(), p0); + +// // auto p1 = a2p(obj.get(), a0); + +// /* THEN */ +// // EXPECT_VALUE_EQ(p0, p1); +// }); +// } + +TEST_P(ArithmeticTest, AddAP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2a(obj.get(), p0); + + auto tmp = add_ap(obj.get(), a0, p1); + auto re = a2p(obj.get(), tmp); + auto rp = add_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(ArithmeticTest, AddAA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2a(obj.get(), p0); + auto a1 = p2a(obj.get(), p1); + + auto tmp = add_aa(obj.get(), a0, a1); + auto re = a2p(obj.get(), tmp); + auto rp = add_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(ArithmeticTest, MulAP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2a(obj.get(), p0); + + auto tmp = mul_ap(obj.get(), a0, p1); + auto re = a2p(obj.get(), tmp); + auto rp = mul_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(ArithmeticTest, MulAA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2a(obj.get(), p0); + auto a1 = p2a(obj.get(), p1); + + auto tmp = mul_aa(obj.get(), a0, a1); + auto re = a2p(obj.get(), tmp); + auto rp = mul_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(ArithmeticTest, MatMulAP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + const int64_t M = 3; + const int64_t K = 4; + const int64_t N = 3; + const Shape shape_A = {M, K}; + const Shape shape_B = {K, N}; + const Shape shape_C = {M, N}; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), shape_A); + auto p1 = rand_p(obj.get(), shape_B); + auto a0 = p2a(obj.get(), p0); + + /* WHEN */ + auto tmp = mmul_ap(obj.get(), a0, p1); + + auto r_aa = a2p(obj.get(), tmp); + + auto r_pp = mmul_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(r_aa, r_pp); + }); +} + +TEST_P(ArithmeticTest, MatMulAA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + const int64_t M = 3; + const int64_t K = 4; + const int64_t N = 5; + const Shape shape_A = {M, K}; + const Shape shape_B = {K, N}; + const Shape shape_C = {M, N}; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), shape_A); + auto p1 = rand_p(obj.get(), shape_B); + auto a0 = p2a(obj.get(), p0); + auto a1 = p2a(obj.get(), p1); + + /* WHEN */ + auto tmp = mmul_aa(obj.get(), a0, a1); + + auto r_aa = a2p(obj.get(), tmp); + auto r_pp = mmul_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(r_aa, r_pp); + }); +} + +TEST_P(ArithmeticTest, LShiftA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto a0 = p2a(obj.get(), p0); + + for (auto bits : kShiftBits) { + if (bits >= p0.elsize() * 8) { + // Shift more than elsize is a UB + continue; + } + /* WHEN */ + auto tmp = lshift_a(obj.get(), a0, {static_cast(bits)}); + auto r_b = a2p(obj.get(), tmp); + auto r_p = lshift_p(obj.get(), p0, {static_cast(bits)}); + + /* THEN */ + EXPECT_VALUE_EQ(r_b, r_p); + } + }); +} + +TEST_P(ArithmeticTest, NegateA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + // auto rank = obj->prot()->getState()->getRank(); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto a0 = p2a(obj.get(), p0); + auto neg_p0 = negate_p(obj.get(), p0); + + /* WHEN */ + auto r_a = negate_a(obj.get(), a0); + + auto r_p = a2p(obj.get(), r_a); + auto r_pp = a2p(obj.get(), negate_a(obj.get(), a0)); + + /* THEN */ + EXPECT_VALUE_EQ(r_p, r_pp); + EXPECT_VALUE_EQ(r_p, neg_p0); + }); +} + +TEST_P(ArithmeticTest, TruncA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + // ArrayRef p0_large = + // ring_rand_range(conf.field(), kShape, -(1 << 28), -(1 << 27)); + // ArrayRef p0_small = ring_rand_range(conf.field(), kShape, 1, 10000); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + auto* kernel = + static_cast(obj->prot()->getKernel("trunc_a")); + + auto p0 = rand_p(obj.get(), kShape); + // auto p0 = rand_p(obj.get(), {4, 5}); + + if (!kernel->hasMsbError()) { + // trunc requires MSB to be zero. + p0 = arshift_p(obj.get(), p0, {1}); + } else { + // has msb error, only use lowest 10 bits. + p0 = arshift_p(obj.get(), p0, + {static_cast(SizeOf(conf.field()) * 8 - 10)}); + } + + /* GIVEN */ + const size_t bits = 2; + auto a0 = p2a(obj.get(), p0); + + /* WHEN */ + auto a1 = trunc_a(obj.get(), a0, bits, SignType::Unknown); + + auto r_a = a2p(obj.get(), a1); + auto r_p = arshift_p(obj.get(), p0, {static_cast(bits)}); + + /* THEN */ + EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc); + }); +} + +// BooleanTest +TEST_P(BooleanTest, P2B_B2P) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto b0 = p2b(obj.get(), p0); + auto p1 = b2p(obj.get(), b0); + + /* THEN */ + EXPECT_VALUE_EQ(p0, p1); + }); +} + +TEST_P(BooleanTest, XorBP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2b(obj.get(), p0); + + auto tmp = xor_bp(obj.get(), a0, p1); + auto re = b2p(obj.get(), tmp); + auto rp = xor_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(BooleanTest, XorBB) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2b(obj.get(), p0); + auto a1 = p2b(obj.get(), p1); + + auto tmp = xor_bb(obj.get(), a0, a1); + auto re = b2p(obj.get(), tmp); + auto rp = xor_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(BooleanTest, AndBP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2b(obj.get(), p0); + + auto tmp = and_bp(obj.get(), a0, p1); + auto re = b2p(obj.get(), tmp); + auto rp = and_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(BooleanTest, AndBB) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto a0 = p2b(obj.get(), p0); + auto a1 = p2b(obj.get(), p1); + + auto tmp = and_bb(obj.get(), a0, a1); + auto re = b2p(obj.get(), tmp); + auto rp = and_pp(obj.get(), p0, p1); + + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + }); +} + +TEST_P(BooleanTest, LshiftB) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate( + npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto b0 = p2b(obj.get(), p0); + + for (auto bits : kShiftBits) { + if (bits >= p0.elsize() * 8) { + continue; + } + /* WHEN */ + auto tmp = lshift_b(obj.get(), b0, {static_cast(bits)}); + auto r_b = b2p(obj.get(), tmp); + auto r_p = lshift_p(obj.get(), p0, {static_cast(bits)}); + + /* THEN */ + EXPECT_VALUE_EQ(r_b, r_p); + } + }); +} + +TEST_P(BooleanTest, RshiftB) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate( + npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto b0 = p2b(obj.get(), p0); + + for (auto bits : kShiftBits) { + if (bits >= p0.elsize() * 8) { + continue; + } + /* WHEN */ + auto tmp = rshift_b(obj.get(), b0, {static_cast(bits)}); + auto r_b = b2p(obj.get(), tmp); + auto r_p = rshift_p(obj.get(), p0, {static_cast(bits)}); + + /* THEN */ + EXPECT_VALUE_EQ(r_b, r_p); + } + }); +} + +TEST_P(BooleanTest, ARshiftB) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate( + npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto b0 = p2b(obj.get(), p0); + + for (auto bits : kShiftBits) { + if (bits >= p0.elsize() * 8) { + continue; + } + /* WHEN */ + auto tmp = arshift_b(obj.get(), b0, {static_cast(bits)}); + auto r_b = b2p(obj.get(), tmp); + auto r_p = arshift_p(obj.get(), p0, {static_cast(bits)}); + + /* THEN */ + EXPECT_VALUE_EQ(r_b, r_p); + } + }); +} + +TEST_P(BooleanTest, BitrevB) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto b0 = p2b(obj.get(), p0); + + for (size_t i = 0; i < SizeOf(conf.field()); i++) { + for (size_t j = i; j < SizeOf(conf.field()); j++) { + auto b1 = bitrev_b(obj.get(), b0, i, j); + + auto p1 = b2p(obj.get(), b1); + auto pp1 = bitrev_p(obj.get(), p0, i, j); + EXPECT_VALUE_EQ(p1, pp1); + } + } + }); +} + +} // namespace +} // namespace spu::mpc::test \ No newline at end of file diff --git a/libspu/mpc/swift/protocol_single_test.h b/libspu/mpc/swift/protocol_single_test.h new file mode 100644 index 00000000..3074a7bb --- /dev/null +++ b/libspu/mpc/swift/protocol_single_test.h @@ -0,0 +1,12 @@ +#include "gtest/gtest.h" +#include "yacl/link/link.h" + +#include "libspu/mpc/api_test_params.h" + +namespace spu::mpc::test { + +class ArithmeticTest : public ::testing::TestWithParam {}; + +class BooleanTest : public ::testing::TestWithParam {}; + +} // namespace spu::mpc::test diff --git a/libspu/mpc/swift/protocol_test.cc b/libspu/mpc/swift/protocol_test.cc new file mode 100644 index 00000000..d8d6a87b --- /dev/null +++ b/libspu/mpc/swift/protocol_test.cc @@ -0,0 +1,56 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/protocol.h" + +#include "yacl/link/link.h" + +#include "libspu/mpc/swift/protocol_single_test.h" + +namespace spu::mpc::test { +namespace { + +RuntimeConfig makeConfig(FieldType field) { + RuntimeConfig conf; + conf.set_field(field); + return conf; +} + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + SwiftTest, ArithmeticTest, + testing::Combine(testing::Values(makeSwiftProtocol), // + testing::Values(makeConfig(FieldType::FM32), + makeConfig(FieldType::FM64), + makeConfig(FieldType::FM128)), // + testing::Values(3)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}", std::get<1>(p.param).field(), + std::get<2>(p.param)); + }); + +INSTANTIATE_TEST_SUITE_P( + SwiftTest, BooleanTest, + testing::Combine(testing::Values(makeSwiftProtocol), // + testing::Values(makeConfig(FieldType::FM32), + makeConfig(FieldType::FM64), + makeConfig(FieldType::FM128)), // + testing::Values(3)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}", std::get<1>(p.param).field(), + std::get<2>(p.param)); + }); + +} // namespace spu::mpc::test diff --git a/libspu/mpc/swift/type.cc b/libspu/mpc/swift/type.cc new file mode 100644 index 00000000..334b98bd --- /dev/null +++ b/libspu/mpc/swift/type.cc @@ -0,0 +1,32 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/type.h" + +#include + +#include "libspu/mpc/common/pv2k.h" + +namespace spu::mpc::swift { + +void registerTypes() { + regPV2kTypes(); + + static std::once_flag flag; + std::call_once(flag, []() { + TypeContext::getTypeContext()->addTypes(); + }); +} + +} // namespace spu::mpc::swift diff --git a/libspu/mpc/swift/type.h b/libspu/mpc/swift/type.h new file mode 100644 index 00000000..09a646a7 --- /dev/null +++ b/libspu/mpc/swift/type.h @@ -0,0 +1,79 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/type.h" + +namespace spu::mpc::swift { + +class AShrTy : public TypeImpl { + using Base = TypeImpl; + + public: + using Base::Base; + static std::string_view getStaticId() { return "swift.AShr"; } + explicit AShrTy(FieldType field) { field_ = field; } + size_t size() const override { return SizeOf(GetStorageType(field_)) * 3; } +}; + +class BShrTy : public TypeImpl { + using Base = TypeImpl; + + static constexpr size_t kDefaultNumBits = std::numeric_limits::max(); + + public: + using Base::Base; + explicit BShrTy(FieldType field, size_t nbits = kDefaultNumBits) { + field_ = field; + nbits_ = nbits == kDefaultNumBits ? SizeOf(field) * 8 : nbits; + SPU_ENFORCE(nbits_ <= SizeOf(field) * 8); + } + + static std::string_view getStaticId() { return "swift.BShr"; } + + void fromString(std::string_view detail) override { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1); + SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), + "parse failed from={}", detail); + nbits_ = std::stoul(std::string(nbits_str)); + }; + + std::string toString() const override { + return fmt::format("{},{}", FieldType_Name(field()), nbits_); + } + + size_t size() const override { return SizeOf(GetStorageType(field_)) * 3; } + + bool equals(TypeObject const* other) const override { + auto const* derived_other = dynamic_cast(other); + SPU_ENFORCE(derived_other); + return field_ == derived_other->field_ && nbits_ == derived_other->nbits(); + } +}; + +class PShrTy : public TypeImpl { + using Base = TypeImpl; + + public: + using Base::Base; + static std::string_view getStaticId() { return "swift.PShr"; } + explicit PShrTy() { field_ = FieldType::FM64; } +}; + +void registerTypes(); + +} // namespace spu::mpc::swift \ No newline at end of file diff --git a/libspu/mpc/swift/type_test.cc b/libspu/mpc/swift/type_test.cc new file mode 100644 index 00000000..fe7ac858 --- /dev/null +++ b/libspu/mpc/swift/type_test.cc @@ -0,0 +1,74 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/type.h" + +#include "gtest/gtest.h" + +namespace spu::mpc::swift { + +TEST(AShrTy, Simple) { + registerTypes(); + { + Type ty = makeType(FM32); + EXPECT_EQ(ty.size(), 4); + + EXPECT_TRUE(ty.isa()); + EXPECT_TRUE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_TRUE(ty.isa()); + EXPECT_FALSE(ty.isa()); + + EXPECT_EQ(ty.toString(), "swift.AShr"); + + EXPECT_EQ(Type::fromString(ty.toString()), ty); + } +} + +TEST(BShrTy, Simple) { + Type ty = makeType(); + // Swift::BShr constructor with field. + { + Type ty = makeType(FM128); + EXPECT_EQ(ty.size(), 16); + + EXPECT_TRUE(ty.isa()); + EXPECT_TRUE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_TRUE(ty.isa()); + + EXPECT_EQ(ty.toString(), "swift.BShr"); + + EXPECT_EQ(Type::fromString(ty.toString()), ty); + } + + // Swift::BShr constructor with field and nbits. + { + Type ty = makeType(FM128, 7); + EXPECT_EQ(ty.size(), 16); + + EXPECT_TRUE(ty.isa()); + EXPECT_TRUE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_TRUE(ty.isa()); + + EXPECT_EQ(ty.toString(), "swift.BShr"); + + EXPECT_EQ(Type::fromString(ty.toString()), ty); + } +} + +} // namespace spu::mpc::swift diff --git a/libspu/mpc/swift/value.cc b/libspu/mpc/swift/value.cc new file mode 100644 index 00000000..2332bb60 --- /dev/null +++ b/libspu/mpc/swift/value.cc @@ -0,0 +1,106 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/swift/value.h" + +#include "libspu/core/prelude.h" +#include "libspu/mpc/swift/type.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::swift { + +NdArrayRef getShare(const NdArrayRef& in, int64_t share_idx) { + SPU_ENFORCE(share_idx == 0 || share_idx == 1 || share_idx == 2, + "expect share_idx = 1 or 2 or 3, got={}", share_idx); + + auto new_strides = in.strides(); + std::transform(new_strides.cbegin(), new_strides.cend(), new_strides.begin(), + [](int64_t s) { return 3 * s; }); + + if (in.eltype().isa()) { + const auto field = in.eltype().as()->field(); + const auto ty = makeType(field); + + return NdArrayRef( + in.buf(), ty, in.shape(), new_strides, + in.offset() + share_idx * static_cast(ty.size())); + } else if(in.eltype().isa()){ + const auto field = in.eltype().as()->field(); + const auto ty = makeType(field); + + return NdArrayRef( + in.buf(), ty, in.shape(), new_strides, + in.offset() + share_idx * static_cast(ty.size())); + } else if (in.eltype().isa()) { + const auto field = in.eltype().as()->field(); + const auto ty = makeType(field); + + return NdArrayRef( + in.buf(), ty, in.shape(), new_strides, + in.offset() + share_idx * static_cast(ty.size())); + } else { + SPU_THROW("unsupported type {}", in.eltype()); + } +} + +NdArrayRef getFirstShare(const NdArrayRef& in) { return getShare(in, 0); } + +NdArrayRef getSecondShare(const NdArrayRef& in) { return getShare(in, 1); } + +NdArrayRef getThirdShare(const NdArrayRef& in) { return getShare(in, 2); } + +NdArrayRef makeAShare(const NdArrayRef& s1, const NdArrayRef& s2, + const NdArrayRef& s3, FieldType field) { + const Type ty = makeType(field); + + SPU_ENFORCE(s2.eltype().as()->field() == field); + SPU_ENFORCE(s1.eltype().as()->field() == field); + SPU_ENFORCE(s1.shape() == s2.shape(), "got s1={}, s2={}", s1, s2); + SPU_ENFORCE(ty.size() == 3 * s1.elsize()); + + NdArrayRef res(ty, s1.shape()); + + if (res.numel() != 0) { + auto res_s1 = getFirstShare(res); + auto res_s2 = getSecondShare(res); + auto res_s3 = getThirdShare(res); + + ring_assign(res_s1, s1); + ring_assign(res_s2, s2); + ring_assign(res_s3, s3); + } + + return res; +} + +PtType calcBShareBacktype(size_t nbits) { + if (nbits <= 8) { + return PT_U8; + } + if (nbits <= 16) { + return PT_U16; + } + if (nbits <= 32) { + return PT_U32; + } + if (nbits <= 64) { + return PT_U64; + } + if (nbits <= 128) { + return PT_U128; + } + SPU_THROW("invalid number of bits={}", nbits); +} + +} // namespace spu::mpc::swift diff --git a/libspu/mpc/swift/value.h b/libspu/mpc/swift/value.h new file mode 100644 index 00000000..d4739b68 --- /dev/null +++ b/libspu/mpc/swift/value.h @@ -0,0 +1,57 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/ndarray_ref.h" +#include "libspu/core/type_util.h" + +namespace spu::mpc::swift { + +NdArrayRef getShare(const NdArrayRef& in, int64_t share_idx); + +NdArrayRef getFirstShare(const NdArrayRef& in); + +NdArrayRef getSecondShare(const NdArrayRef& in); + +NdArrayRef getThirdShare(const NdArrayRef& in); + +NdArrayRef makeAShare(const NdArrayRef& s1, const NdArrayRef& s2, + const NdArrayRef& s3, FieldType field); + +PtType calcBShareBacktype(size_t nbits); + +template +std::vector getShareAs(const NdArrayRef& in, size_t share_idx) { + SPU_ENFORCE(share_idx == 0 || share_idx == 1 || share_idx == 2); + + NdArrayRef share = getShare(in, share_idx); + SPU_ENFORCE(share.elsize() == sizeof(T)); + + auto numel = in.numel(); + + std::vector res(numel); + DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), [&]() { + NdArrayView _share(share); + for (auto idx = 0; idx < numel; ++idx) { + res[idx] = _share[idx]; + } + }); + + return res; +} + +#define PFOR_GRAIN_SIZE 8192 + +} // namespace spu::mpc::swift \ No newline at end of file