From 358b9c88ad12f414c5c67f74d8f527aa4c9c53a1 Mon Sep 17 00:00:00 2001 From: Ruiyu Zhu Date: Fri, 22 Mar 2024 13:36:40 -0700 Subject: [PATCH] add subtract operation (#50) Add some APIs for subtraction on ciphertexts --- Sources/SwiftHe/Bfv/Bfv+Encrypt.swift | 2 +- Sources/SwiftHe/Bfv/Bfv.swift | 30 +++----- Sources/SwiftHe/Ciphertext.swift | 101 +++++++++++++++++++++----- Sources/SwiftHe/HeScheme.swift | 27 ++++--- Sources/SwiftHe/NoOpScheme.swift | 74 ++++++++++--------- Sources/SwiftHe/Plaintext.swift | 4 +- Tests/SwiftHeTests/HeAPITests.swift | 38 ++++++++++ 7 files changed, 193 insertions(+), 83 deletions(-) diff --git a/Sources/SwiftHe/Bfv/Bfv+Encrypt.swift b/Sources/SwiftHe/Bfv/Bfv+Encrypt.swift index 2c5d3642..c91d6838 100644 --- a/Sources/SwiftHe/Bfv/Bfv+Encrypt.swift +++ b/Sources/SwiftHe/Bfv/Bfv+Encrypt.swift @@ -12,7 +12,7 @@ extension Bfv { var ciphertext: Ciphertext = try encryptZero( for: plaintext.context, secretKey: secretKey) - try Self.addAssign(lhs: &ciphertext, rhs: plaintext) + try Self.addAssign(&ciphertext, plaintext) return ciphertext } diff --git a/Sources/SwiftHe/Bfv/Bfv.swift b/Sources/SwiftHe/Bfv/Bfv.swift index 633d0f9e..12266e71 100644 --- a/Sources/SwiftHe/Bfv/Bfv.swift +++ b/Sources/SwiftHe/Bfv/Bfv.swift @@ -23,26 +23,20 @@ public struct Bfv: HeScheme { preconditionFailure("Unimplemented") } - public static func mulAssign(lhs _: inout CanonicalCiphertext, rhs _: CanonicalCiphertext) throws { + public static func mulAssign(_: inout CanonicalCiphertext, _: CanonicalCiphertext) throws { preconditionFailure("Unimplemented") } // swiftlint:enable unavailable_function - // for Plaintext - public static func addAssign(lhs: inout CoeffPlaintext, rhs: CoeffPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - lhs.poly += rhs.poly - } - - public static func addAssign(lhs: inout EvalPlaintext, rhs: EvalPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + public static func addAssign(_ lhs: inout Plaintext, F>, _ rhs: Plaintext, F>) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.poly += rhs.poly } @inlinable - public static func addAssign(lhs: inout CanonicalCiphertext, rhs: CanonicalCiphertext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + public static func addAssign(_ lhs: inout Ciphertext, F>, _ rhs: Ciphertext, F>) throws { + try checkContextConsistency(lhs.context, rhs.context) guard lhs.polys.count == rhs.polys.count else { throw HeError.incompatibleCiphertexts(lhs, rhs) } @@ -52,23 +46,23 @@ public struct Bfv: HeScheme { } @inlinable - public static func addAssign(lhs: inout EvalCiphertext, rhs: EvalCiphertext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + public static func subAssign(_ lhs: inout Ciphertext, F>, _ rhs: Ciphertext, F>) throws { + try checkContextConsistency(lhs.context, rhs.context) guard lhs.polys.count == rhs.polys.count else { throw HeError.incompatibleCiphertexts(lhs, rhs) } for (polyIndex, rhsPoly) in rhs.polys.enumerated() { - lhs.polys[polyIndex] += rhsPoly + lhs.polys[polyIndex] -= rhsPoly } } - public static func addAssign(lhs ciphertext: inout CoeffCiphertext, rhs plaintext: CoeffPlaintext) throws { + public static func addAssign(_ ciphertext: inout CoeffCiphertext, _ plaintext: CoeffPlaintext) throws { try plaintextTranslate(ciphertext: &ciphertext, plaintext: plaintext, op: PlaintextTranslateOp.Add) } @inlinable public static func mulAssign(_ ciphertext: inout EvalCiphertext, _ plaintext: EvalPlaintext) throws { - try checkContextConsistency(lhs: ciphertext.context, rhs: plaintext.context) + try checkContextConsistency(ciphertext.context, plaintext.context) guard ciphertext.moduli.count == plaintext.moduli.count else { throw HeError.incompatibleCiphertextAndPlaintext(ciphertext: ciphertext, plaintext: plaintext) } @@ -82,12 +76,12 @@ public struct Bfv: HeScheme { // These operations could be supported with extra NTT conversions, but NTTs are expensive, so we prefer to // keep NTT conversions explicit - public static func addAssign(lhs _: inout EvalCiphertext, rhs _: EvalPlaintext) throws { + public static func addAssign(_: inout EvalCiphertext, _: EvalPlaintext) throws { // TODO: rdar://124413535 (swift-he: Check if BFV EvalCiphertext + EvalPlaintext is possible) throw HeError.unsupportedHeOperation() } - public static func addAssign(lhs _: inout CanonicalCiphertext, rhs _: EvalPlaintext) throws { + public static func addAssign(_: inout CanonicalCiphertext, _: EvalPlaintext) throws { throw HeError.unsupportedHeOperation() } diff --git a/Sources/SwiftHe/Ciphertext.swift b/Sources/SwiftHe/Ciphertext.swift index b8847067..f21938f0 100644 --- a/Sources/SwiftHe/Ciphertext.swift +++ b/Sources/SwiftHe/Ciphertext.swift @@ -16,63 +16,87 @@ public struct Ciphertext { public static func += (lhs: inout Ciphertext, rhs: Plaintext) throws where Format == Coeff { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) } @inlinable public static func += (lhs: inout Ciphertext, rhs: Ciphertext) throws where Format == Coeff { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) } @inlinable public static func += (lhs: inout Ciphertext, rhs: Plaintext) throws where Format == Eval { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) } @inlinable public static func += (lhs: inout Ciphertext, rhs: Plaintext) throws where Format == Scheme.CanonicalCiphertextFormat { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) } @inlinable public static func += (lhs: inout Ciphertext, rhs: Plaintext) throws where Format == Scheme.CanonicalCiphertextFormat { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) } @inlinable public static func += (lhs: inout Ciphertext, rhs: Ciphertext) throws where Format == Scheme.CanonicalCiphertextFormat { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) } @inlinable public static func += (lhs: inout Ciphertext, rhs: Ciphertext) throws where Format == Eval { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.addAssign(&lhs, rhs) + } + + @inlinable + public static func -= (lhs: inout Ciphertext, rhs: Ciphertext) throws + where Format == Scheme.CanonicalCiphertextFormat + { + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.subAssign(&lhs, rhs) + } + + @inlinable + public static func -= (lhs: inout Ciphertext, rhs: Ciphertext) throws + where Format == Coeff + { + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.subAssign(&lhs, rhs) + } + + @inlinable + public static func -= (lhs: inout Ciphertext, rhs: Ciphertext) throws + where Format == Eval + { + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.subAssign(&lhs, rhs) } @inlinable public static func *= (lhs: inout Ciphertext, rhs: Plaintext) throws where Format == Eval { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + try Scheme.checkContextConsistency(lhs.context, rhs.context) try Scheme.mulAssign(&lhs, rhs) } @@ -80,8 +104,8 @@ public struct Ciphertext { public static func *= (lhs: inout Ciphertext, rhs: Ciphertext) throws where Format == Scheme.CanonicalCiphertextFormat { - try Scheme.checkContextConsistency(lhs: lhs.context, rhs: rhs.context) - try Scheme.mulAssign(lhs: &lhs, rhs: rhs) + try Scheme.checkContextConsistency(lhs.context, rhs.context) + try Scheme.mulAssign(&lhs, rhs) } @inlinable @@ -96,6 +120,22 @@ public struct Ciphertext { return Ciphertext(context: context, polys: polys, correctionFactor: correctionFactor) } + @inlinable + public func convertToCoeffFormat() throws -> Ciphertext + where Format == Scheme.CanonicalCiphertextFormat + { + if Format.self == Coeff.self { + if let ciphertext = self as? Ciphertext { + return ciphertext + } + throw HeError.errorInSameFormatCasting(Format.self, Coeff.self) + } + if let ciphertext = self as? Ciphertext { + return try ciphertext.inverseNtt() + } + throw HeError.errorInSameFormatCasting(Format.self, Eval.self) + } + @inlinable public func convertToEvalFormat() throws -> Ciphertext where Format == Scheme.CanonicalCiphertextFormat @@ -216,6 +256,15 @@ extension Ciphertext { return result } + @inlinable + static func - (lhs: Ciphertext, rhs: Ciphertext) throws -> Self + where Format == Coeff + { + var result = lhs + try result -= rhs + return result + } + @inlinable static func + (lhs: Ciphertext, rhs: Plaintext) throws -> Self where Format == Eval @@ -252,6 +301,15 @@ extension Ciphertext { return result } + @inlinable + static func - (lhs: Ciphertext, rhs: Ciphertext) throws -> Self + where Format == Eval + { + var result = lhs + try result -= rhs + return result + } + @inlinable static func + (lhs: Ciphertext, rhs: Ciphertext) throws -> Self where Format == Scheme.CanonicalCiphertextFormat @@ -261,6 +319,15 @@ extension Ciphertext { return result } + @inlinable + static func - (lhs: Ciphertext, rhs: Ciphertext) throws -> Self + where Format == Scheme.CanonicalCiphertextFormat + { + var result = lhs + try result -= rhs + return result + } + @inlinable static func * (lhs: Ciphertext, rhs: Plaintext) throws -> Self where Format == Eval diff --git a/Sources/SwiftHe/HeScheme.swift b/Sources/SwiftHe/HeScheme.swift index d581d6db..252656eb 100644 --- a/Sources/SwiftHe/HeScheme.swift +++ b/Sources/SwiftHe/HeScheme.swift @@ -73,24 +73,27 @@ public protocol HeScheme { static func swapColumns(ciphertext: inout CanonicalCiphertext, evaluationKey: EvaluationKey) throws // support PT-PT addition for plaintext of same format - static func addAssign(lhs: inout CoeffPlaintext, rhs: CoeffPlaintext) throws - static func addAssign(lhs: inout EvalPlaintext, rhs: EvalPlaintext) throws + static func addAssign(_ lhs: inout CoeffPlaintext, _ rhs: CoeffPlaintext) throws + static func addAssign(_ lhs: inout EvalPlaintext, _ rhs: EvalPlaintext) throws // support CT-CT addition for ciphertext of same format - static func addAssign(lhs: inout CoeffCiphertext, rhs: CoeffCiphertext) throws - static func addAssign(lhs: inout EvalCiphertext, rhs: EvalCiphertext) throws + static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffCiphertext) throws + static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalCiphertext) throws + static func subAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffCiphertext) throws + static func subAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalCiphertext) throws // support CT-PT addition - static func addAssign(lhs: inout CoeffCiphertext, rhs: CoeffPlaintext) throws - static func addAssign(lhs: inout EvalCiphertext, rhs: EvalPlaintext) throws + static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffPlaintext) throws + static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalPlaintext) throws // support CT-PT mul for eval format plaintext/ciphertext static func mulAssign(_ ciphertext: inout EvalCiphertext, _ plaintext: EvalPlaintext) throws // support CT-CT add/mul for canonical format ciphertexts, note these functions don't need to be implemented // separately - static func addAssign(lhs: inout CanonicalCiphertext, rhs: CanonicalCiphertext) throws - static func mulAssign(lhs: inout CanonicalCiphertext, rhs: CanonicalCiphertext) throws - static func addAssign(lhs: inout CanonicalCiphertext, rhs: CoeffPlaintext) throws - static func addAssign(lhs: inout CanonicalCiphertext, rhs: EvalPlaintext) throws + static func addAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CanonicalCiphertext) throws + static func mulAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CanonicalCiphertext) throws + static func addAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CoeffPlaintext) throws + static func addAssign(_ lhs: inout CanonicalCiphertext, _ rhs: EvalPlaintext) throws + static func subAssign(_ lhs: inout CanonicalCiphertext, _ rhs: CanonicalCiphertext) throws // ciphertext mod switch down protocol static func modSwitchDown(ciphertext: inout CanonicalCiphertext) throws @@ -102,11 +105,11 @@ public protocol HeScheme { using key: GaloisKey) throws // helper function check if two contexts are the same - static func checkContextConsistency(lhs: Context, rhs: Context) throws + static func checkContextConsistency(_ lhs: Context, _ rhs: Context) throws } public extension HeScheme { - static func checkContextConsistency(lhs: Context, rhs: Context) throws { + static func checkContextConsistency(_ lhs: Context, _ rhs: Context) throws { guard lhs == rhs else { throw HeError.inconsistentContext(got: lhs, expected: rhs) } diff --git a/Sources/SwiftHe/NoOpScheme.swift b/Sources/SwiftHe/NoOpScheme.swift index a46928ac..52125f3c 100644 --- a/Sources/SwiftHe/NoOpScheme.swift +++ b/Sources/SwiftHe/NoOpScheme.swift @@ -79,7 +79,7 @@ struct NoOpScheme: HeScheme { evaluationKey _: EvaluationKey) throws { let element = try getGaloisElement(for: step, degree: ciphertext.context.parameter.polyDegree) - ciphertext.polys[0] = try ciphertext.polys[0].applyGalois(galoisElement: element) + ciphertext.polys[0] = ciphertext.polys[0].applyGalois(galoisElement: element) } static func swapColumns( @@ -87,84 +87,92 @@ struct NoOpScheme: HeScheme { evaluationKey _: EvaluationKey) throws { let element = getGaloisElementForColumnRotation(degree: ciphertext.context.parameter.polyDegree) - ciphertext.polys[0] = try ciphertext.polys[0].applyGalois(galoisElement: element) + ciphertext.polys[0] = ciphertext.polys[0].applyGalois(galoisElement: element) } // for Plaintext - static func addAssign(lhs: inout CoeffPlaintext, rhs: CoeffPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout CoeffPlaintext, _ rhs: CoeffPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.poly += rhs.poly } - static func addAssign(lhs: inout EvalPlaintext, rhs: EvalPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout EvalPlaintext, _ rhs: EvalPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.poly += rhs.poly } // for CoeffCiphertext - static func addAssign(lhs: inout CoeffCiphertext, rhs: CoeffPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.polys[0] += rhs.poly } - static func addAssign(lhs: inout CoeffCiphertext, rhs: EvalPlaintext) throws { - try addAssign(lhs: &lhs, rhs: rhs.inverseNtt()) + static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: EvalPlaintext) throws { + try addAssign(&lhs, rhs.inverseNtt()) } - static func mulAssign(lhs: inout CoeffCiphertext, rhs: CoeffPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func mulAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) var evalLhs = try lhs.forwardNtt() let evalRhs = try rhs.forwardNtt() try mulAssign(&evalLhs, evalRhs) lhs = try evalLhs.inverseNtt() } - static func mulAssign(lhs: inout CoeffCiphertext, rhs: EvalPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func mulAssign(_ lhs: inout CoeffCiphertext, _ rhs: EvalPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) var evalLhs = try lhs.forwardNtt() try mulAssign(&evalLhs, rhs) lhs = try evalLhs.inverseNtt() } - static func addAssign(lhs: inout CoeffCiphertext, rhs: CoeffCiphertext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffCiphertext) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.polys[0] += rhs.polys[0] } + static func subAssign(_ lhs: inout CoeffCiphertext, _ rhs: CoeffCiphertext) throws { + try checkContextConsistency(lhs.context, rhs.context) + lhs.polys[0] -= rhs.polys[0] + } + // for EvalCiphertext - static func addAssign(lhs: inout EvalCiphertext, rhs: CoeffPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: CoeffPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) let evalRhs = try rhs.forwardNtt() - try addAssign(lhs: &lhs, rhs: evalRhs) + try addAssign(&lhs, evalRhs) } - static func addAssign(lhs: inout EvalCiphertext, rhs: EvalPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.polys[0] += rhs.poly } - static func mulAssign(lhs: inout EvalCiphertext, rhs: CoeffPlaintext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func subAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalCiphertext) throws { + try checkContextConsistency(lhs.context, rhs.context) + lhs.polys[0] -= rhs.polys[0] + } + + static func mulAssign(_ lhs: inout EvalCiphertext, _ rhs: CoeffPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) let evalRhs = try rhs.forwardNtt() try mulAssign(&lhs, evalRhs) } - static func mulAssign(_ ciphertext: inout EvalCiphertext, _ plaintext: EvalPlaintext) throws { - try checkContextConsistency(lhs: ciphertext.context, rhs: plaintext.context) -// let coeffPlaintext = try plaintext.convertToCoeffFormat() -// let plaintextPoly = try coeffPlaintext.poly.forwardNtt() - ciphertext.polys[0] *= plaintext.poly + static func mulAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalPlaintext) throws { + try checkContextConsistency(lhs.context, rhs.context) + lhs.polys[0] *= rhs.poly } - static func mulAssign(lhs: inout NoOpScheme.CanonicalCiphertext, - rhs: NoOpScheme.CanonicalCiphertext) throws + static func mulAssign(_ lhs: inout NoOpScheme.CanonicalCiphertext, + _ rhs: NoOpScheme.CanonicalCiphertext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + try checkContextConsistency(lhs.context, rhs.context) lhs.polys[0] = try (lhs.polys[0].forwardNtt() * rhs.polys[0].forwardNtt()).inverseNtt() } - static func addAssign(lhs: inout EvalCiphertext, rhs: EvalCiphertext) throws { - try checkContextConsistency(lhs: lhs.context, rhs: rhs.context) + static func addAssign(_ lhs: inout EvalCiphertext, _ rhs: EvalCiphertext) throws { + try checkContextConsistency(lhs.context, rhs.context) lhs.polys[0] += rhs.polys[0] } diff --git a/Sources/SwiftHe/Plaintext.swift b/Sources/SwiftHe/Plaintext.swift index 82030440..f40fb0c4 100644 --- a/Sources/SwiftHe/Plaintext.swift +++ b/Sources/SwiftHe/Plaintext.swift @@ -11,11 +11,11 @@ public struct Plaintext { } static func += (lhs: inout Plaintext, rhs: Plaintext) throws where Format == Coeff { - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.addAssign(&lhs, rhs) } static func += (lhs: inout Plaintext, rhs: Plaintext) throws where Format == Eval { - try Scheme.addAssign(lhs: &lhs, rhs: rhs) + try Scheme.addAssign(&lhs, rhs) } func forwardNtt() throws -> Plaintext where Format == Coeff { diff --git a/Tests/SwiftHeTests/HeAPITests.swift b/Tests/SwiftHeTests/HeAPITests.swift index 4077d4e4..9451038f 100644 --- a/Tests/SwiftHeTests/HeAPITests.swift +++ b/Tests/SwiftHeTests/HeAPITests.swift @@ -169,6 +169,41 @@ class HeAPITests: XCTestCase { XCTAssertEqual(decryptedData3, sumData) } + private func schemeSameTypeSubtractionTest(context: Context) throws { + let testEnv = try TestEnv(context: context, format: .coefficient) + let data1 = testEnv.data1 + let data2 = testEnv.data2 + var diffData = [Scheme.Scalar](repeating: 0, count: testPolyDegree) + for index in diffData.indices { + diffData[index] = data1[index].subtractMod(data2[index], modulus: context.parameter.plaintextModulus) + } + + var ciphertext1 = testEnv.ciphertext1 + var ciphertext2 = testEnv.ciphertext2 + let ciphertextDiff1 = try ciphertext1 - ciphertext2 + try ciphertext1.modSwitchDownToSingle() + try ciphertext2.modSwitchDownToSingle() + let ciphertextDiff2 = try ciphertext1 - ciphertext2 + + let evalCiphertext: Ciphertext = try ciphertextDiff1.convertToEvalFormat() + let coeffCiphertext: Ciphertext = try evalCiphertext.inverseNtt() + + let secretKey = testEnv.secretKey + let decryptedData1: [Scheme.Scalar] = try context.decode( + plaintext: Scheme.decrypt(ciphertext: coeffCiphertext, secretKey: secretKey), + format: .coefficient) + let decryptedData2: [Scheme.Scalar] = try context.decode( + plaintext: Scheme.decrypt(ciphertext: evalCiphertext, secretKey: secretKey), + format: .coefficient) + let decryptedData3: [Scheme.Scalar] = try context.decode( + plaintext: Scheme.decrypt(ciphertext: ciphertextDiff2, secretKey: secretKey), + format: .coefficient) + + XCTAssertEqual(decryptedData1, diffData) + XCTAssertEqual(decryptedData2, diffData) + XCTAssertEqual(decryptedData3, diffData) + } + private func schemeCiphertextCiphertextMultiplicationTest(context: Context) throws { let testEnv = try TestEnv(context: context, format: .simd) let data1 = testEnv.data1 @@ -317,6 +352,7 @@ class HeAPITests: XCTestCase { try schemeEncodeDecodeTest(context: context) try schemeEncryptDecryptTest(context: context) try schemeSameTypeAdditionTest(context: context) + try schemeSameTypeSubtractionTest(context: context) try schemeCiphertextCiphertextMultiplicationTest(context: context) try schemeCiphertextPlaintextAdditionTest(context: context) try schemeCiphertextPlaintextMultiplicationTest(context: context) @@ -330,6 +366,7 @@ class HeAPITests: XCTestCase { try schemeEncodeDecodeTest(context: context) try schemeEncryptDecryptTest(context: context) try schemeSameTypeAdditionTest(context: context) + try schemeSameTypeSubtractionTest(context: context) try schemeCiphertextPlaintextAdditionTest(context: context) try schemeCiphertextPlaintextMultiplicationTest(context: context) } @@ -339,6 +376,7 @@ class HeAPITests: XCTestCase { try schemeEncodeDecodeTest(context: context) try schemeEncryptDecryptTest(context: context) try schemeSameTypeAdditionTest(context: context) + try schemeSameTypeSubtractionTest(context: context) try schemeCiphertextPlaintextAdditionTest(context: context) try schemeCiphertextPlaintextMultiplicationTest(context: context) }