Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for Shamir-based Protocol #879

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug",
"program": "${workspaceFolder}/bazel-bin/libspu/mpc/shamir/protocol_test",
"args": ["--gtest_filter=Shamir/ArithmeticTest.A2P/FM32x3"],
"cwd": "${workspaceFolder}"
}
]
}
8 changes: 8 additions & 0 deletions libspu/mpc/ab_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ Value mul_aa(SPUContext* ctx, const Value& x, const Value& y) {
TILED_DISPATCH(ctx, x, y);
}

Value mul_aa_p(SPUContext* ctx, const Value& x, const Value& y) {
TILED_DISPATCH(ctx, x, y);
}

Value square_a(SPUContext* ctx, const Value& x) { TILED_DISPATCH(ctx, x); }

OptionalAPI<Value> mul_av(SPUContext* ctx, const Value& x, const Value& y) {
Expand All @@ -142,6 +146,10 @@ Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign) {
TILED_DISPATCH(ctx, x, nbits, sign);
}

Value mul_aa_trunc(SPUContext *ctx, const Value &x, const Value &y, size_t nbits, SignType sign) {
TILED_DISPATCH(ctx, x, y, nbits, sign);
}

Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y) {
FORCE_DISPATCH(ctx, x, y);
}
Expand Down
3 changes: 3 additions & 0 deletions libspu/mpc/ab_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ OptionalAPI<Value> add_av(SPUContext* ctx, const Value& x, const Value& y);

Value mul_ap(SPUContext* ctx, const Value& x, const Value& y);
Value mul_aa(SPUContext* ctx, const Value& x, const Value& y);
Value mul_aa_p(SPUContext* ctx, const Value& x, const Value& y);

Value square_a(SPUContext* ctx, const Value& x);
OptionalAPI<Value> mul_av(SPUContext* ctx, const Value& x, const Value& y);

Expand All @@ -49,6 +51,7 @@ OptionalAPI<Value> mul_a1bv(SPUContext* ctx, const Value& x, const Value& y);

Value lshift_a(SPUContext* ctx, const Value& x, const Sizes& nbits);
Value trunc_a(SPUContext* ctx, const Value& x, size_t nbits, SignType sign);
Value mul_aa_trunc(SPUContext* ctx, const Value& x, const Value& y, size_t nbits, SignType sign);

Value mmul_ap(SPUContext* ctx, const Value& x, const Value& y);
Value mmul_aa(SPUContext* ctx, const Value& x, const Value& y);
Expand Down
103 changes: 101 additions & 2 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,59 @@ TEST_P(ArithmeticTest, MulA1BV) {
});
}

TEST_P(ArithmeticTest, MulAA) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto sctx = factory(conf, lctx);

auto p0 = rand_p(sctx.get(), kShape);
auto p1 = rand_p(sctx.get(), kShape);

auto v0 = p2v(sctx.get(), p0, 0);
auto v1 = p2v(sctx.get(), p1, 1);

auto a0 = v2a(sctx.get(), v0);
auto a1 = v2a(sctx.get(), v1);

auto prod = mul_aa(sctx.get(), a0, a1);
auto p_prod = a2p(sctx.get(), prod);

auto s = mul_pp(sctx.get(), p0, p1);

/* THEN */
EXPECT_VALUE_EQ(s, p_prod);
});
}

TEST_P(ArithmeticTest, MulAAP) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto sctx = factory(conf, lctx);

auto p0 = rand_p(sctx.get(), kShape);
auto p1 = rand_p(sctx.get(), kShape);

auto v0 = p2v(sctx.get(), p0, 0);
auto v1 = p2v(sctx.get(), p1, 1);

auto a0 = v2a(sctx.get(), v0);
auto a1 = v2a(sctx.get(), v1);

auto prod = mul_aa_p(sctx.get(), a0, a1);

auto s = mul_pp(sctx.get(), p0, p1);

/* THEN */
EXPECT_VALUE_EQ(s, prod);
});
}

TEST_P(ArithmeticTest, MatMulAP) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
Expand Down Expand Up @@ -521,10 +574,12 @@ TEST_P(ArithmeticTest, TruncA) {
p0 = arshift_p(obj.get(), p0,
{static_cast<int64_t>(SizeOf(conf.field()) * 8 - 10)});
}

auto v0 = p2v(obj.get(), p0, 0);

