From eacf34862cf08ba045f238bface3c118d92fb6b1 Mon Sep 17 00:00:00 2001 From: Karl Tarbe Date: Tue, 30 Jul 2024 23:10:50 -0700 Subject: [PATCH] Fix sharding There was a mistake in sharding. I added a test and made the sharding compatible with the device code. --- .../PrivateInformationRetrieval/KeywordDatabase.swift | 11 ++++++----- .../KeywordDatabaseTests.swift | 11 +++++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/Sources/PrivateInformationRetrieval/KeywordDatabase.swift b/Sources/PrivateInformationRetrieval/KeywordDatabase.swift index 87d72ee9..79647fbe 100644 --- a/Sources/PrivateInformationRetrieval/KeywordDatabase.swift +++ b/Sources/PrivateInformationRetrieval/KeywordDatabase.swift @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +import Crypto import Foundation import HomomorphicEncryption @@ -52,11 +53,11 @@ extension KeywordValuePair.Keyword { /// - Returns: The shard index. @inlinable func shardIndex(shardCount: Int) -> Int { - HashKeyword - .indexFromHash( - keywordHash: HashKeyword.hash(keyword: self), - bucketCount: shardCount, - counter: 0) + let digest = SHA256.hash(data: self) + let truncatedHash = digest.withUnsafeBytes { buffer in + buffer.load(as: UInt64.self) + } + return Int(truncatedHash % UInt64(shardCount)) } } diff --git a/Tests/PrivateInformationRetrievalTests/KeywordDatabaseTests.swift b/Tests/PrivateInformationRetrievalTests/KeywordDatabaseTests.swift index a5f93222..acfa891f 100644 --- a/Tests/PrivateInformationRetrievalTests/KeywordDatabaseTests.swift +++ b/Tests/PrivateInformationRetrievalTests/KeywordDatabaseTests.swift @@ -39,4 +39,15 @@ class KeywordDatabaseTests: XCTestCase { XCTAssert(database.shards.contains { shard in shard.value[row.keyword] == row.value }) } } + + func testShardingKnownAnswerTest() throws { + func checkKeywordShard(_ keyword: KeywordValuePair.Keyword, shardCount: Int, expectedShard: Int) { + XCTAssertEqual(keyword.shardIndex(shardCount: shardCount), expectedShard) + } + + checkKeywordShard([0, 0, 0, 0], shardCount: 41, expectedShard: 2) + checkKeywordShard([0, 0, 0, 0], shardCount: 1001, expectedShard: 635) + checkKeywordShard([1, 2, 3], shardCount: 1001, expectedShard: 903) + checkKeywordShard([3, 2, 1], shardCount: 1001, expectedShard: 842) + } }