Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a few things needed for a Pnns service. #81

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions Sources/PrivateNearestNeighborsSearch/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,15 @@ public struct Client<Scheme: HeScheme> {
}
self.config = config

if !contexts.isEmpty {
precondition(contexts.count == config.encryptionParameters.count)
for (context, encryptionParameters) in zip(contexts, config.encryptionParameters) {
guard context.encryptionParameters == encryptionParameters else {
throw PnnsError.wrongEncryptionParameters(
got: context.encryptionParameters,
expected: encryptionParameters)
}
}
self.contexts = contexts
} else {
self.contexts = try config.encryptionParameters.map { encryptionParams in
var contexts = contexts
if contexts.isEmpty {
contexts = try config.encryptionParameters.map { encryptionParams in
try Context(encryptionParameters: encryptionParams)
}
}
try config.validateContexts(contexts: contexts)
self.contexts = contexts

self.plaintextContext = try PolyContext(
degree: config.encryptionParameters[0].polyDegree,
moduli: config.plaintextModuli)
Expand Down
49 changes: 47 additions & 2 deletions Sources/PrivateNearestNeighborsSearch/Config.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
public let scalingFactor: Int
/// Packing for the query.
public let queryPacking: MatrixPacking
/// Number of entries in each vector vector.
/// Number of entries in each vector.
public let vectorDimension: Int
/// Evaluation key configuration for nearest neighbors computation.
public let evaluationKeyConfig: EvaluationKeyConfiguration
Expand All @@ -47,7 +47,7 @@ public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
/// - encryptionParams: Encryption parameters.
/// - scalingFactor: Factor by which to scale floating-point entries before rounding to integers.
/// - queryPacking: Packing for the query.
/// - vectorDimension: Number of entries in each vector vector.
/// - vectorDimension: Number of entries in each vector.
/// - evaluationKeyConfig: Evaluation key configuration for nearest neighbors computation.
/// - distanceMetric: Metric for nearest neighbors computation
/// - extraPlaintextModuli: For plaintext CRT, the list of extra plaintext moduli. The first plaintext modulus
Expand Down Expand Up @@ -88,25 +88,62 @@ public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
let scalingFactor = (((t - 1) / 2).squareRoot() - Float(vectorDimension).squareRoot() / 2).rounded(.down)
return Int(scalingFactor)
}

/// Validates the contexts are suitable for computing with this configuration.
/// - Parameter contexts: Contexts; one per plaintext modulus.
/// - Throws: Error if the contexts are not valid.
@inlinable
func validateContexts(contexts: [Context<Scheme>]) throws {
guard contexts.count == encryptionParameters.count else {
throw PnnsError.wrongContextsCount(got: contexts.count, expected: encryptionParameters.count)
}
for (context, params) in zip(contexts, encryptionParameters) {
guard context.encryptionParameters == params else {
throw PnnsError.wrongEncryptionParameters(got: context.encryptionParameters, expected: params)
}
}
}
}

/// Server configuration.
public struct ServerConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Sendable {
/// Configuration shared with the client.
public let clientConfig: ClientConfig<Scheme>

/// Packing for the plaintext database.
public let databasePacking: MatrixPacking

/// Factor by which to scale floating-point entries before rounding to integers.
public var scalingFactor: Int { clientConfig.scalingFactor }

/// The plaintext CRT moduli.
public var plaintextModuli: [Scheme.Scalar] { clientConfig.plaintextModuli }

/// For plaintext CRT, the list of extra plaintext moduli.
///
/// The first plaintext modulus will be the one in ``ClientConfig/encryptionParams``.
public var extraPlaintextModuli: [Scheme.Scalar] {
clientConfig.extraPlaintextModuli
}

/// Distance metric.
public var distanceMetric: DistanceMetric { clientConfig.distanceMetric }

/// The encryption parameters, one per plaintext modulus.
public var encryptionParameters: [EncryptionParameters<Scheme>] {
clientConfig.encryptionParameters
}

/// Number of entries in each vector.
public var vectorDimension: Int {
clientConfig.vectorDimension
}

/// Packing for the query.
public var queryPacking: MatrixPacking {
clientConfig.queryPacking
}

/// Creates a new ``ServerConfig``.
/// - Parameters:
/// - clientConfig: Configuration shared with the client.
Expand All @@ -118,4 +155,12 @@ public struct ServerConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
self.clientConfig = clientConfig
self.databasePacking = databasePacking
}

