Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NoOp scheme context.encode/decode API #51

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions Benchmarks/RlweBenchmark/RlweBenchmark.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct RlweBenchmarkContext<Scheme: HeScheme>: Sendable {
self.serializedEvaluationKey = evaluationKey.serialize()

self.data = getRandomPlaintextData(count: polyDegree, in: 0..<Scheme.Scalar(plaintextModulus))
self.coeffPlaintext = try Scheme.encode(context: context, values: data, format: .simd)
self.coeffPlaintext = try context.encode(values: data, format: .simd)
self.evalPlaintext = try coeffPlaintext.convertToEvalFormat()
self.ciphertext = try coeffPlaintext.encrypt(using: secretKey)
self.evalCiphertext = try ciphertext.convertToEvalFormat()
Expand Down Expand Up @@ -123,9 +123,8 @@ func encodeCoefficientBenchmark<Scheme: HeScheme>(_: Scheme.Type) -> () -> Void
benchmark.startMeasurement()
var plaintext: Scheme.CoeffPlaintext?
for _ in benchmark.scaledIterations {
try blackHole(plaintext = Scheme.encode(context: benchmarkContext.context,
values: benchmarkContext.data,
format: .coefficient))
try blackHole(plaintext = benchmarkContext.context.encode(values: benchmarkContext.data,
format: .coefficient))
}
// Avoid warning about variable written to, but never read
withExtendedLifetime(plaintext) {}
Expand All @@ -140,9 +139,8 @@ func encodeSimdBenchmark<Scheme: HeScheme>(_: Scheme.Type) -> () -> Void {
benchmark.startMeasurement()
var plaintext: Scheme.CoeffPlaintext?
for _ in benchmark.scaledIterations {
try blackHole(plaintext = Scheme.encode(context: benchmarkContext.context,
values: benchmarkContext.data,
format: .simd))
try blackHole(plaintext = benchmarkContext.context.encode(values: benchmarkContext.data,
format: .simd))
}
// Avoid warning about variable written to, but never read
withExtendedLifetime(plaintext) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ deserialized = try Ciphertext(
context: context,
moduliCount: 1)
let decryptedIndices = try deserialized.decrypt(using: secretKey)
let clientDecoded = try decryptedIndices.decode(format: .coefficient)
let clientDecoded: [UInt32] = try decryptedIndices.decode(format: .coefficient)
for index in indices {
precondition(clientDecoded[index] == expectedValues[index])
}
20 changes: 7 additions & 13 deletions Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,22 @@ extension Bfv {
@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, values: [some ScalarType], format: EncodeFormat,
moduliCount: Int) throws -> EvalPlaintext
moduliCount: Int?) throws -> EvalPlaintext
{
try context.encode(values: values, format: format, moduliCount: moduliCount)
let coeffPlaintext = try Self.encode(context: context, values: values, format: format)
return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, values: [some ScalarType],
format: EncodeFormat) throws -> EvalPlaintext
{
try context.encode(values: values, format: format, moduliCount: nil)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] where V: ScalarType {
public static func decode<V: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] where V: ScalarType {
try plaintext.context.decode(plaintext: plaintext, format: format)
public static func decode<V: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
let coeffPlaintext = try plaintext.convertToCoeffFormat()
return try coeffPlaintext.decode(format: format)
}
}
2 changes: 1 addition & 1 deletion Sources/HomomorphicEncryption/Context.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public final class Context<Scheme: HeScheme>: Equatable, Sendable {
/// `keySwitchingContexts[0].next.moduli = [q_0, q_1]`
@usableFromInline let keySwitchingContexts: [PolyContext<Scheme.Scalar>]

/// the rns tools for each level of ciphertexts, with number of modulis in descending order.
/// The rns tools for each level of ciphertexts, with number of moduli in descending order.
@usableFromInline let rnsTools: [RnsTool<Scheme.Scalar>]

/// The plaintext modulus,`t`.
Expand Down
77 changes: 14 additions & 63 deletions Sources/HomomorphicEncryption/Encoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,6 @@
// corresponding encode/decode functions in specific Scheme instead.

extension Context {
/// Encodes `values` in the given format.
/// - Parameters:
/// - values: Values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// moduli.
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode<T: ScalarType>(
values: [some ScalarType],
format: EncodeFormat,
moduliCount: Int? = nil) throws -> Plaintext<Scheme, Eval>
where Scheme == Bfv<T>
{
let coeffPlaintext: Plaintext<Scheme, Coeff> = try encode(values: values, format: format)
return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount)
}

/// Encodes `values` in the given format.
///
/// Encoding will use the top-level ciphertext context with all moduli.
Expand All @@ -44,15 +25,16 @@ extension Context {
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode<T: ScalarType>(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Eval>
where Scheme == Bfv<T>
{
let coeffPlaintext: Plaintext<Scheme, Coeff> = try encode(values: values, format: format)
return try coeffPlaintext.convertToEvalFormat(moduliCount: nil)
public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Coeff> {
try validDataForEncoding(values: values)
switch format {
case .coefficient:
return try encodeCoefficient(values: values)
case .simd:
return try encodeSimd(values: values)
}
}
}

extension Context {
/// Encodes `values` in the given format.
/// - Parameters:
/// - values: Values to encode.
Expand All @@ -62,43 +44,12 @@ extension Context {
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(values: [some ScalarType], format: EncodeFormat,
moduliCount: Int? = nil) throws -> Plaintext<Scheme, Eval>
public func encode(
values: [some ScalarType],
format: EncodeFormat,
moduliCount: Int? = nil) throws -> Plaintext<Scheme, Eval>
{
let coeffPlaintext: Plaintext<Scheme, Coeff> = try encode(values: values, format: format)
return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount)
}

/// Encodes `values` in the given format.
///
/// Encoding will use the top-level ciphertext context with all moduli.
/// - Parameters:
/// - values: Values to encode.
/// - format: Encoding format.
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Eval> {
try encode(values: values, format: format, moduliCount: nil)
}

/// Encodes `values` in the given format.
///
/// Encoding will use the top-level ciphertext context with all moduli.
/// - Parameters:
/// - values: Values to encode.
/// - format: Encoding format.
/// - Returns: The plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(values: [some ScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Coeff> {
try validDataForEncoding(values: values)
switch format {
case .coefficient:
return try encodeCoefficient(values: values)
case .simd:
return try encodeSimd(values: values)
}
try Scheme.encode(context: self, values: values, format: format, moduliCount: moduliCount)
}

/// Decodes a plaintext with the given format.
Expand Down Expand Up @@ -131,7 +82,7 @@ extension Context {
public func decode<T: ScalarType>(plaintext: Plaintext<Scheme, Eval>,
format: EncodeFormat) throws -> [T]
{
try decode(plaintext: plaintext.convertToCoeffFormat(), format: format)
try Scheme.decode(plaintext: plaintext, format: format)
}

@inlinable
Expand Down
21 changes: 8 additions & 13 deletions Sources/HomomorphicEncryption/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ public enum EncodeFormat: CaseIterable {
}

/// Protocol for HE schemes.
///
/// The protocol should be implemented when adding a new HE scheme.
/// However, several functions have an alternative API which is more ergonomic and should be preferred.
public protocol HeScheme {
/// Coefficient type for each polynomial.
associatedtype Scalar: ScalarType
Expand Down Expand Up @@ -140,6 +143,7 @@ public protocol HeScheme {
/// - format: Encoding format.
/// - Returns: A plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(values:format:)`` for an alternative API.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat) throws -> CoeffPlaintext

/// Encodes values into a plaintext with evaluation format.
Expand All @@ -149,22 +153,13 @@ public protocol HeScheme {
/// - context: Context for HE computation.
/// - values: Values to encode.
/// - format: Encoding format.
/// - moduliCount: Number of coefficient moduli in the encoded plaintext.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// moduli.
/// - Returns: A plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(values:format:moduliCount:)`` for an alternative API.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat,
moduliCount: Int) throws -> EvalPlaintext

/// Encodes `values` into a plaintext with evaluation format and with top-level ciphertext context with all moduli.
/// - seealso: ``HeScheme/encode(context:values:format:moduliCount:)``
/// for an alternative which allows specifying the `moduliCount`.
/// - Parameters:
/// - context: Context for HE computation.
/// - values: Values to encode.
/// - format: Encoding format.
/// - Returns: A plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat) throws -> EvalPlaintext
moduliCount: Int?) throws -> EvalPlaintext

/// Decodes a plaintext in ``Coeff`` format.
/// - Parameters:
Expand Down
11 changes: 3 additions & 8 deletions Sources/HomomorphicEncryption/NoOpScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,10 @@ public enum NoOpScheme: HeScheme {
}

public static func encode(context: Context<NoOpScheme>, values: [some ScalarType],
format: EncodeFormat, moduliCount _: Int) throws -> EvalPlaintext
format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext
{
try encode(context: context, values: values, format: format).forwardNtt()
}

public static func encode(context: Context<NoOpScheme>, values: [some ScalarType],
format: EncodeFormat) throws -> EvalPlaintext
{
try encode(context: context, values: values, format: format, moduliCount: 1)
let coeffPlaintext = try Self.encode(context: context, values: values, format: format)
return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt())
}

public static func decode<T>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType {
Expand Down
8 changes: 4 additions & 4 deletions Sources/HomomorphicEncryption/Plaintext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ extension Plaintext {
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``HeScheme/decode(plaintext:format:)-h6vl`` for an alternative API.
@inlinable
public func decode(format: EncodeFormat) throws -> [Scheme.Scalar] where Format == Coeff {
try context.decode(plaintext: self, format: format)
public func decode<T: ScalarType>(format: EncodeFormat) throws -> [T] where Format == Coeff {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Eval`` format.
Expand All @@ -186,8 +186,8 @@ extension Plaintext {
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``HeScheme/decode(plaintext:format:)-663x4`` for an alternative API.
@inlinable
public func decode(format: EncodeFormat) throws -> [Scheme.Scalar] where Format == Eval {
try context.decode(plaintext: self, format: format)
public func decode<T: ScalarType>(format: EncodeFormat) throws -> [T] where Format == Eval {
try Scheme.decode(plaintext: self, format: format)
}

/// Symmetric secret key encryption of the plaintext.
Expand Down
3 changes: 2 additions & 1 deletion Sources/HomomorphicEncryption/SerializedCiphertext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ extension Ciphertext {
/// - Parameters:
/// - serialized: Serialized ciphertext.
/// - context: Context to associate with the ciphertext.
/// - moduliCount: Number of moduli in the serialized ciphertext.
/// - moduliCount: Number of moduli in the serialized ciphertext. If not set, deserialization will use the
/// top-level ciphertext with all the moduli.
/// - Throws: Error upon failure to deserialize the ciphertext.
@inlinable
public init(
Expand Down
5 changes: 3 additions & 2 deletions Sources/HomomorphicEncryption/SerializedPlaintext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ extension Plaintext where Format == Eval {
/// - Parameters:
/// - serialized: Serialized plaintext.
/// - context: Context to associate with the plaintext.
/// - moduliCount: Optional number of moduli to associate with the plaintext. If `nil`, the deserialized plaintext
/// will have the ciphertext context with `moduliCount` moduli.
/// - moduliCount: Optional number of moduli to associate with the plaintext. If not set, the plaintext will have
/// the top-level ciphertext context with all the
/// moduli.
/// - Throws: Error upon failure to deserialize.
public init(deserialize serialized: SerializedPlaintext, context: Context<Scheme>, moduliCount: Int? = nil) throws {
self.context = context
Expand Down
4 changes: 2 additions & 2 deletions Sources/PrivateInformationRetrieval/MulPir.swift
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ extension MulPirServer {
if coefficients.allSatisfy({ $0 == 0 }) {
return nil
}
return try Scheme.encode(context: context, values: coefficients, format: .coefficient)
return try context.encode(values: coefficients, format: .coefficient)
}
}

Expand Down Expand Up @@ -502,7 +502,7 @@ extension MulPirServer {
if plaintextCoefficients.allSatisfy({ $0 == 0 }) {
return nil
}
return try Scheme.encode(context: context, values: plaintextCoefficients, format: .coefficient)
return try context.encode(values: plaintextCoefficients, format: .coefficient)
}
let perChunkPlaintextCount = IndexPir.computePerChunkPlaintextCount(for: parameter)
while plaintexts.count < perChunkPlaintextCount {
Expand Down
2 changes: 1 addition & 1 deletion Sources/PrivateInformationRetrieval/PirUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ enum PirUtil<Scheme: HeScheme> {
for index in nonZeroInputs {
rawData[index] = inverseInputCountCeilLog
}
return try Scheme.encode(context: context, values: rawData, format: .coefficient)
return try context.encode(values: rawData, format: .coefficient)
}

/// Generate the ciphertext based on the given non-zero positions.
Expand Down
4 changes: 1 addition & 3 deletions Sources/TestUtilities/TestUtilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ extension TestUtils {
package static func getRandomPlaintextData<T: ScalarType>(count: Int,
in range: Range<T>) -> [T]
{
(0..<count).map { _ in
T.random(in: range)
}
(0..<count).map { _ in T.random(in: range) }
}

package static func uniformnessTest<T>(poly: PolyRq<T, some Any>) {
Expand Down
Loading