/* GIVEN */
const size_t bits = 2;
auto a0 = p2a(obj.get(), p0);
auto a0 = v2a(obj.get(), v0);
// auto a0 = p2a(obj.get(), p0);

/* WHEN */
auto prev = obj->prot()->getState<Communicator>()->getStats();
Expand All @@ -541,6 +596,50 @@ TEST_P(ArithmeticTest, TruncA) {
});
}

TEST_P(ArithmeticTest, MulAATrunc) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());

// ArrayRef p0_large =
// ring_rand_range(conf.field(), kShape, -(1 << 28), -(1 << 27));
// ArrayRef p0_small = ring_rand_range(conf.field(), kShape, 1, 10000);

utils::simulate(npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
auto obj = factory(conf, lctx);


auto p0 = rand_p(obj.get(), kShape);
auto p1 = rand_p(obj.get(), kShape);

auto bits_range_gap = p0.elsize() * 8 - (p0.elsize() * 8) / 2;
p0 = arshift_p(obj.get(), p0, {static_cast<int64_t>(bits_range_gap)});
p1 = arshift_p(obj.get(), p1, {static_cast<int64_t>(bits_range_gap)});
auto prod = mul_pp(obj.get(), p0, p1);

auto v0 = p2v(obj.get(), p0, 0);
auto v1 = p2v(obj.get(), p1, 1);

/* GIVEN */
auto a0 = v2a(obj.get(), v0);
auto a1 = v2a(obj.get(), v1);

/* WHEN */
const size_t bits = 2;
auto prev = obj->prot()->getState<Communicator>()->getStats();
auto prod_a = mul_aa_trunc(obj.get(), a0, a1, bits, SignType::Unknown);
auto cost = obj->prot()->getState<Communicator>()->getStats() - prev;

auto r_a = a2p(obj.get(), prod_a);
auto r_p = arshift_p(obj.get(), prod, {static_cast<int64_t>(bits)});

/* THEN */
EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc);
EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mul_aa_trunc"), "mul_aa_trunc",
conf.field(), kShape, npc, cost));
});
}

