diff --git a/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift b/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift index 4b138ba2..2005b3d9 100644 --- a/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift +++ b/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift @@ -16,17 +16,14 @@ import HomomorphicEncryption /// Stores a matrix of scalars as ciphertexts. struct CiphertextMatrix: Equatable, Sendable { - typealias Packing = PlaintextMatrixPacking - typealias Dimensions = MatrixDimensions - /// Dimensions of the scalars. - @usableFromInline let dimensions: Dimensions + @usableFromInline let dimensions: MatrixDimensions /// Dimensions of the scalar matrix in a SIMD-encoded plaintext. @usableFromInline let simdDimensions: SimdEncodingDimensions /// Plaintext packing with which the data is stored. - @usableFromInline let packing: Packing + @usableFromInline let packing: MatrixPacking /// Encrypted data. @usableFromInline let ciphertexts: [Ciphertext] @@ -59,7 +56,7 @@ struct CiphertextMatrix: Equatable, Sendab /// - ciphertexts: ciphertexts encrypting the data; must not be empty. /// - Throws: Error upon failure to initialize the ciphertext matrix. @inlinable - init(dimensions: Dimensions, packing: Packing, ciphertexts: [Ciphertext]) throws { + init(dimensions: MatrixDimensions, packing: MatrixPacking, ciphertexts: [Ciphertext]) throws { guard let context = ciphertexts.first?.context else { throw PnnsError.emptyCiphertextArray } @@ -128,3 +125,155 @@ extension CiphertextMatrix { ciphertexts: coeffCiphertexts) } } + +extension CiphertextMatrix { + /// Computes the evaluation key configuration for calling `extractDenseRow`. + /// - Parameters: + /// - encryptionParams: Encryption parameters; must support `.simd` encoding. + /// - dimensions: Dimensions of the matrix to call `extractDenseRow` on. + /// - Returns: The evaluation key configuration. + /// - Throws: Error upon failure to generate the evaluation key configuration. + @inlinable + static func extractDenseRowConfig(for encryptionParams: EncryptionParameters, + dimensions: MatrixDimensions) throws -> EvaluationKeyConfiguration + { + if dimensions.rowCount == 1 { + // extractDenseRow is a No-op, so no evaluation key required + return EvaluationKeyConfiguration() + } + guard let simdDimensions = encryptionParams.simdDimensions else { + throw PnnsError.simdEncodingNotSupported(for: encryptionParams) + } + let degree = encryptionParams.polyDegree + var galoisElements = [GaloisElement.swappingRows(degree: degree)] + let columnCountPowerOfTwo = dimensions.columnCount.nextPowerOfTwo + if columnCountPowerOfTwo != simdDimensions.columnCount { + try galoisElements.append(GaloisElement.rotatingColumns(by: columnCountPowerOfTwo, degree: degree)) + } + return EvaluationKeyConfiguration(galoisElements: galoisElements) + } + + /// Extracts a ciphertext matrix with a single row and `.denseRow` packing. + /// - Parameters: + /// - rowIndex: Row index to extract + /// - evaluationKey: Evaluation key; must have `CiphertextMatrix/extractDenseRow` configuration + /// - Returns: A ciphertext matrix in `.denseRow` format with 1 row + /// - Throws: Error upon failure to extract the row. + @inlinable + func extractDenseRow(rowIndex: Int, evaluationKey: EvaluationKey) throws -> Self + where Format == Scheme.CanonicalCiphertextFormat + { + precondition((0.. Range { + precondition((0.. simdDimensions.columnCount { + let padding = simdColumnCount % columnCountPowerOfTwo + batchIndices = (batchIndices.startIndex + padding.. 0 ? rowIndex - 1 : 0 + while firstRowIndexInBatch > 0, + simdSlotIndices(rowIndex: firstRowIndexInBatch).upperBound == batchIndices.upperBound + { + firstRowIndexInBatch -= 1 + } + return lastRowIndexInBatch - firstRowIndexInBatch + }() + + // First, we mask out just the ciphertext data row vector e.g., + // plaintextMask = [[0, 0, 1, 1, 0, 0, 1, 1], + // [0, 0, 0, 0, 0, 0, 0, 0]] + let (plaintextMask, copiesInMask) = try { + var repeatMask = Array(repeating: Scheme.Scalar(1), count: columnCountPowerOfTwo) + repeatMask += Array(repeating: 0, count: columnCountPowerOfTwo * (rowCountInBatch - 1)) + // pad to next power of two + repeatMask += Array(repeating: 0, count: repeatMask.count.nextPowerOfTwo - repeatMask.count) + + var mask = Array(repeating: Scheme.Scalar(0), count: batchIndices.lowerBound) + var repeatCountInMask = 0 + while mask.count < batchIndices.upperBound { + mask += repeatMask + repeatCountInMask += 1 + } + mask = Array(mask.prefix(degree)) + let plaintext: Plaintext = try context.encode(values: mask, format: .simd) + return (plaintext, repeatCountInMask) + }() + + var ciphertextEval = try ciphertexts[ciphertextIndex].convertToEvalFormat() + try ciphertextEval *= plaintextMask + var ciphertext = try ciphertextEval.convertToCanonicalFormat() + // e.g., `ciphertext` now encrypts + // [[0, 0, 3, 4, 0, 0, 3, 4], + // [0, 0, 0, 0, 0, 0, 0, 0]] + + // Replicate the values across one SIMD row by rotating and adding. + let rotateCount = simdColumnCount / (copiesInMask * columnCountPowerOfTwo) - 1 + var ciphertextCopyRight = ciphertext + for _ in 0..: Equatable, Sendable { - typealias Packing = PlaintextMatrixPacking - typealias Dimensions = MatrixDimensions - /// Dimensions of the scalars. - @usableFromInline let dimensions: Dimensions + @usableFromInline let dimensions: MatrixDimensions /// Dimensions of the scalar matrix in a SIMD-encoded plaintext. let simdDimensions: SimdEncodingDimensions /// Plaintext packing with which the data is stored. - @usableFromInline let packing: Packing + @usableFromInline let packing: MatrixPacking /// Plaintexts encoding the scalars. let plaintexts: [Plaintext] @@ -107,7 +104,7 @@ struct PlaintextMatrix: Equatable, Sendabl /// - plaintexts: Plaintexts encoding the data; must not be empty. /// - Throws: Error upon failure to initialize the plaintext matrix. @inlinable - init(dimensions: Dimensions, packing: Packing, plaintexts: [Plaintext]) throws { + init(dimensions: MatrixDimensions, packing: MatrixPacking, plaintexts: [Plaintext]) throws { guard !plaintexts.isEmpty else { throw PnnsError.emptyPlaintextArray } @@ -143,7 +140,11 @@ struct PlaintextMatrix: Equatable, Sendabl /// - values: The data values to store in the plaintext matrix; stored in row-major format. /// - Throws: Error upon failure to create the plaitnext matrix. @inlinable - init(context: Context, dimensions: Dimensions, packing: Packing, values: [some ScalarType]) throws + init( + context: Context, + dimensions: MatrixDimensions, + packing: MatrixPacking, + values: [some ScalarType]) throws where Format == Coeff { guard values.count == dimensions.count, !values.isEmpty else { @@ -182,8 +183,8 @@ struct PlaintextMatrix: Equatable, Sendabl @inlinable static func plaintextCount( encryptionParameters: EncryptionParameters, - dimensions: PlaintextMatrix.Dimensions, - packing: PlaintextMatrix.Packing) throws -> Int + dimensions: MatrixDimensions, + packing: MatrixPacking) throws -> Int { guard let simdDimensions = encryptionParameters.simdDimensions else { throw PnnsError.simdEncodingNotSupported(for: encryptionParameters) @@ -219,7 +220,7 @@ struct PlaintextMatrix: Equatable, Sendabl /// - Returns: The plaintexts for `denseColumn` packing. /// - Throws: Error upon plaintext to compute the plaintexts. @inlinable - static func denseColumnPlaintexts(context: Context, dimensions: Dimensions, + static func denseColumnPlaintexts(context: Context, dimensions: MatrixDimensions, values: [V]) throws -> [Scheme.CoeffPlaintext] { let degree = context.degree @@ -277,14 +278,14 @@ struct PlaintextMatrix: Equatable, Sendabl @inlinable static func denseRowPlaintexts( context: Context, - dimensions: Dimensions, + dimensions: MatrixDimensions, values: [V]) throws -> [Plaintext] { let encryptionParameters = context.encryptionParameters guard let simdDimensions = context.simdDimensions else { throw PnnsError.simdEncodingNotSupported(for: encryptionParameters) } - precondition(simdDimensions.rowCount == 2, "simdRowCount must be 2") + precondition(simdDimensions.rowCount == 2, "SIMD row count must be 2") let simdColumnCount = simdDimensions.columnCount guard dimensions.columnCount <= simdColumnCount else { throw PnnsError.invalidMatrixDimensions(dimensions) @@ -351,8 +352,8 @@ struct PlaintextMatrix: Equatable, Sendabl @inlinable static func diagonalPlaintexts( context: Context, - dimensions: Dimensions, - packing: PlaintextMatrixPacking, + dimensions: MatrixDimensions, + packing: MatrixPacking, values: [V]) throws -> [Scheme.CoeffPlaintext] { let encryptionParameters = context.encryptionParameters @@ -367,8 +368,7 @@ struct PlaintextMatrix: Equatable, Sendabl } guard case let .diagonal(bsgs) = packing else { let expectedBsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount) - throw PnnsError - .wrongPlaintextMatrixPacking(got: packing, expected: .diagonal(babyStepGiantStep: expectedBsgs)) + throw PnnsError.wrongMatrixPacking(got: packing, expected: .diagonal(babyStepGiantStep: expectedBsgs)) } let data = Array2d(data: values, rowCount: dimensions.rowCount, columnCount: dimensions.columnCount) @@ -439,11 +439,11 @@ struct PlaintextMatrix: Equatable, Sendabl @inlinable func unpackDenseColumn() throws -> [V] where Format == Coeff { guard case packing = .denseColumn else { - throw PnnsError.wrongPlaintextMatrixPacking(got: packing, expected: .denseColumn) + throw PnnsError.wrongMatrixPacking(got: packing, expected: .denseColumn) } let simdColumnCount = simdDimensions.columnCount let simdRowCount = simdDimensions.rowCount - precondition(simdRowCount == 2, "simdRowCount must be 2") + precondition(simdRowCount == 2, "SIMD row count must be 2") let columnsPerPlaintextCount = simdRowCount * (simdColumnCount / rowCount) var valuesColumnMajor: [V] = [] @@ -483,7 +483,7 @@ struct PlaintextMatrix: Equatable, Sendabl @inlinable func unpackDenseRow() throws -> [V] where Format == Coeff { guard case packing = .denseRow else { - throw PnnsError.wrongPlaintextMatrixPacking(got: packing, expected: Packing.denseRow) + throw PnnsError.wrongMatrixPacking(got: packing, expected: MatrixPacking.denseRow) } let simdColumnCount = simdDimensions.columnCount diff --git a/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift b/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift index 342e8ad1..b7b26110 100644 --- a/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift +++ b/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift @@ -17,6 +17,15 @@ import HomomorphicEncryption import TestUtilities import XCTest +func increasingData(dimensions: MatrixDimensions, modulus: T) -> [[T]] { + (0..(for _: Scheme.Type) throws { @@ -25,12 +34,9 @@ final class CiphertextMatrixTests: XCTestCase { XCTAssert(encryptionParams.supportsSimdEncoding) let context = try Context(encryptionParameters: encryptionParams) let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) - let encodeValues: [[Scheme.Scalar]] = (0..( context: context, dimensions: dimensions, @@ -53,12 +59,9 @@ final class CiphertextMatrixTests: XCTestCase { XCTAssert(encryptionParams.supportsSimdEncoding) let context = try Context(encryptionParameters: encryptionParams) let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) - let encodeValues: [[Scheme.Scalar]] = (0..( context: context, dimensions: dimensions, @@ -75,4 +78,80 @@ final class CiphertextMatrixTests: XCTestCase { try runTest(for: Bfv.self) try runTest(for: Bfv.self) } + + func testExtractDenseRow() throws { + func runTest(for _: Scheme.Type) throws { + let degree = 16 + let plaintextModulus = try Scheme.Scalar.generatePrimes( + significantBitCounts: [9], + preferringSmall: true, + nttDegree: degree)[0] + let coefficientModuli = try Scheme.Scalar.generatePrimes( + significantBitCounts: Array( + repeating: Scheme.Scalar.bitWidth - 4, + count: 2), + preferringSmall: false, + nttDegree: degree) + let encryptionParams = try EncryptionParameters( + polyDegree: degree, + plaintextModulus: plaintextModulus, + coefficientModuli: coefficientModuli, + errorStdDev: .stdDev32, + securityLevel: .unchecked) + XCTAssert(encryptionParams.supportsSimdEncoding) + let context = try Context(encryptionParameters: encryptionParams) + + for rowCount in 1..<(2 * degree) { + for columnCount in 1..( + context: context, + dimensions: dimensions, + packing: .denseRow, + values: encodeValues.flatMap { $0 }) + let secretKey = try context.generateSecretKey() + let ciphertextMatrix: CiphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) + + let evaluationKeyConfig = try CiphertextMatrix.extractDenseRowConfig( + for: encryptionParams, + dimensions: dimensions) + let evaluationKey = try context.generateEvaluationKey( + configuration: evaluationKeyConfig, + using: secretKey) + + for rowIndex in 0...self) + try runTest(for: Bfv.self) + } } diff --git a/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift b/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift index f129e34d..b6c1b8b7 100644 --- a/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift +++ b/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift @@ -36,7 +36,7 @@ final class PlaintextMatrixTests: XCTestCase { let rowCount = encryptionParams.polyDegree let columnCount = 2 let dims = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount) - let packing = PlaintextMatrixPacking.denseRow + let packing = MatrixPacking.denseRow let context = try Context(encryptionParameters: encryptionParams) let values = TestUtils.getRandomPlaintextData( count: encryptionParams.polyDegree, @@ -89,7 +89,7 @@ final class PlaintextMatrixTests: XCTestCase { let values = TestUtils.getRandomPlaintextData( count: encryptionParams.polyDegree, in: 0..( context: Context, dimensions: MatrixDimensions, - packing: PlaintextMatrixPacking, + packing: MatrixPacking, expected: [[Int]]) throws { guard context.supportsSimdEncoding else { @@ -456,12 +456,9 @@ final class PlaintextMatrixTests: XCTestCase { XCTAssert(encryptionParams.supportsSimdEncoding) let context = try Context(encryptionParameters: encryptionParams) let dimensions = try MatrixDimensions(rowCount: 10, columnCount: 4) - let encodeValues: [[Scheme.Scalar]] = (0..( context: context, dimensions: dimensions,