Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant type inference when insert data #1739

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/hello_milvus_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pymilvus import CollectionSchema, FieldSchema, Collection, connections, DataType, Partition, utility
import numpy as np
import random
import pandas as pd
connections.connect()

dim = 128
collection_name = "test_array"
arr_len = 100
nb = 10
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
# create collection
pk_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, description='pk')
vector_field = FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim)
int8_array = FieldSchema(name="int8_array", dtype=DataType.ARRAY, element_type=DataType.INT8, max_capacity=arr_len)
int16_array = FieldSchema(name="int16_array", dtype=DataType.ARRAY, element_type=DataType.INT16, max_capacity=arr_len)
int32_array = FieldSchema(name="int32_array", dtype=DataType.ARRAY, element_type=DataType.INT32, max_capacity=arr_len)
int64_array = FieldSchema(name="int64_array", dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=arr_len)
bool_array = FieldSchema(name="bool_array", dtype=DataType.ARRAY, element_type=DataType.BOOL, max_capacity=arr_len)
float_array = FieldSchema(name="float_array", dtype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=arr_len)
double_array = FieldSchema(name="double_array", dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=arr_len)
string_array = FieldSchema(name="string_array", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=arr_len,
max_length=100)

fields = [pk_field, vector_field, int8_array, int16_array, int32_array, int64_array,
bool_array, float_array, double_array, string_array]

schema = CollectionSchema(fields=fields)
collection = Collection(collection_name, schema=schema)

# insert data
pk_value = [i for i in range(nb)]
vector_value = [[random.random() for _ in range(dim)] for i in range(nb)]
int8_value = [[np.int8(j) for j in range(arr_len)] for i in range(nb)]
int16_value = [[np.int16(j) for j in range(arr_len)] for i in range(nb)]
int32_value = [[np.int32(j) for j in range(arr_len)] for i in range(nb)]
int64_value = [[np.int64(j) for j in range(arr_len)] for i in range(nb)]
bool_value = [[np.bool_(j) for j in range(arr_len)] for i in range(nb)]
float_value = [[np.float32(j) for j in range(arr_len)] for i in range(nb)]
double_value = [[np.double(j) for j in range(arr_len)] for i in range(nb)]
string_value = [[str(j) for j in range(arr_len)] for i in range(nb)]

data = [pk_value, vector_value,
int8_value,int16_value, int32_value, int64_value,
bool_value,
float_value,
double_value,
string_value
]

#collection.insert(data)

data = pd.DataFrame({
'int64': pk_value,
'float_vector': vector_value,
"int8_array": int8_value,
"int16_array": int16_value,
"int32_array": int32_value,
"int64_array": int64_value,
"bool_array": bool_value,
"float_array": float_value,
"double_array": double_value,
"string_array": string_value
})
collection.insert(data)

index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

collection.create_index("float_vector", index)
collection.load()

res = collection.query("int64 >= 0", output_fields=["int8_array"])
for hits in res:
print(hits)
10 changes: 5 additions & 5 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def entity_to_json_arr(entity: Dict):
return convert_to_json_arr(entity.get("values", []))


def convert_to_array_arr(objs: List[Any]):
return [convert_to_array_arr(obj) for obj in objs]
def convert_to_array_arr(objs: List[Any], field_info: Any):
return [convert_to_array(obj, field_info) for obj in objs]


def convert_to_array(obj: List[Any], field_info: Any):
Expand Down Expand Up @@ -100,8 +100,8 @@ def convert_to_array(obj: List[Any], field_info: Any):
)


def entity_to_array_arr(entity: List[Any]):
return convert_to_array_arr(entity.get("values", []))
def entity_to_array_arr(entity: List[Any], field_info: Any):
return convert_to_array_arr(entity.get("values", []), field_info)


def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any):
Expand Down Expand Up @@ -166,7 +166,7 @@ def entity_to_field_data(entity: Any, field_info: Any):
elif entity_type == DataType.JSON:
field_data.scalars.json_data.data.extend(entity_to_json_arr(entity))
elif entity_type == DataType.ARRAY:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity))
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info))
else:
raise ParamError(message=f"UnSupported data type: {entity_type}")

