Skip to content

Commit

Permalink
Change sparse related errors to ParamError (#2066)
Browse files Browse the repository at this point in the history
so that the error messages can correctly propagate

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored May 7, 2024
1 parent 0a7327d commit 97f12ae
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def is_float_type(v: Any):
# parses plain bytes to a sparse float vector(SparseRowOutputType)
def sparse_parse_single_row(data: bytes) -> SparseRowOutputType:
if len(data) % 8 != 0:
msg = f"The length of data must be a multiple of 8, got {len(data)}"
raise ValueError(msg)
raise ParamError(message=f"The length of data must be a multiple of 8, got {len(data)}")

return {
struct.unpack("I", data[i : i + 4])[0]: struct.unpack("f", data[i + 4 : i + 8])[0]
Expand All @@ -129,16 +128,17 @@ def sparse_rows_to_proto(data: SparseMatrixInputType) -> schema_types.SparseFloa
# milvus interprets/persists the data.
def sparse_float_row_to_bytes(indices: Iterable[int], values: Iterable[float]):
if len(indices) != len(values):
msg = f"length of indices and values must be the same, got {len(indices)} and {len(values)}"
raise ValueError(msg)
raise ParamError(
message=f"length of indices and values must be the same, got {len(indices)} and {len(values)}"
)
data = b""
for i, v in sorted(zip(indices, values), key=lambda x: x[0]):
if not (0 <= i < 2**32 - 1):
msg = f"sparse vector index must be positive and less than 2^32-1: {i}"
raise ValueError(msg)
raise ParamError(
message=f"sparse vector index must be positive and less than 2^32-1: {i}"
)
if math.isnan(v):
msg = "sparse vector value must not be NaN"
raise ValueError(msg)
raise ParamError(message="sparse vector value must not be NaN")
data += struct.pack("I", i)
data += struct.pack("f", v)
return data
Expand All @@ -163,8 +163,7 @@ def unify_sparse_input(data: SparseMatrixInputType) -> sparse.csr_array:
return sparse.csr_array((values, (row_indices, col_indices)))

if not entity_is_sparse_matrix(data):
msg = "input must be a sparse matrix in supported format"
raise TypeError(msg)
raise ParamError(message="input must be a sparse matrix in supported format")
csr = unify_sparse_input(data)
result = schema_types.SparseFloatArray()
result.dim = csr.shape[1]
Expand All @@ -180,8 +179,7 @@ def sparse_proto_to_rows(
sfv: schema_types.SparseFloatArray, start: Optional[int] = None, end: Optional[int] = None
) -> Iterable[SparseRowOutputType]:
if not isinstance(sfv, schema_types.SparseFloatArray):
msg = "Vector must be a sparse float vector"
raise TypeError(msg)
raise ParamError(message="Vector must be a sparse float vector")
start = start or 0
end = end or len(sfv.contents)
return [sparse_parse_single_row(row_bytes) for row_bytes in sfv.contents[start:end]]
Expand Down

0 comments on commit 97f12ae

Please sign in to comment.