Skip to content

Commit

Permalink
add subtract operation (#50)
Browse files Browse the repository at this point in the history
Add some APIs for subtraction on ciphertexts
  • Loading branch information
RuiyuZhu authored and GitHub Enterprise committed Mar 22, 2024
1 parent 2c2ee35 commit 358b9c8
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Sources/SwiftHe/Bfv/Bfv+Encrypt.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extension Bfv {
var ciphertext: Ciphertext<Self, Self.CanonicalCiphertextFormat> = try encryptZero(
for: plaintext.context,
secretKey: secretKey)
try Self.addAssign(lhs: &ciphertext, rhs: plaintext)
try Self.addAssign(&ciphertext, plaintext)
return ciphertext
}

Expand Down
30 changes: 12 additions & 18 deletions Sources/SwiftHe/Bfv/Bfv.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,20 @@ public struct Bfv<T: ScalarType>: HeScheme {
preconditionFailure("Unimplemented")
}

public static func mulAssign(lhs _: inout CanonicalCiphertext, rhs _: CanonicalCiphertext) throws {
public static func mulAssign(_: inout CanonicalCiphertext, _: CanonicalCiphertext) throws {
preconditionFailure("Unimplemented")
}

// swiftlint:enable unavailable_function

// for Plaintext
public static func addAssign(lhs: inout CoeffPlaintext, rhs: CoeffPlaintext) throws {
try checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
lhs.poly += rhs.poly
}

public static func addAssign(lhs: inout EvalPlaintext, rhs: EvalPlaintext) throws {
try checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
public static func addAssign<F: HeFormat>(_ lhs: inout Plaintext<Bfv<T>, F>, _ rhs: Plaintext<Bfv<T>, F>) throws {
try checkContextConsistency(lhs.context, rhs.context)
lhs.poly += rhs.poly
}

@inlinable
public static func addAssign(lhs: inout CanonicalCiphertext, rhs: CanonicalCiphertext) throws {
try checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
public static func addAssign<F: HeFormat>(_ lhs: inout Ciphertext<Bfv<T>, F>, _ rhs: Ciphertext<Bfv<T>, F>) throws {
try checkContextConsistency(lhs.context, rhs.context)
guard lhs.polys.count == rhs.polys.count else {
throw HeError.incompatibleCiphertexts(lhs, rhs)
}
Expand All @@ -52,23 +46,23 @@ public struct Bfv<T: ScalarType>: HeScheme {
}

@inlinable
public static func addAssign(lhs: inout EvalCiphertext, rhs: EvalCiphertext) throws {
try checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
public static func subAssign<F: HeFormat>(_ lhs: inout Ciphertext<Bfv<T>, F>, _ rhs: Ciphertext<Bfv<T>, F>) throws {
try checkContextConsistency(lhs.context, rhs.context)
guard lhs.polys.count == rhs.polys.count else {
throw HeError.incompatibleCiphertexts(lhs, rhs)
}
for (polyIndex, rhsPoly) in rhs.polys.enumerated() {
lhs.polys[polyIndex] += rhsPoly
lhs.polys[polyIndex] -= rhsPoly
}
}

public static func addAssign(lhs ciphertext: inout CoeffCiphertext, rhs plaintext: CoeffPlaintext) throws {
public static func addAssign(_ ciphertext: inout CoeffCiphertext, _ plaintext: CoeffPlaintext) throws {
try plaintextTranslate(ciphertext: &ciphertext, plaintext: plaintext, op: PlaintextTranslateOp.Add)
}

@inlinable
public static func mulAssign(_ ciphertext: inout EvalCiphertext, _ plaintext: EvalPlaintext) throws {
try checkContextConsistency(lhs: ciphertext.context, rhs: plaintext.context)
try checkContextConsistency(ciphertext.context, plaintext.context)
guard ciphertext.moduli.count == plaintext.moduli.count else {
throw HeError.incompatibleCiphertextAndPlaintext(ciphertext: ciphertext, plaintext: plaintext)
}
Expand All @@ -82,12 +76,12 @@ public struct Bfv<T: ScalarType>: HeScheme {
// These operations could be supported with extra NTT conversions, but NTTs are expensive, so we prefer to
// keep NTT conversions explicit

public static func addAssign(lhs _: inout EvalCiphertext, rhs _: EvalPlaintext) throws {
public static func addAssign(_: inout EvalCiphertext, _: EvalPlaintext) throws {
// TODO: rdar://124413535 (swift-he: Check if BFV EvalCiphertext + EvalPlaintext is possible)
throw HeError.unsupportedHeOperation()
}

public static func addAssign(lhs _: inout CanonicalCiphertext, rhs _: EvalPlaintext) throws {
public static func addAssign(_: inout CanonicalCiphertext, _: EvalPlaintext) throws {
throw HeError.unsupportedHeOperation()
}

Expand Down
101 changes: 84 additions & 17 deletions Sources/SwiftHe/Ciphertext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,72 +16,96 @@ public struct Ciphertext<Scheme: HeScheme, Format: HeFormat> {
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Coeff>) throws
where Format == Coeff
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Coeff
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Eval>) throws
where Format == Eval
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Coeff>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Eval>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func += (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Eval
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.addAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.addAssign(&lhs, rhs)
}

@inlinable
public static func -= (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.subAssign(&lhs, rhs)
}

@inlinable
public static func -= (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Coeff
{
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.subAssign(&lhs, rhs)
}

@inlinable
public static func -= (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Eval
{
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.subAssign(&lhs, rhs)
}

@inlinable
public static func *= (lhs: inout Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Eval>) throws
where Format == Eval
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.mulAssign(&lhs, rhs)
}

@inlinable
public static func *= (lhs: inout Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws
where Format == Scheme.CanonicalCiphertextFormat
{
try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context)
try Scheme.mulAssign(lhs: &lhs, rhs: rhs)
try Scheme.checkContextConsistency(lhs.context, rhs.context)
try Scheme.mulAssign(&lhs, rhs)
}

@inlinable
Expand All @@ -96,6 +120,22 @@ public struct Ciphertext<Scheme: HeScheme, Format: HeFormat> {
return Ciphertext<Scheme, Coeff>(context: context, polys: polys, correctionFactor: correctionFactor)
}

@inlinable
public func convertToCoeffFormat() throws -> Ciphertext<Scheme, Coeff>
where Format == Scheme.CanonicalCiphertextFormat
{
if Format.self == Coeff.self {
if let ciphertext = self as? Ciphertext<Scheme, Coeff> {
return ciphertext
}
throw HeError.errorInSameFormatCasting(Format.self, Coeff.self)
}
if let ciphertext = self as? Ciphertext<Scheme, Eval> {
return try ciphertext.inverseNtt()
}
throw HeError.errorInSameFormatCasting(Format.self, Eval.self)
}

@inlinable
public func convertToEvalFormat() throws -> Ciphertext<Scheme, Eval>
where Format == Scheme.CanonicalCiphertextFormat
Expand Down Expand Up @@ -216,6 +256,15 @@ extension Ciphertext {
return result
}

@inlinable
static func - (lhs: Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws -> Self
where Format == Coeff
{
var result = lhs
try result -= rhs
return result
}

@inlinable
static func + (lhs: Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Eval>) throws -> Self
where Format == Eval
Expand Down Expand Up @@ -252,6 +301,15 @@ extension Ciphertext {
return result
}

@inlinable
static func - (lhs: Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws -> Self
where Format == Eval
{
var result = lhs
try result -= rhs
return result
}

@inlinable
static func + (lhs: Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws -> Self
where Format == Scheme.CanonicalCiphertextFormat
Expand All @@ -261,6 +319,15 @@ extension Ciphertext {
return result
}

@inlinable
static func - (lhs: Ciphertext<Scheme, Format>, rhs: Ciphertext<Scheme, Format>) throws -> Self
where Format == Scheme.CanonicalCiphertextFormat
{
var result = lhs
try result -= rhs
return result
}

@inlinable
static func * (lhs: Ciphertext<Scheme, Format>, rhs: Plaintext<Scheme, Eval>) throws -> Self
where Format == Eval
Expand Down
27 changes: 15 additions & 12 deletions Sources/SwiftHe/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,27 @@ public protocol HeScheme {
static func swapColumns(ciphertext: inout CanonicalCiphertext, evaluationKey: EvaluationKey) throws

// support PT-PT addition for plaintext of same format
static func addAssign(lhs: inout CoeffPlaintext, rhs: CoeffPlaintext) throws
static func addAssign(lhs: inout EvalPlaintext, rhs: EvalPlaintext) throws
static func addAssign(_ lhs: inout CoeffPlaintext, _ rhs: CoeffPlaintext) throws
static func addAssign(_ lhs: inout EvalPlaintext, _ rhs: EvalPlaintext) throws
// support CT-CT addition for ciphertext of same format
static func addAssign(lhs: inout CoeffCiphertext, rhs: CoeffCiphertext) throws
static func addAssign(lhs: inout EvalCiphertext, rhs: EvalCiphertext) throws
static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffCiphertext) throws
static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalCiphertext) throws
static func subAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffCiphertext) throws
static func subAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalCiphertext) throws
// support CT-PT addition
static func addAssign(lhs: inout CoeffCiphertext, rhs: CoeffPlaintext) throws
static func addAssign(lhs: inout EvalCiphertext, rhs: EvalPlaintext) throws
static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffPlaintext) throws
static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalPlaintext) throws

// support CT-PT mul for eval format plaintext/ciphertext
static func mulAssign(_ ciphertext: inout EvalCiphertext, _ plaintext: EvalPlaintext) throws

// support CT-CT add/mul for canonical format ciphertexts, note these functions don't need to be implemented
// separately
static func addAssign(lhs: inout CanonicalCiphertext, rhs: CanonicalCiphertext) throws
static func mulAssign(lhs: inout CanonicalCiphertext, rhs: CanonicalCiphertext) throws
static func addAssign(lhs: inout CanonicalCiphertext, rhs: CoeffPlaintext) throws
static func addAssign(lhs: inout CanonicalCiphertext, rhs: EvalPlaintext) throws
static func addAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CanonicalCiphertext) throws
static func mulAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CanonicalCiphertext) throws
static func addAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CoeffPlaintext) throws
static func addAssign(_ lhs: inout CanonicalCiphertext, _ rhs: EvalPlaintext) throws
static func subAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CanonicalCiphertext) throws

// ciphertext mod switch down protocol
static func modSwitchDown(ciphertext: inout CanonicalCiphertext) throws
Expand All @@ -102,11 +105,11 @@ public protocol HeScheme {
using key: GaloisKey<Self>) throws

// helper function check if two contexts are the same
static func checkContextConsistency(lhs: Context<Self>, rhs: Context<Self>) throws
static func checkContextConsistency(_ lhs: Context<Self>, _ rhs: Context<Self>) throws
}

public extension HeScheme {
static func checkContextConsistency(lhs: Context<Self>, rhs: Context<Self>) throws {
static func checkContextConsistency(_ lhs: Context<Self>, _ rhs: Context<Self>) throws {
guard lhs == rhs else {
throw HeError.inconsistentContext(got: lhs, expected: rhs)
}
Expand Down
Loading

0 comments on commit 358b9c8

Please sign in to comment.