Expand Down
1 change: 0 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ def batch_insert(
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
return m

raise MilvusException(
response.status.code, response.status.reason, response.status.error_code
)
Expand Down
10 changes: 5 additions & 5 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,19 @@ def create_collection_with_schema(
self,
collection_name: str,
schema: CollectionSchema,
index_param: Dict,
index_params: Dict,
timeout: Optional[float] = None,
**kwargs,
):
schema.verify()
if kwargs.get("auto_id", True):
if kwargs.get("auto_id", False):
schema.auto_id = True
if kwargs.get("enable_dynamic_field", False):
schema.enable_dynamic_field = True
schema.verify()

index_param = index_param or {}
vector_field_name = index_param.pop("field_name", "")
index_params = index_params or {}
vector_field_name = index_params.pop("field_name", "")
if not vector_field_name:
schema_dict = schema.to_dict()
vector_field_name = self._get_vector_field_name(schema_dict)
Expand All @@ -520,7 +520,7 @@ def create_collection_with_schema(
logger.error("Failed to create collection: %s", collection_name)
raise ex from ex

self._create_index(collection_name, vector_field_name, index_param, timeout=timeout)
self._create_index(collection_name, vector_field_name, index_params, timeout=timeout)
self._load(collection_name, timeout=timeout)

def close(self):
Expand Down
151 changes: 49 additions & 102 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ def _check_fields(self):
primary_field_name = self._kwargs.get("primary_field", None)
partition_key_field_name = self._kwargs.get("partition_key_field", None)
for field in self._fields:
if primary_field_name == field.name:
if primary_field_name and primary_field_name == field.name:
field.is_primary = True
if partition_key_field_name == field.name:

if partition_key_field_name and partition_key_field_name == field.name:
field.is_partition_key = True

if field.is_primary:
Expand Down Expand Up @@ -403,17 +404,58 @@ def check_is_row_based(data: Union[List[List], List[Dict], Dict, pd.DataFrame])
return False


def _check_insert_data(data: Union[List[List], pd.DataFrame]):
if not isinstance(data, (pd.DataFrame, list)):
raise DataTypeNotSupportException(
message="The type of data should be list or pandas.DataFrame"
)
is_dataframe = isinstance(data, pd.DataFrame)
for col in data:
if not is_dataframe and not is_list_like(col):
raise DataTypeNotSupportException(message="data should be a list of list")


def _check_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

field_cnt = len(tmp_fields)
is_dataframe = isinstance(data, pd.DataFrame)
data_cnt = len(data.columns) if is_dataframe else len(data)
if field_cnt != data_cnt:
message = (
f"The data don't match with schema fields, expect {field_cnt} list, got {len(data)}"
)
if is_dataframe:
i_name = [f.name for f in tmp_fields]
t_name = list(data.columns)
message = f"The fields don't match with schema fields, expected: {i_name}, got {t_name}"

raise DataNotMatchException(message=message)

if is_dataframe:
for x, y in zip(list(data.columns), tmp_fields):
if x != y.name:
raise DataNotMatchException(
message=f"The name of field don't match, expected: {y.name}, got {x}"
)


def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
if schema is None:
raise SchemaNotReadyException(message="Schema shouldn't be None")
if schema.auto_id and isinstance(data, pd.DataFrame) and schema.primary_field.name in data:
if not data[schema.primary_field.name].isnull().all():
msg = f"Expect no data for auto_id primary field: {schema.primary_field.name}"
raise DataNotMatchException(message=msg)
data = data.drop(schema.primary_field.name, axis=1)
columns = list(data.columns)
columns.remove(schema.primary_field)
data = data[[columns]]

infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data)
check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame)
_check_data_schema_cnt(schema, data)
_check_insert_data(data)


def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
Expand All @@ -422,78 +464,8 @@ def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat
if schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)

infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data)
check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame)


def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
if not isinstance(data, (pd.DataFrame, list)):
raise DataTypeNotSupportException(
message="The type of data should be list or pandas.DataFrame"
)

if isinstance(data, pd.DataFrame):
return parse_fields_from_dataframe(schema, data)

tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