/// Validates the contexts are suitable for computing with this configuration.
/// - Parameter contexts: Contexts; one per plaintext modulus.
/// - Throws: Error if the contexts are not valid.
@inlinable
func validateContexts(contexts: [Context<Scheme>]) throws {
try clientConfig.validateContexts(contexts: contexts)
}
}
3 changes: 3 additions & 0 deletions Sources/PrivateNearestNeighborsSearch/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public enum PnnsError: Error, Equatable {
case simdEncodingNotSupported(_ description: String)
case wrongCiphertextCount(got: Int, expected: Int)
case wrongContext(gotDescription: String, expectedDescription: String)
case wrongContextsCount(got: Int, expected: Int)
case wrongDistanceMetric(got: DistanceMetric, expected: DistanceMetric)
case wrongEncodingValuesCount(got: Int, expected: Int)
case wrongEncryptionParameters(gotDescription: String, expectedDescription: String)
Expand Down Expand Up @@ -79,6 +80,8 @@ extension PnnsError: LocalizedError {
"Invalid query due to \(reason)"
case let .wrongCiphertextCount(got, expected):
"Wrong ciphertext count \(got), expected \(expected)"
case let .wrongContextsCount(got, expected):
"Wrong contexts count \(got), expected \(expected)"
case let .wrongContext(gotDescription, expectedDescription):
"Wrong context \(gotDescription), expected \(expectedDescription)"
case let .wrongDistanceMetric(got, expected):
Expand Down
2 changes: 1 addition & 1 deletion Sources/PrivateNearestNeighborsSearch/PnnsProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import HomomorphicEncryption

/// A nearest neighbor search query.
public struct Query<Scheme: HeScheme>: Sendable {
public struct Query<Scheme: HeScheme>: Equatable, Sendable {
/// Encrypted query; one matrix per plaintext CRT modulus.
public let ciphertextMatrices: [CiphertextMatrix<Scheme, Coeff>]

Expand Down
42 changes: 30 additions & 12 deletions Sources/PrivateNearestNeighborsSearch/ProcessedDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,42 @@ public struct ProcessedDatabase<Scheme: HeScheme>: Equatable, Sendable {
plaintextMatrices: [PlaintextMatrix<Scheme, Eval>],
entryIds: [UInt64],
entryMetadatas: [[UInt8]],
serverConfig: ServerConfig<Scheme>)
serverConfig: ServerConfig<Scheme>) throws
{
precondition(contexts.count == plaintextMatrices.count)
try serverConfig.validateContexts(contexts: contexts)
self.contexts = contexts
self.plaintextMatrices = plaintextMatrices
self.entryIds = entryIds
self.entryMetadatas = entryMetadatas
self.serverConfig = serverConfig
}

/// Initializes a ``ProcessedDatabase`` from a ``SerializedProcessedDatabase``.
/// - Parameters:
/// - serialized: Serialized processed database.
/// - contexts: Contexts for HE computation, one per plaintext modulus.
/// - Throws: Error upon failure to load the database.
public init(from serialized: SerializedProcessedDatabase<Scheme>, contexts: [Context<Scheme>] = []) throws {
var contexts = contexts
if contexts.isEmpty {
contexts = try serialized.serverConfig.encryptionParameters.map { encryptionParams in
try Context(encryptionParameters: encryptionParams)
}
}
try serialized.serverConfig.validateContexts(contexts: contexts)

let plaintextMatrices = try zip(serialized.plaintextMatrices, contexts)
.map { matrix, context in
try PlaintextMatrix<Scheme, Eval>(deserialize: matrix, context: context)
}
try self.init(
contexts: contexts,
plaintextMatrices: plaintextMatrices,
entryIds: serialized.entryIds,
entryMetadatas: serialized.entryMetadatas,
serverConfig: serialized.serverConfig)
}

/// Serializes the processed database.
/// - Returns: The serialized processed database.
/// - Throws: Error upon failure to serialize.
Expand Down Expand Up @@ -78,16 +104,8 @@ extension Database {
contexts = try config.encryptionParameters.map { encryptionParams in
try Context(encryptionParameters: encryptionParams)
}
} else {
precondition(contexts.count == config.encryptionParameters.count)
for (context, encryptionParameters) in zip(contexts, config.encryptionParameters) {
guard context.encryptionParameters == encryptionParameters else {
throw PnnsError.wrongEncryptionParameters(
got: context.encryptionParameters,
expected: encryptionParameters)
}
}
}
try config.validateContexts(contexts: contexts)

let vectors = Array2d(data: rows.map { row in row.vector })
let roundedVectors: Array2d<Scheme.SignedScalar> = vectors.normalizedScaledAndRounded(
Expand All @@ -104,7 +122,7 @@ extension Database {
reduce: shouldReduce).convertToEvalFormat()
}
let hasMetadata = rows.contains { row in !row.entryMetadata.isEmpty }
return ProcessedDatabase(
return try ProcessedDatabase(
contexts: contexts,
plaintextMatrices: plaintextMatrices,
entryIds: rows.map { row in row.entryId },
Expand Down
2 changes: 1 addition & 1 deletion Sources/PrivateNearestNeighborsSearch/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import HomomorphicEncryption

/// Private nearest neighbors server.
public struct Server<Scheme: HeScheme> {
public struct Server<Scheme: HeScheme>: Sendable {
/// The database.
public let database: ProcessedDatabase<Scheme>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,19 @@ extension Query {
}
}

extension [Apple_SwiftHomomorphicEncryption_Pnns_V1_SerializedCiphertextMatrix] {
/// Converts the native object into a protobuf object.
/// - Returns: The converted protobuf object.
/// - Throws: Error upon unsupported object.
public func native<Scheme: HeScheme>(context: Context<Scheme>) throws -> Query<Scheme> {
let matrices: [CiphertextMatrix<Scheme, Coeff>] = try map { matrix in
let native: SerializedCiphertextMatrix<Scheme.Scalar> = try matrix.native()
return try CiphertextMatrix(deserialize: native, context: context)
}
return Query(ciphertextMatrices: matrices)
}
}

extension Query {
package func size() throws -> Int {
try proto().map { matrix in try matrix.serializedData().count }.sum()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,33 @@ class ConversionTests: XCTestCase {
try runTest(Bfv<UInt64>.self)
}

func testQuery() throws {
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
let encryptionParams = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
let context = try Context<Scheme>(encryptionParameters: encryptionParams)
let secretKey = try context.generateSecretKey()

let dimensions = try MatrixDimensions(rowCount: 5, columnCount: 4)
let scalars: [[Scheme.Scalar]] = increasingData(
dimensions: dimensions,
modulus: encryptionParams.plaintextModulus)
let plaintextMatrix = try PlaintextMatrix(
context: context,
dimensions: dimensions,
packing: .denseColumn,
values: scalars.flatMap { $0 })
let ciphertextMatrices = try (0...3).map { _ in
try plaintextMatrix.encrypt(using: secretKey).convertToCoeffFormat()
}

let query = Query(ciphertextMatrices: ciphertextMatrices)
let roundtrip = try query.proto().native(context: context)
XCTAssertEqual(roundtrip, query)
}
try runTest(Bfv<UInt32>.self)
try runTest(Bfv<UInt64>.self)
}

func testSerializedProcessedDatabase() throws {
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
let encryptionParams = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
Expand Down
60 changes: 60 additions & 0 deletions Tests/PrivateNearestNeighborsSearchTests/DatabaseTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import HomomorphicEncryption
@testable import PrivateNearestNeighborsSearch
import TestUtilities
import XCTest

final class DatabaseTests: XCTestCase {
func testSerializedProcessedDatabase() throws {
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
let encryptionParams = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
let vectorDimension = 4

let rows = (0...10).map { rowIndex in
DatabaseRow(
entryId: rowIndex,
entryMetadata: rowIndex.littleEndianBytes,
vector: Array(repeating: Float(rowIndex), count: vectorDimension))
}
let database = Database(rows: rows)

let clientConfig = try ClientConfig<Scheme>(
encryptionParams: encryptionParams,
scalingFactor: 123,
queryPacking: .denseRow,
vectorDimension: vectorDimension,
evaluationKeyConfig: EvaluationKeyConfiguration(galoisElements: [3]),
distanceMetric: .cosineSimilarity,
extraPlaintextModuli: Scheme.Scalar
.generatePrimes(
significantBitCounts: [7],
preferringSmall: true,
nttDegree: encryptionParams.polyDegree))
let serverConfig = ServerConfig<Scheme>(
clientConfig: clientConfig,
databasePacking: MatrixPacking
.diagonal(
babyStepGiantStep: BabyStepGiantStep(vectorDimension: vectorDimension)))

let processed = try database.process(config: serverConfig)
let serialized = try processed.serialize()
let deserialized = try ProcessedDatabase(from: serialized, contexts: processed.contexts)
XCTAssertEqual(deserialized, processed)
}
try runTest(Bfv<UInt32>.self)
try runTest(Bfv<UInt64>.self)
}
}
Loading