TEST_P(ArithmeticTest, P2A) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
Expand Down
1 change: 1 addition & 0 deletions libspu/mpc/api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace spu::mpc::test {
namespace {

Shape kShape = {20, 30};
// Shape kShape = {3};
const std::vector<size_t> kShiftBits = {0, 1, 2, 31, 32, 33, 64, 1000};

#define EXPECT_VALUE_EQ(X, Y) \
Expand Down
11 changes: 11 additions & 0 deletions libspu/mpc/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ void TruncAKernel::evaluate(KernelEvalContext* ctx) const {
ctx->pushOutput(WrapValue(z));
}

void MulTruncAKernel::evaluate(KernelEvalContext* ctx) const {
const auto& lhs = ctx->getParam<Value>(0);
const auto& rhs = ctx->getParam<Value>(1);
size_t bits = ctx->getParam<size_t>(2);
SignType sign = ctx->getParam<SignType>(3);

auto z = proc(ctx, UnwrapValue(lhs), UnwrapValue(rhs), bits, sign);

ctx->pushOutput(WrapValue(z));
}

void BitSplitKernel::evaluate(KernelEvalContext* ctx) const {
const auto& in = ctx->getParam<Value>(0);
size_t stride = ctx->getParam<size_t>(1);
Expand Down
16 changes: 16 additions & 0 deletions libspu/mpc/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ class TruncAKernel : public Kernel {
size_t bits, SignType sign) const = 0;
};

class MulTruncAKernel : public Kernel {
public:
void evaluate(KernelEvalContext* ctx) const override;

// For protocol like SecureML, the most significant bit may have error with
// low probability, which lead to huge calculation error.
//
// Return true if the protocol has this kind of error.
virtual bool hasMsbError() const = 0;

virtual TruncLsbRounding lsbRounding() const = 0;

virtual NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs,
size_t bits, SignType sign) const = 0;
};

class BitSplitKernel : public Kernel {
public:
void evaluate(KernelEvalContext* ctx) const override;
Expand Down
51 changes: 43 additions & 8 deletions libspu/mpc/shamir/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ NdArrayRef wrap_a2p(SPUContext* ctx, const NdArrayRef& x) {
return UnwrapValue(a2p(ctx, WrapValue(x)));
}

// Generate zero sharings of degree = threshold
NdArrayRef gen_zero_shares(KernelEvalContext* ctx, int64_t numel, int64_t threshold) {
const auto field = ctx->getState<Z2kState>()->getDefaultField();
auto* prg_state = ctx->getState<PrgState>();
auto* comm = ctx->getState<Communicator>();
auto ty = makeType<PubGfmpTy>(field);
auto r = prg_state->genPubl(field, {threshold * numel}).as(ty);
auto coeffs = gfmp_mod(r);
NdArrayRef zeros = ring_zeros(field, {numel}).as(makeType<GfmpTy>(field));
auto shares = gfmp_rand_shamir_shares(zeros, coeffs, comm->getWorldSize(), threshold);
return shares[comm->getRank()].as(makeType<AShrTy>(field));
}

// Ref: DN'07 protocol
// https://www.iacr.org/archive/crypto2007/46220565/46220565.pdf
std::pair<NdArrayRef, NdArrayRef> gen_double_shares(KernelEvalContext* ctx,
Expand Down Expand Up @@ -143,6 +156,9 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const {
NdArrayRef r_t(ty, {dn_times * (world_size - th)});
NdArrayView<ring2k_t> _r_t(r_t);
auto van = GenVandermondeMatrix<ring2k_t>(world_size, world_size - th);
// TODO optimize me: all random shares can be done by a mmut between van^T * r_shrs
// van^T is a n-t by n
// r_shrs is a n by dn_times matrix
pforeach(0, dn_times, [&](int64_t idx) {
GfmpMatrix<ring2k_t> s_t(1, world_size);
for (auto i = 0; i < world_size; ++i) {
Expand All @@ -162,14 +178,7 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const {

NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
const auto field = in.eltype().as<Ring2k>()->field();
auto* prg_state = ctx->getState<PrgState>();
auto* comm = ctx->getState<Communicator>();
int64_t th = ctx->sctx()->config().sss_threshold();
auto ty = makeType<PubGfmpTy>(field);
auto r = prg_state->genPubl(field, {th * in.numel()}).as(ty);
auto coeffs = gfmp_mod(r);
auto shares = gfmp_rand_shamir_shares(in, coeffs, comm->getWorldSize(), th);
return shares[comm->getRank()].as(makeType<AShrTy>(field));
return in.as(makeType<AShrTy>(field));
}

NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
Expand Down Expand Up @@ -300,6 +309,32 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
return out;
}

// Combine MulAA and A2P in 1 round
NdArrayRef MulAAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
const NdArrayRef& rhs) const {
SPU_ENFORCE(lhs.numel() == rhs.numel());
SPU_ENFORCE_EQ(lhs.eltype(), rhs.eltype());
const auto field = lhs.eltype().as<Ring2k>()->field();

// local mul
auto tmp_2t = gfmp_mul_mod(lhs, rhs).as(lhs.eltype());

// generate zero sharings of degree-2t
auto zero_shares = gen_zero_shares(ctx, lhs.numel(), ctx->sctx()->config().sss_threshold()<<1);

// add zero sharings
NdArrayRef out(lhs.eltype(), lhs.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView<ring2k_t> _zero(zero_shares);
NdArrayView<ring2k_t> _tmp_2t(tmp_2t);
NdArrayView<ring2k_t> _out(out);
pforeach(0, lhs.numel(),
[&](int64_t idx){ _out[idx] = add_mod(_tmp_2t[idx], _zero[idx]); });
});

return wrap_a2p(ctx->sctx(), out);
}

////////////////////////////////////////////////////////////////////
// matmul family
////////////////////////////////////////////////////////////////////
Expand Down
14 changes: 14 additions & 0 deletions libspu/mpc/shamir/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ class MulAA : public BinaryKernel {
const NdArrayRef& rhs) const override;
};

class MulAAP : public BinaryKernel {
public:
static constexpr char kBindName[] = "mul_aa_p";

Kind kind() const override { return Kind::Dynamic; }

ce::CExpr latency() const override { return ce::Const(1); }

ce::CExpr comm() const override { return ce::K() * 2 * (ce::N() - 1); }

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs,
const NdArrayRef& rhs) const override;
};

////////////////////////////////////////////////////////////////////
// matmul family
////////////////////////////////////////////////////////////////////
Expand Down
Loading
Loading