infer_fields = []
for i, field in enumerate(tmp_fields):
try:
d = data[i]
if not is_list_like(d):
raise DataTypeNotSupportException(message="data should be a list of list")
try:
elem = d[0]
infer_fields.append(FieldSchema("", infer_dtype_bydata(elem)))
# if pass in [] or None, considering to be passed in order according to the schema
except IndexError:
infer_fields.append(FieldSchema("", field.dtype))
# the last missing part of data is also completed in order according to the schema
except IndexError:
infer_fields.append(FieldSchema("", field.dtype))

index = len(tmp_fields)
while index < len(data):
fields = FieldSchema("", infer_dtype_bydata(data[index][0]))
infer_fields.append(fields)
index = index + 1

return infer_fields, tmp_fields, False


def parse_fields_from_dataframe(schema: CollectionSchema, df: pd.DataFrame):
col_names, data_types, column_params_map = prepare_fields_from_dataframe(df)
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(schema.fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)
infer_fields = []
for field in tmp_fields:
# if no data pass in, considering to be passed in order according to the schema
if field.name not in col_names:
field_schema = FieldSchema(field.name, field.dtype)
col_names.append(field.name)
data_types.append(field.dtype)
infer_fields.append(field_schema)
else:
type_params = column_params_map.get(field.name, {})
field_schema = FieldSchema(
field.name, data_types[col_names.index(field.name)], **type_params
)
infer_fields.append(field_schema)

infer_name = [f.name for f in infer_fields]
for name, dtype in zip(col_names, data_types):
if name not in infer_name:
type_params = column_params_map.get(name, {})
field_schema = FieldSchema(name, dtype, **type_params)
infer_fields.append(field_schema)

return infer_fields, tmp_fields, True
_check_data_schema_cnt(schema, data)
_check_insert_data(data)


def construct_fields_from_dataframe(df: pd.DataFrame) -> List[FieldSchema]:
Expand Down Expand Up @@ -536,31 +508,6 @@ def prepare_fields_from_dataframe(df: pd.DataFrame):
return col_names, data_types, column_params_map


def check_infer_fields_valid(
infer_fields: List[FieldSchema],
tmp_fields: List,
is_data_frame: bool,
):
if len(infer_fields) != len(tmp_fields):
i_name = [f.name for f in infer_fields]
t_name = [f.name for f in tmp_fields]
raise DataNotMatchException(
message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}"
)

for x, y in zip(infer_fields, tmp_fields):
if is_data_frame and x.name != y.name:
raise DataNotMatchException(
message=f"The name of field don't match, expected: {y.name}, got {x.name}"
)
if x.dtype != y.dtype:
msg = (
f"The data type of field {y.name} doesn't match, "
f"expected: {y.dtype.name}, got {x.dtype.name}"
)
raise DataNotMatchException(message=msg)


def check_schema(schema: CollectionSchema):
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema)
Expand Down
28 changes: 0 additions & 28 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,31 +139,3 @@ def test_to_dict(self, raw_dict_norm, raw_dict_float_vector, raw_dict_binary_vec
target = f.to_dict()
assert target == dicts[i]
assert target is not dicts[i]

# def test_parse_fields_from_dataframe(self, dataframe1):
# fields = parse_fields_from_dataframe(dataframe1)
# assert len(fields) == len(dataframe1.columns)
# for f in fields:
# if f.dtype == DataType.FLOAT_VECTOR:
# assert f.dim == len(dataframe1['float_vec'].values[0])


class TestCheckInsertDataSchema:
def test_check_insert_data_schema_issue1324(self):
schema = CollectionSchema([
FieldSchema(name="id", dtype=DataType.INT64, descrition="int64", is_primary=True, auto_id=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="float vector", dim=2),
FieldSchema(name="work_id2", dtype=5, descrition="work id"),
FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=200),
FieldSchema(name="uid", dtype=DataType.INT64, descrition="user id"),
])

data = [
[[0.003984056, 0.05035976]],
['15755403'],
['https://xxx.com/app/works/105149/2023-01-11/w_63be653c4643b/963be653c8aa8c.jpg'],
['105149'],
]

with pytest.raises(MilvusException):
s.check_insert_schema(schema, data)
Loading