From 1bd12157518a1f87a311c73f24874ee01cf0e124 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Wed, 18 Oct 2023 16:44:00 +0800 Subject: [PATCH] Remove redundant type inference when insert data (#1739) Signed-off-by: zhenshan.cao --- examples/hello_milvus_array.py | 79 +++++++++++++ pymilvus/client/entity_helper.py | 10 +- pymilvus/client/grpc_handler.py | 1 - pymilvus/milvus_client/milvus_client.py | 10 +- pymilvus/orm/schema.py | 151 ++++++++---------------- tests/test_schema.py | 28 ----- 6 files changed, 138 insertions(+), 141 deletions(-) create mode 100644 examples/hello_milvus_array.py diff --git a/examples/hello_milvus_array.py b/examples/hello_milvus_array.py new file mode 100644 index 000000000..05f4165ea --- /dev/null +++ b/examples/hello_milvus_array.py @@ -0,0 +1,79 @@ +from pymilvus import CollectionSchema, FieldSchema, Collection, connections, DataType, Partition, utility +import numpy as np +import random +import pandas as pd +connections.connect() + +dim = 128 +collection_name = "test_array" +arr_len = 100 +nb = 10 +if utility.has_collection(collection_name): + utility.drop_collection(collection_name) +# create collection +pk_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, description='pk') +vector_field = FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim) +int8_array = FieldSchema(name="int8_array", dtype=DataType.ARRAY, element_type=DataType.INT8, max_capacity=arr_len) +int16_array = FieldSchema(name="int16_array", dtype=DataType.ARRAY, element_type=DataType.INT16, max_capacity=arr_len) +int32_array = FieldSchema(name="int32_array", dtype=DataType.ARRAY, element_type=DataType.INT32, max_capacity=arr_len) +int64_array = FieldSchema(name="int64_array", dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=arr_len) +bool_array = FieldSchema(name="bool_array", dtype=DataType.ARRAY, element_type=DataType.BOOL, max_capacity=arr_len) +float_array = FieldSchema(name="float_array", dtype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=arr_len) +double_array = FieldSchema(name="double_array", dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=arr_len) +string_array = FieldSchema(name="string_array", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=arr_len, + max_length=100) + +fields = [pk_field, vector_field, int8_array, int16_array, int32_array, int64_array, + bool_array, float_array, double_array, string_array] + +schema = CollectionSchema(fields=fields) +collection = Collection(collection_name, schema=schema) + +# insert data +pk_value = [i for i in range(nb)] +vector_value = [[random.random() for _ in range(dim)] for i in range(nb)] +int8_value = [[np.int8(j) for j in range(arr_len)] for i in range(nb)] +int16_value = [[np.int16(j) for j in range(arr_len)] for i in range(nb)] +int32_value = [[np.int32(j) for j in range(arr_len)] for i in range(nb)] +int64_value = [[np.int64(j) for j in range(arr_len)] for i in range(nb)] +bool_value = [[np.bool_(j) for j in range(arr_len)] for i in range(nb)] +float_value = [[np.float32(j) for j in range(arr_len)] for i in range(nb)] +double_value = [[np.double(j) for j in range(arr_len)] for i in range(nb)] +string_value = [[str(j) for j in range(arr_len)] for i in range(nb)] + +data = [pk_value, vector_value, + int8_value,int16_value, int32_value, int64_value, + bool_value, + float_value, + double_value, + string_value + ] + +#collection.insert(data) + +data = pd.DataFrame({ + 'int64': pk_value, + 'float_vector': vector_value, + "int8_array": int8_value, + "int16_array": int16_value, + "int32_array": int32_value, + "int64_array": int64_value, + "bool_array": bool_value, + "float_array": float_value, + "double_array": double_value, + "string_array": string_value +}) +collection.insert(data) + +index = { + "index_type": "IVF_FLAT", + "metric_type": "L2", + "params": {"nlist": 128}, +} + +collection.create_index("float_vector", index) +collection.load() + +res = collection.query("int64 >= 0", output_fields=["int8_array"]) +for hits in res: + print(hits) diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index fb8debba7..f0b1fa50d 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -70,8 +70,8 @@ def entity_to_json_arr(entity: Dict): return convert_to_json_arr(entity.get("values", [])) -def convert_to_array_arr(objs: List[Any]): - return [convert_to_array_arr(obj) for obj in objs] +def convert_to_array_arr(objs: List[Any], field_info: Any): + return [convert_to_array(obj, field_info) for obj in objs] def convert_to_array(obj: List[Any], field_info: Any): @@ -100,8 +100,8 @@ def convert_to_array(obj: List[Any], field_info: Any): ) -def entity_to_array_arr(entity: List[Any]): - return convert_to_array_arr(entity.get("values", [])) +def entity_to_array_arr(entity: List[Any], field_info: Any): + return convert_to_array_arr(entity.get("values", []), field_info) def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any): @@ -166,7 +166,7 @@ def entity_to_field_data(entity: Any, field_info: Any): elif entity_type == DataType.JSON: field_data.scalars.json_data.data.extend(entity_to_json_arr(entity)) elif entity_type == DataType.ARRAY: - field_data.scalars.array_data.data.extend(entity_to_array_arr(entity)) + field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info)) else: raise ParamError(message=f"UnSupported data type: {entity_type}") diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 40fb7523d..6e4803708 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -570,7 +570,6 @@ def batch_insert( m = MutationResult(response) ts_utils.update_collection_ts(collection_name, m.timestamp) return m - raise MilvusException( response.status.code, response.status.reason, response.status.error_code ) diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 15129335e..da367f2ab 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -493,19 +493,19 @@ def create_collection_with_schema( self, collection_name: str, schema: CollectionSchema, - index_param: Dict, + index_params: Dict, timeout: Optional[float] = None, **kwargs, ): schema.verify() - if kwargs.get("auto_id", True): + if kwargs.get("auto_id", False): schema.auto_id = True if kwargs.get("enable_dynamic_field", False): schema.enable_dynamic_field = True schema.verify() - index_param = index_param or {} - vector_field_name = index_param.pop("field_name", "") + index_params = index_params or {} + vector_field_name = index_params.pop("field_name", "") if not vector_field_name: schema_dict = schema.to_dict() vector_field_name = self._get_vector_field_name(schema_dict) @@ -520,7 +520,7 @@ def create_collection_with_schema( logger.error("Failed to create collection: %s", collection_name) raise ex from ex - self._create_index(collection_name, vector_field_name, index_param, timeout=timeout) + self._create_index(collection_name, vector_field_name, index_params, timeout=timeout) self._load(collection_name, timeout=timeout) def close(self): diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index bdab6b6f9..3f83a2e23 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -98,9 +98,10 @@ def _check_fields(self): primary_field_name = self._kwargs.get("primary_field", None) partition_key_field_name = self._kwargs.get("partition_key_field", None) for field in self._fields: - if primary_field_name == field.name: + if primary_field_name and primary_field_name == field.name: field.is_primary = True - if partition_key_field_name == field.name: + + if partition_key_field_name and partition_key_field_name == field.name: field.is_partition_key = True if field.is_primary: @@ -403,6 +404,45 @@ def check_is_row_based(data: Union[List[List], List[Dict], Dict, pd.DataFrame]) return False +def _check_insert_data(data: Union[List[List], pd.DataFrame]): + if not isinstance(data, (pd.DataFrame, list)): + raise DataTypeNotSupportException( + message="The type of data should be list or pandas.DataFrame" + ) + is_dataframe = isinstance(data, pd.DataFrame) + for col in data: + if not is_dataframe and not is_list_like(col): + raise DataTypeNotSupportException(message="data should be a list of list") + + +def _check_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): + tmp_fields = copy.deepcopy(schema.fields) + for i, field in enumerate(tmp_fields): + if field.is_primary and field.auto_id: + tmp_fields.pop(i) + + field_cnt = len(tmp_fields) + is_dataframe = isinstance(data, pd.DataFrame) + data_cnt = len(data.columns) if is_dataframe else len(data) + if field_cnt != data_cnt: + message = ( + f"The data don't match with schema fields, expect {field_cnt} list, got {len(data)}" + ) + if is_dataframe: + i_name = [f.name for f in tmp_fields] + t_name = list(data.columns) + message = f"The fields don't match with schema fields, expected: {i_name}, got {t_name}" + + raise DataNotMatchException(message=message) + + if is_dataframe: + for x, y in zip(list(data.columns), tmp_fields): + if x != y.name: + raise DataNotMatchException( + message=f"The name of field don't match, expected: {y.name}, got {x}" + ) + + def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): if schema is None: raise SchemaNotReadyException(message="Schema shouldn't be None") @@ -410,10 +450,12 @@ def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat if not data[schema.primary_field.name].isnull().all(): msg = f"Expect no data for auto_id primary field: {schema.primary_field.name}" raise DataNotMatchException(message=msg) - data = data.drop(schema.primary_field.name, axis=1) + columns = list(data.columns) + columns.remove(schema.primary_field) + data = data[[columns]] - infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data) - check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame) + _check_data_schema_cnt(schema, data) + _check_insert_data(data) def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): @@ -422,78 +464,8 @@ def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat if schema.auto_id: raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) - infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data) - check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame) - - -def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]): - if not isinstance(data, (pd.DataFrame, list)): - raise DataTypeNotSupportException( - message="The type of data should be list or pandas.DataFrame" - ) - - if isinstance(data, pd.DataFrame): - return parse_fields_from_dataframe(schema, data) - - tmp_fields = copy.deepcopy(schema.fields) - for i, field in enumerate(tmp_fields): - if field.is_primary and field.auto_id: - tmp_fields.pop(i) - - infer_fields = [] - for i, field in enumerate(tmp_fields): - try: - d = data[i] - if not is_list_like(d): - raise DataTypeNotSupportException(message="data should be a list of list") - try: - elem = d[0] - infer_fields.append(FieldSchema("", infer_dtype_bydata(elem))) - # if pass in [] or None, considering to be passed in order according to the schema - except IndexError: - infer_fields.append(FieldSchema("", field.dtype)) - # the last missing part of data is also completed in order according to the schema - except IndexError: - infer_fields.append(FieldSchema("", field.dtype)) - - index = len(tmp_fields) - while index < len(data): - fields = FieldSchema("", infer_dtype_bydata(data[index][0])) - infer_fields.append(fields) - index = index + 1 - - return infer_fields, tmp_fields, False - - -def parse_fields_from_dataframe(schema: CollectionSchema, df: pd.DataFrame): - col_names, data_types, column_params_map = prepare_fields_from_dataframe(df) - tmp_fields = copy.deepcopy(schema.fields) - for i, field in enumerate(schema.fields): - if field.is_primary and field.auto_id: - tmp_fields.pop(i) - infer_fields = [] - for field in tmp_fields: - # if no data pass in, considering to be passed in order according to the schema - if field.name not in col_names: - field_schema = FieldSchema(field.name, field.dtype) - col_names.append(field.name) - data_types.append(field.dtype) - infer_fields.append(field_schema) - else: - type_params = column_params_map.get(field.name, {}) - field_schema = FieldSchema( - field.name, data_types[col_names.index(field.name)], **type_params - ) - infer_fields.append(field_schema) - - infer_name = [f.name for f in infer_fields] - for name, dtype in zip(col_names, data_types): - if name not in infer_name: - type_params = column_params_map.get(name, {}) - field_schema = FieldSchema(name, dtype, **type_params) - infer_fields.append(field_schema) - - return infer_fields, tmp_fields, True + _check_data_schema_cnt(schema, data) + _check_insert_data(data) def construct_fields_from_dataframe(df: pd.DataFrame) -> List[FieldSchema]: @@ -536,31 +508,6 @@ def prepare_fields_from_dataframe(df: pd.DataFrame): return col_names, data_types, column_params_map -def check_infer_fields_valid( - infer_fields: List[FieldSchema], - tmp_fields: List, - is_data_frame: bool, -): - if len(infer_fields) != len(tmp_fields): - i_name = [f.name for f in infer_fields] - t_name = [f.name for f in tmp_fields] - raise DataNotMatchException( - message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}" - ) - - for x, y in zip(infer_fields, tmp_fields): - if is_data_frame and x.name != y.name: - raise DataNotMatchException( - message=f"The name of field don't match, expected: {y.name}, got {x.name}" - ) - if x.dtype != y.dtype: - msg = ( - f"The data type of field {y.name} doesn't match, " - f"expected: {y.dtype.name}, got {x.dtype.name}" - ) - raise DataNotMatchException(message=msg) - - def check_schema(schema: CollectionSchema): if schema is None: raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema) diff --git a/tests/test_schema.py b/tests/test_schema.py index 15469797d..aabcff668 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -139,31 +139,3 @@ def test_to_dict(self, raw_dict_norm, raw_dict_float_vector, raw_dict_binary_vec target = f.to_dict() assert target == dicts[i] assert target is not dicts[i] - - # def test_parse_fields_from_dataframe(self, dataframe1): - # fields = parse_fields_from_dataframe(dataframe1) - # assert len(fields) == len(dataframe1.columns) - # for f in fields: - # if f.dtype == DataType.FLOAT_VECTOR: - # assert f.dim == len(dataframe1['float_vec'].values[0]) - - -class TestCheckInsertDataSchema: - def test_check_insert_data_schema_issue1324(self): - schema = CollectionSchema([ - FieldSchema(name="id", dtype=DataType.INT64, descrition="int64", is_primary=True, auto_id=True), - FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="float vector", dim=2), - FieldSchema(name="work_id2", dtype=5, descrition="work id"), - FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=200), - FieldSchema(name="uid", dtype=DataType.INT64, descrition="user id"), - ]) - - data = [ - [[0.003984056, 0.05035976]], - ['15755403'], - ['https://xxx.com/app/works/105149/2023-01-11/w_63be653c4643b/963be653c8aa8c.jpg'], - ['105149'], - ] - - with pytest.raises(MilvusException): - s.check_insert_schema(schema, data)