diff --git a/Benchmarks/RlweBenchmark/RlweBenchmark.swift b/Benchmarks/RlweBenchmark/RlweBenchmark.swift index f38e93ad..5128885a 100644 --- a/Benchmarks/RlweBenchmark/RlweBenchmark.swift +++ b/Benchmarks/RlweBenchmark/RlweBenchmark.swift @@ -85,7 +85,7 @@ struct RlweBenchmarkContext: Sendable { self.serializedEvaluationKey = evaluationKey.serialize() self.data = getRandomPlaintextData(count: polyDegree, in: 0..(_: Scheme.Type) -> () -> Void benchmark.startMeasurement() var plaintext: Scheme.CoeffPlaintext? for _ in benchmark.scaledIterations { - try blackHole(plaintext = Scheme.encode(context: benchmarkContext.context, - values: benchmarkContext.data, - format: .coefficient)) + try blackHole(plaintext = benchmarkContext.context.encode(values: benchmarkContext.data, + format: .coefficient)) } // Avoid warning about variable written to, but never read withExtendedLifetime(plaintext) {} @@ -140,9 +139,8 @@ func encodeSimdBenchmark(_: Scheme.Type) -> () -> Void { benchmark.startMeasurement() var plaintext: Scheme.CoeffPlaintext? for _ in benchmark.scaledIterations { - try blackHole(plaintext = Scheme.encode(context: benchmarkContext.context, - values: benchmarkContext.data, - format: .simd)) + try blackHole(plaintext = benchmarkContext.context.encode(values: benchmarkContext.data, + format: .simd)) } // Avoid warning about variable written to, but never read withExtendedLifetime(plaintext) {} diff --git a/Snippets/HomomorphicEncryption/SerializationSnippet.swift b/Snippets/HomomorphicEncryption/SerializationSnippet.swift index cc7556c6..80fb9297 100644 --- a/Snippets/HomomorphicEncryption/SerializationSnippet.swift +++ b/Snippets/HomomorphicEncryption/SerializationSnippet.swift @@ -127,7 +127,7 @@ deserialized = try Ciphertext( context: context, moduliCount: 1) let decryptedIndices = try deserialized.decrypt(using: secretKey) -let clientDecoded = try decryptedIndices.decode(format: .coefficient) +let clientDecoded: [UInt32] = try decryptedIndices.decode(format: .coefficient) for index in indices { precondition(clientDecoded[index] == expectedValues[index]) } diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift index a9afbd35..25287ff6 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift @@ -24,28 +24,22 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes public static func encode(context: Context>, values: [some ScalarType], format: EncodeFormat, - moduliCount: Int) throws -> EvalPlaintext + moduliCount: Int?) throws -> EvalPlaintext { - try context.encode(values: values, format: format, moduliCount: moduliCount) + let coeffPlaintext = try Self.encode(context: context, values: values, format: format) + return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount) } @inlinable // swiftlint:disable:next missing_docs attributes - public static func encode(context: Context>, values: [some ScalarType], - format: EncodeFormat) throws -> EvalPlaintext - { - try context.encode(values: values, format: format, moduliCount: nil) - } - - @inlinable - // swiftlint:disable:next missing_docs attributes - public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] where V: ScalarType { + public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] { try plaintext.context.decode(plaintext: plaintext, format: format) } @inlinable // swiftlint:disable:next missing_docs attributes - public static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] where V: ScalarType { - try plaintext.context.decode(plaintext: plaintext, format: format) + public static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] { + let coeffPlaintext = try plaintext.convertToCoeffFormat() + return try coeffPlaintext.decode(format: format) } } diff --git a/Sources/HomomorphicEncryption/Context.swift b/Sources/HomomorphicEncryption/Context.swift index 2dc1aa28..86041483 100644 --- a/Sources/HomomorphicEncryption/Context.swift +++ b/Sources/HomomorphicEncryption/Context.swift @@ -39,7 +39,7 @@ public final class Context: Equatable, Sendable { /// `keySwitchingContexts[0].next.moduli = [q_0, q_1]` @usableFromInline let keySwitchingContexts: [PolyContext] - /// the rns tools for each level of ciphertexts, with number of modulis in descending order. + /// The rns tools for each level of ciphertexts, with number of moduli in descending order. @usableFromInline let rnsTools: [RnsTool] /// The plaintext modulus,`t`. diff --git a/Sources/HomomorphicEncryption/Encoding.swift b/Sources/HomomorphicEncryption/Encoding.swift index 7d32429f..5eaabe98 100644 --- a/Sources/HomomorphicEncryption/Encoding.swift +++ b/Sources/HomomorphicEncryption/Encoding.swift @@ -16,25 +16,6 @@ // corresponding encode/decode functions in specific Scheme instead. extension Context { - /// Encodes `values` in the given format. - /// - Parameters: - /// - values: Values to encode. - /// - format: Encoding format. - /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the - /// moduli. - /// - Returns: The plaintext encoding `values`. - /// - Throws: Error upon failure to encode. - @inlinable - public func encode( - values: [some ScalarType], - format: EncodeFormat, - moduliCount: Int? = nil) throws -> Plaintext - where Scheme == Bfv - { - let coeffPlaintext: Plaintext = try encode(values: values, format: format) - return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount) - } - /// Encodes `values` in the given format. /// /// Encoding will use the top-level ciphertext context with all moduli. @@ -44,15 +25,16 @@ extension Context { /// - Returns: The plaintext encoding `values`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext - where Scheme == Bfv - { - let coeffPlaintext: Plaintext = try encode(values: values, format: format) - return try coeffPlaintext.convertToEvalFormat(moduliCount: nil) + public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext { + try validDataForEncoding(values: values) + switch format { + case .coefficient: + return try encodeCoefficient(values: values) + case .simd: + return try encodeSimd(values: values) + } } -} -extension Context { /// Encodes `values` in the given format. /// - Parameters: /// - values: Values to encode. @@ -62,43 +44,12 @@ extension Context { /// - Returns: The plaintext encoding `values`. /// - Throws: Error upon failure to encode. @inlinable - public func encode(values: [some ScalarType], format: EncodeFormat, - moduliCount: Int? = nil) throws -> Plaintext + public func encode( + values: [some ScalarType], + format: EncodeFormat, + moduliCount: Int? = nil) throws -> Plaintext { - let coeffPlaintext: Plaintext = try encode(values: values, format: format) - return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount) - } - - /// Encodes `values` in the given format. - /// - /// Encoding will use the top-level ciphertext context with all moduli. - /// - Parameters: - /// - values: Values to encode. - /// - format: Encoding format. - /// - Returns: The plaintext encoding `values`. - /// - Throws: Error upon failure to encode. - @inlinable - public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext { - try encode(values: values, format: format, moduliCount: nil) - } - - /// Encodes `values` in the given format. - /// - /// Encoding will use the top-level ciphertext context with all moduli. - /// - Parameters: - /// - values: Values to encode. - /// - format: Encoding format. - /// - Returns: The plaintext encoding `values`. - /// - Throws: Error upon failure to encode. - @inlinable - public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext { - try validDataForEncoding(values: values) - switch format { - case .coefficient: - return try encodeCoefficient(values: values) - case .simd: - return try encodeSimd(values: values) - } + try Scheme.encode(context: self, values: values, format: format, moduliCount: moduliCount) } /// Decodes a plaintext with the given format. @@ -131,7 +82,7 @@ extension Context { public func decode(plaintext: Plaintext, format: EncodeFormat) throws -> [T] { - try decode(plaintext: plaintext.convertToCoeffFormat(), format: format) + try Scheme.decode(plaintext: plaintext, format: format) } @inlinable diff --git a/Sources/HomomorphicEncryption/HeScheme.swift b/Sources/HomomorphicEncryption/HeScheme.swift index d88a9ebd..7e97154c 100644 --- a/Sources/HomomorphicEncryption/HeScheme.swift +++ b/Sources/HomomorphicEncryption/HeScheme.swift @@ -63,6 +63,9 @@ public enum EncodeFormat: CaseIterable { } /// Protocol for HE schemes. +/// +/// The protocol should be implemented when adding a new HE scheme. +/// However, several functions have an alternative API which is more ergonomic and should be preferred. public protocol HeScheme { /// Coefficient type for each polynomial. associatedtype Scalar: ScalarType @@ -140,6 +143,7 @@ public protocol HeScheme { /// - format: Encoding format. /// - Returns: A plaintext encoding `values`. /// - Throws: Error upon failure to encode. + /// - seealso: ``Context/encode(values:format:)`` for an alternative API. static func encode(context: Context, values: [some ScalarType], format: EncodeFormat) throws -> CoeffPlaintext /// Encodes values into a plaintext with evaluation format. @@ -149,22 +153,13 @@ public protocol HeScheme { /// - context: Context for HE computation. /// - values: Values to encode. /// - format: Encoding format. - /// - moduliCount: Number of coefficient moduli in the encoded plaintext. + /// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the + /// moduli. /// - Returns: A plaintext encoding `values`. /// - Throws: Error upon failure to encode. + /// - seealso: ``Context/encode(values:format:moduliCount:)`` for an alternative API. static func encode(context: Context, values: [some ScalarType], format: EncodeFormat, - moduliCount: Int) throws -> EvalPlaintext - - /// Encodes `values` into a plaintext with evaluation format and with top-level ciphertext context with all moduli. - /// - seealso: ``HeScheme/encode(context:values:format:moduliCount:)`` - /// for an alternative which allows specifying the `moduliCount`. - /// - Parameters: - /// - context: Context for HE computation. - /// - values: Values to encode. - /// - format: Encoding format. - /// - Returns: A plaintext encoding `values`. - /// - Throws: Error upon failure to encode. - static func encode(context: Context, values: [some ScalarType], format: EncodeFormat) throws -> EvalPlaintext + moduliCount: Int?) throws -> EvalPlaintext /// Decodes a plaintext in ``Coeff`` format. /// - Parameters: diff --git a/Sources/HomomorphicEncryption/NoOpScheme.swift b/Sources/HomomorphicEncryption/NoOpScheme.swift index 22dab5c0..b9d36ee8 100644 --- a/Sources/HomomorphicEncryption/NoOpScheme.swift +++ b/Sources/HomomorphicEncryption/NoOpScheme.swift @@ -53,15 +53,10 @@ public enum NoOpScheme: HeScheme { } public static func encode(context: Context, values: [some ScalarType], - format: EncodeFormat, moduliCount _: Int) throws -> EvalPlaintext + format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext { - try encode(context: context, values: values, format: format).forwardNtt() - } - - public static func encode(context: Context, values: [some ScalarType], - format: EncodeFormat) throws -> EvalPlaintext - { - try encode(context: context, values: values, format: format, moduliCount: 1) + let coeffPlaintext = try Self.encode(context: context, values: values, format: format) + return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt()) } public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType { diff --git a/Sources/HomomorphicEncryption/Plaintext.swift b/Sources/HomomorphicEncryption/Plaintext.swift index 3592c4ee..17ef0962 100644 --- a/Sources/HomomorphicEncryption/Plaintext.swift +++ b/Sources/HomomorphicEncryption/Plaintext.swift @@ -176,8 +176,8 @@ extension Plaintext { /// - Throws: Error upon failure to decode the plaintext. /// - seealso: ``HeScheme/decode(plaintext:format:)-h6vl`` for an alternative API. @inlinable - public func decode(format: EncodeFormat) throws -> [Scheme.Scalar] where Format == Coeff { - try context.decode(plaintext: self, format: format) + public func decode(format: EncodeFormat) throws -> [T] where Format == Coeff { + try Scheme.decode(plaintext: self, format: format) } /// Decodes a plaintext in ``Eval`` format. @@ -186,8 +186,8 @@ extension Plaintext { /// - Throws: Error upon failure to decode the plaintext. /// - seealso: ``HeScheme/decode(plaintext:format:)-663x4`` for an alternative API. @inlinable - public func decode(format: EncodeFormat) throws -> [Scheme.Scalar] where Format == Eval { - try context.decode(plaintext: self, format: format) + public func decode(format: EncodeFormat) throws -> [T] where Format == Eval { + try Scheme.decode(plaintext: self, format: format) } /// Symmetric secret key encryption of the plaintext. diff --git a/Sources/HomomorphicEncryption/SerializedCiphertext.swift b/Sources/HomomorphicEncryption/SerializedCiphertext.swift index 363b02e8..8a3b98ac 100644 --- a/Sources/HomomorphicEncryption/SerializedCiphertext.swift +++ b/Sources/HomomorphicEncryption/SerializedCiphertext.swift @@ -37,7 +37,8 @@ extension Ciphertext { /// - Parameters: /// - serialized: Serialized ciphertext. /// - context: Context to associate with the ciphertext. - /// - moduliCount: Number of moduli in the serialized ciphertext. + /// - moduliCount: Number of moduli in the serialized ciphertext. If not set, deserialization will use the + /// top-level ciphertext with all the moduli. /// - Throws: Error upon failure to deserialize the ciphertext. @inlinable public init( diff --git a/Sources/HomomorphicEncryption/SerializedPlaintext.swift b/Sources/HomomorphicEncryption/SerializedPlaintext.swift index 60cf5b01..7e03ef6e 100644 --- a/Sources/HomomorphicEncryption/SerializedPlaintext.swift +++ b/Sources/HomomorphicEncryption/SerializedPlaintext.swift @@ -48,8 +48,9 @@ extension Plaintext where Format == Eval { /// - Parameters: /// - serialized: Serialized plaintext. /// - context: Context to associate with the plaintext. - /// - moduliCount: Optional number of moduli to associate with the plaintext. If `nil`, the deserialized plaintext - /// will have the ciphertext context with `moduliCount` moduli. + /// - moduliCount: Optional number of moduli to associate with the plaintext. If not set, the plaintext will have + /// the top-level ciphertext context with all the + /// moduli. /// - Throws: Error upon failure to deserialize. public init(deserialize serialized: SerializedPlaintext, context: Context, moduliCount: Int? = nil) throws { self.context = context diff --git a/Sources/PrivateInformationRetrieval/MulPir.swift b/Sources/PrivateInformationRetrieval/MulPir.swift index 4127d0f6..fb450957 100644 --- a/Sources/PrivateInformationRetrieval/MulPir.swift +++ b/Sources/PrivateInformationRetrieval/MulPir.swift @@ -453,7 +453,7 @@ extension MulPirServer { if coefficients.allSatisfy({ $0 == 0 }) { return nil } - return try Scheme.encode(context: context, values: coefficients, format: .coefficient) + return try context.encode(values: coefficients, format: .coefficient) } } @@ -502,7 +502,7 @@ extension MulPirServer { if plaintextCoefficients.allSatisfy({ $0 == 0 }) { return nil } - return try Scheme.encode(context: context, values: plaintextCoefficients, format: .coefficient) + return try context.encode(values: plaintextCoefficients, format: .coefficient) } let perChunkPlaintextCount = IndexPir.computePerChunkPlaintextCount(for: parameter) while plaintexts.count < perChunkPlaintextCount { diff --git a/Sources/PrivateInformationRetrieval/PirUtil.swift b/Sources/PrivateInformationRetrieval/PirUtil.swift index 6a0897f2..c8986039 100644 --- a/Sources/PrivateInformationRetrieval/PirUtil.swift +++ b/Sources/PrivateInformationRetrieval/PirUtil.swift @@ -155,7 +155,7 @@ enum PirUtil { for index in nonZeroInputs { rawData[index] = inverseInputCountCeilLog } - return try Scheme.encode(context: context, values: rawData, format: .coefficient) + return try context.encode(values: rawData, format: .coefficient) } /// Generate the ciphertext based on the given non-zero positions. diff --git a/Sources/TestUtilities/TestUtilities.swift b/Sources/TestUtilities/TestUtilities.swift index 84324388..959fddac 100644 --- a/Sources/TestUtilities/TestUtilities.swift +++ b/Sources/TestUtilities/TestUtilities.swift @@ -199,9 +199,7 @@ extension TestUtils { package static func getRandomPlaintextData(count: Int, in range: Range) -> [T] { - (0..(poly: PolyRq) { diff --git a/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift b/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift index 86a6587e..1f404250 100644 --- a/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift +++ b/Tests/HomomorphicEncryptionProtobufTests/ConversionTests.swift @@ -57,9 +57,9 @@ class ConversionTests: XCTestCase { func runTest(_: Scheme.Type) throws { let context: Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try Scheme.zeroCiphertext( context: context, - moduliCount: context.ciphertextContext.moduli.count) + moduliCount: testEnv.evalPlaintext1.moduli.count) let product = try zeroCiphertext * testEnv.evalPlaintext1 XCTAssert(product.isTransparent()) @@ -846,10 +840,10 @@ class HeAPITests: XCTestCase { var ciphertext = testEnv.ciphertext1 try ciphertext.modSwitchDown() let evalCiphertext = try ciphertext.convertToEvalFormat() - let evalPlaintext = try Scheme.encode(context: testEnv.context, - values: testEnv.data2, - format: .simd, - moduliCount: evalCiphertext.moduli.count) + let evalPlaintext = try testEnv.context.encode( + values: testEnv.data2, + format: .simd, + moduliCount: evalCiphertext.moduli.count) try testEnv.checkDecryptsDecodes( ciphertext: evalCiphertext * evalPlaintext, format: .simd, diff --git a/Tests/HomomorphicEncryptionTests/SerializationTests.swift b/Tests/HomomorphicEncryptionTests/SerializationTests.swift index aa54828f..c8e121b2 100644 --- a/Tests/HomomorphicEncryptionTests/SerializationTests.swift +++ b/Tests/HomomorphicEncryptionTests/SerializationTests.swift @@ -21,9 +21,9 @@ class SerializationTests: XCTestCase { func runTest(_: Scheme.Type) throws { let context: Context = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try TestUtils.getTestContext() let values = TestUtils.getRandomPlaintextData(count: context.degree, in: 0.. = try Scheme.encode(context: context, - values: data, - format: .coefficient) + let plaintext: Plaintext = try context.encode(values: data, + format: .coefficient) let secretKey = try context.generateSecretKey() let expandedQueryCount = degree