diff --git a/qdrant_client/async_qdrant_remote.py b/qdrant_client/async_qdrant_remote.py index 9befcd69..76067c3b 100644 --- a/qdrant_client/async_qdrant_remote.py +++ b/qdrant_client/async_qdrant_remote.py @@ -2765,16 +2765,17 @@ async def create_shard_key( if self._prefer_grpc: if isinstance(shard_key, get_args_subscribed(models.ShardKey)): shard_key = RestToGrpc.convert_shard_key(shard_key) - request = await grpc.CreateShardKey( - shard_key=shard_key, - shards_number=shards_number, - replication_factor=replication_factor, - placement=placement or [], - ) return ( await self.grpc_collections.CreateShardKey( grpc.CreateShardKeyRequest( - collection_name=collection_name, timeout=timeout, request=request + collection_name=collection_name, + timeout=timeout, + request=grpc.CreateShardKey( + shard_key=shard_key, + shards_number=shards_number, + replication_factor=replication_factor, + placement=placement or [], + ), ), timeout=self._timeout, ) diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index 243a3fe8..e3088fee 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -3138,18 +3138,16 @@ def create_shard_key( if isinstance(shard_key, get_args_subscribed(models.ShardKey)): shard_key = RestToGrpc.convert_shard_key(shard_key) - request = grpc.CreateShardKey( - shard_key=shard_key, - shards_number=shards_number, - replication_factor=replication_factor, - placement=placement or [], - ) - return self.grpc_collections.CreateShardKey( grpc.CreateShardKeyRequest( collection_name=collection_name, timeout=timeout, - request=request, + request=grpc.CreateShardKey( + shard_key=shard_key, + shards_number=shards_number, + replication_factor=replication_factor, + placement=placement or [], + ), ), timeout=self._timeout, ).result diff --git a/tests/test_async_qdrant_client.py b/tests/test_async_qdrant_client.py index 08647dcb..4174f6fd 100644 --- a/tests/test_async_qdrant_client.py +++ b/tests/test_async_qdrant_client.py @@ -609,3 +609,25 @@ def auth_token_provider(): await client.unlock_storage() assert sync_token == "token_2" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("prefer_grpc", [False, True]) +async def test_custom_sharding(prefer_grpc): + client = AsyncQdrantClient(prefer_grpc=prefer_grpc) + + if await client.collection_exists(COLLECTION_NAME): + await client.delete_collection(collection_name=COLLECTION_NAME) + await client.create_collection( + collection_name=COLLECTION_NAME, + vectors_config=models.VectorParams(size=DIM, distance=models.Distance.DOT), + sharding_method=models.ShardingMethod.CUSTOM, + ) + + await client.create_shard_key(collection_name=COLLECTION_NAME, shard_key="cats") + await client.create_shard_key(collection_name=COLLECTION_NAME, shard_key="dogs") + + collection_info = await client.get_collection(COLLECTION_NAME) + + assert collection_info.config.params.shard_number == 1 + # assert collection_info.config.params.sharding_method == models.ShardingMethod.CUSTOM # todo: fix in grpc