Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulkwriter set row group for parquet #1836

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def build_simple_collection():
print(f"Collection '{collection.name}' created")
return collection.schema

def build_all_type_schema(bin_vec: bool):
def build_all_type_schema(bin_vec: bool, has_array: bool):
print(f"\n===================== build all types schema ====================")
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
Expand All @@ -93,6 +93,11 @@ def build_all_type_schema(bin_vec: bool):
FieldSchema(name="json", dtype=DataType.JSON),
FieldSchema(name="vector", dtype=DataType.BINARY_VECTOR, dim=DIM) if bin_vec else FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=DIM),
]

if has_array:
fields.append(FieldSchema(name="array_str", dtype=DataType.ARRAY, max_capacity=100, element_type=DataType.VARCHAR, max_length=128))
fields.append(FieldSchema(name="array_int", dtype=DataType.ARRAY, max_capacity=100, element_type=DataType.INT64))

schema = CollectionSchema(fields=fields, enable_dynamic_field=True)
return schema

Expand All @@ -118,8 +123,6 @@ def local_writer(schema: CollectionSchema, file_type: BulkFileType):
segment_size=128*1024*1024,
file_type=file_type,
) as local_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", local_writer)

# append rows
for i in range(100000):
Expand Down Expand Up @@ -245,6 +248,9 @@ def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFil
"json": {"dummy": i, "ok": f"name_{i}"},
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
# bulkinsert doesn't support import npy with array field, the below values will be stored into dynamic field
"array_str": [f"str_{k}" for k in range(5)],
"array_int": [k for k in range(10)],
}
remote_writer.append_row(row)

Expand All @@ -263,6 +269,9 @@ def all_types_writer(bin_vec: bool, schema: CollectionSchema, file_type: BulkFil
"json": json.dumps({"dummy": i, "ok": f"name_{i}"}),
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
# bulkinsert doesn't support import npy with array field, the below values will be stored into dynamic field
"array_str": np.array([f"str_{k}" for k in range(5)], np.dtype("str")),
"array_int": np.array([k for k in range(10)], np.dtype("int64")),
})

print(f"{remote_writer.total_row_count} rows appends")
Expand Down Expand Up @@ -383,15 +392,17 @@ def cloud_bulkinsert():
parallel_append(schema)

# float vectors + all scalar types
schema = build_all_type_schema(bin_vec=False)
for file_type in file_types:
# Note: bulkinsert doesn't support import npy with array field
schema = build_all_type_schema(bin_vec=False, has_array=False if file_type==BulkFileType.NPY else True)
batch_files = all_types_writer(bin_vec=False, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=False)

# binary vectors + all scalar types
schema = build_all_type_schema(bin_vec=True)
for file_type in file_types:
# Note: bulkinsert doesn't support import npy with array field
schema = build_all_type_schema(bin_vec=True, has_array=False if file_type == BulkFileType.NPY else True)
batch_files = all_types_writer(bin_vec=True, schema=schema, file_type=file_type)
call_bulkinsert(schema, batch_files)
retrieve_imported_data(bin_vec=True)
Expand Down
70 changes: 54 additions & 16 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .constants import (
DYNAMIC_FIELD_NAME,
MB,
NUMPY_TYPE_CREATOR,
BulkFileType,
)
Expand Down Expand Up @@ -74,6 +75,14 @@ def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)

def _raw_obj(self, x: object):
if isinstance(x, np.ndarray):
return x.tolist()
if isinstance(x, np.generic):
return x.item()

return x

def append_row(self, row: dict):
dynamic_values = {}
if DYNAMIC_FIELD_NAME in row and not isinstance(row[DYNAMIC_FIELD_NAME], dict):
Expand All @@ -85,14 +94,14 @@ def append_row(self, row: dict):
continue

if k not in self._buffer:
dynamic_values[k] = row[k]
dynamic_values[k] = self._raw_obj(row[k])
else:
self._buffer[k].append(row[k])

if DYNAMIC_FIELD_NAME in self._buffer:
self._buffer[DYNAMIC_FIELD_NAME].append(dynamic_values)

def persist(self, local_path: str) -> list:
def persist(self, local_path: str, **kwargs) -> list:
# verify row count of fields are equal
row_count = -1
for k in self._buffer:
Expand All @@ -107,17 +116,18 @@ def persist(self, local_path: str) -> list:

# output files
if self._file_type == BulkFileType.NPY:
return self._persist_npy(local_path)
return self._persist_npy(local_path, **kwargs)
if self._file_type == BulkFileType.JSON_RB:
return self._persist_json_rows(local_path)
return self._persist_json_rows(local_path, **kwargs)
if self._file_type == BulkFileType.PARQUET:
return self._persist_parquet(local_path)
return self._persist_parquet(local_path, **kwargs)

self._throw(f"Unsupported file tpye: {self._file_type}")
return []

def _persist_npy(self, local_path: str):
def _persist_npy(self, local_path: str, **kwargs):
file_list = []
row_count = len(next(iter(self._buffer.values())))
for k in self._buffer:
full_file_name = Path(local_path).joinpath(k + ".npy")
file_list.append(str(full_file_name))
Expand All @@ -127,7 +137,10 @@ def _persist_npy(self, local_path: str):
# numpy data type specify
dt = None
field_schema = self._fields[k]
if field_schema.dtype.name in NUMPY_TYPE_CREATOR:
if field_schema.dtype == DataType.ARRAY:
element_type = field_schema.element_type
dt = NUMPY_TYPE_CREATOR[element_type.name]
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]

# for JSON field, convert to string array
Expand All @@ -140,9 +153,9 @@ def _persist_npy(self, local_path: str):
arr = np.array(self._buffer[k], dtype=dt)
np.save(str(full_file_name), arr)
except Exception as e:
self._throw(f"Failed to persist column-based file {full_file_name}, error: {e}")
self._throw(f"Failed to persist file {full_file_name}, error: {e}")

logger.info(f"Successfully persist column-based file {full_file_name}")
logger.info(f"Successfully persist file {full_file_name}, row count: {row_count}")

if len(file_list) != len(self._buffer):
logger.error("Some of fields were not persisted successfully, abort the files")
Expand All @@ -154,7 +167,7 @@ def _persist_npy(self, local_path: str):

return file_list

def _persist_json_rows(self, local_path: str):
def _persist_json_rows(self, local_path: str, **kwargs):
rows = []
row_count = len(next(iter(self._buffer.values())))
row_index = 0
Expand All @@ -173,12 +186,12 @@ def _persist_json_rows(self, local_path: str):
with file_path.open("w") as json_file:
json.dump(data, json_file, indent=2)
except Exception as e:
self._throw(f"Failed to persist row-based file {file_path}, error: {e}")
self._throw(f"Failed to persist file {file_path}, error: {e}")

logger.info(f"Successfully persist row-based file {file_path}")
logger.info(f"Successfully persist file {file_path}, row count: {len(rows)}")
return [str(file_path)]

def _persist_parquet(self, local_path: str):
def _persist_parquet(self, local_path: str, **kwargs):
file_path = Path(local_path + ".parquet")

data = {}
Expand All @@ -203,10 +216,35 @@ def _persist_parquet(self, local_path: str):
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]
data[k] = pd.Series(self._buffer[k], dtype=dt)
else:
# dtype is null, let pandas deduce the type, might not work
data[k] = pd.Series(self._buffer[k])

# calculate a proper row group size
row_group_size_min = 1000
row_group_size = 10000
row_group_size_max = 1000000
if "buffer_size" in kwargs and "buffer_row_count" in kwargs:
row_group_bytes = kwargs.get(
"row_group_bytes", 32 * MB
) # 32MB is an experience value that avoid high memory usage of parquet reader on server-side
buffer_size = kwargs.get("buffer_size", 1)
buffer_row_count = kwargs.get("buffer_row_count", 1)
size_per_row = int(buffer_size / buffer_row_count) + 1
row_group_size = int(row_group_bytes / size_per_row)
if row_group_size < row_group_size_min:
row_group_size = row_group_size_min
if row_group_size > row_group_size_max:
row_group_size = row_group_size_max

# write to Parquet file
data_frame = pd.DataFrame(data=data)
data_frame.to_parquet(file_path, engine="pyarrow") # don't use fastparquet

logger.info(f"Successfully persist parquet file {file_path}")
data_frame.to_parquet(
file_path, row_group_size=row_group_size, engine="pyarrow"
) # don't use fastparquet

logger.info(
f"Successfully persist file {file_path}, total size: {buffer_size},"
f" row count: {buffer_row_count}, row group size: {row_group_size}"
)
return [str(file_path)]
94 changes: 68 additions & 26 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pymilvus.client.types import DataType
from pymilvus.exceptions import MilvusException
from pymilvus.orm.schema import CollectionSchema
from pymilvus.orm.schema import CollectionSchema, FieldSchema

from .buffer import (
Buffer,
Expand All @@ -39,6 +39,7 @@ def __init__(
schema: CollectionSchema,
segment_size: int,
file_type: BulkFileType = BulkFileType.NPY,
**kwargs,
):
self._schema = schema
self._buffer_size = 0
Expand Down Expand Up @@ -107,6 +108,62 @@ def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)

def _verify_vector(self, x: object, field: FieldSchema):
dtype = DataType(field.dtype)
validator = TYPE_VALIDATOR[dtype.name]
dim = field.params["dim"]
if not validator(x, dim):
self._throw(
f"Illegal vector data for vector field: '{field.name}',"
f" dim is not {dim} or type mismatch"
)

return len(x) * 4 if dtype == DataType.FLOAT_VECTOR else len(x) / 8

def _verify_json(self, x: object, field: FieldSchema):
size = 0
validator = TYPE_VALIDATOR[DataType.JSON.name]
if isinstance(x, str):
size = len(x)
x = self._try_convert_json(field.name, x)
elif validator(x):
size = len(json.dumps(x))
else:
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")

return x, size

def _verify_varchar(self, x: object, field: FieldSchema):
max_len = field.params["max_length"]
validator = TYPE_VALIDATOR[DataType.VARCHAR.name]
if not validator(x, max_len):
self._throw(
f"Illegal varchar value for field '{field.name}',"
f" length exceeds {max_len} or type mismatch"
)

return len(x)

def _verify_array(self, x: object, field: FieldSchema):
max_capacity = field.params["max_capacity"]
element_type = field.element_type
validator = TYPE_VALIDATOR[DataType.ARRAY.name]
if not validator(x, max_capacity):
self._throw(
f"Illegal array value for field '{field.name}', length exceeds capacity or type mismatch"
)

row_size = 0
if element_type.name in TYPE_SIZE:
row_size = TYPE_SIZE[element_type.name] * len(x)
elif element_type == DataType.VARCHAR:
for ele in x:
row_size = row_size + self._verify_varchar(ele, field)
else:
self._throw(f"Unsupported element type for array field '{field.name}'")

return row_size

def _verify_row(self, row: dict):
if not isinstance(row, dict):
self._throw("The input row must be a dict object")
Expand All @@ -125,41 +182,26 @@ def _verify_row(self, row: dict):
self._throw(f"The field '{field.name}' is missed in the row")

dtype = DataType(field.dtype)
validator = TYPE_VALIDATOR[dtype.name]
if dtype in {DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}:
if isinstance(row[field.name], np.ndarray):
row[field.name] = row[field.name].tolist()
dim = field.params["dim"]
if not validator(row[field.name], dim):
self._throw(
f"Illegal vector data for vector field: '{field.name}',"
f" dim is not {dim} or type mismatch"
)

vec_size = (
len(row[field.name]) * 4
if dtype == DataType.FLOAT_VECTOR
else len(row[field.name]) / 8
)
row_size = row_size + vec_size
row_size = row_size + self._verify_vector(row[field.name], field)
elif dtype == DataType.VARCHAR:
max_len = field.params["max_length"]
if not validator(row[field.name], max_len):
self._throw(
f"Illegal varchar value for field '{field.name}',"
f" length exceeds {max_len} or type mismatch"
)

row_size = row_size + len(row[field.name])
row_size = row_size + self._verify_varchar(row[field.name], field)
elif dtype == DataType.JSON:
row[field.name] = self._try_convert_json(field.name, row[field.name])
if not validator(row[field.name]):
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")
row[field.name], size = self._verify_json(row[field.name], field)
row_size = row_size + size
elif dtype == DataType.ARRAY:
if isinstance(row[field.name], np.ndarray):
row[field.name] = row[field.name].tolist()

row_size = row_size + len(row[field.name])
row_size = row_size + self._verify_array(row[field.name], field)
else:
if isinstance(row[field.name], np.generic):
row[field.name] = row[field.name].item()

validator = TYPE_VALIDATOR[dtype.name]
if not validator(row[field.name]):
self._throw(
f"Illegal scalar value for field '{field.name}', value overflow or type mismatch"
Expand Down
12 changes: 6 additions & 6 deletions pymilvus/bulk_writer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

TYPE_SIZE = {
DataType.BOOL.name: 1,
DataType.INT8.name: 8,
DataType.INT16.name: 8,
DataType.INT32.name: 8,
DataType.INT8.name: 1,
DataType.INT16.name: 2,
DataType.INT32.name: 4,
DataType.INT64.name: 8,
DataType.FLOAT.name: 8,
DataType.FLOAT.name: 4,
DataType.DOUBLE.name: 8,
}

Expand All @@ -43,10 +43,10 @@
DataType.FLOAT.name: lambda x: isinstance(x, float),
DataType.DOUBLE.name: lambda x: isinstance(x, float),
DataType.VARCHAR.name: lambda x, max_len: isinstance(x, str) and len(x) <= max_len,
DataType.JSON.name: lambda x: isinstance(x, dict) and len(x) <= 65535,
DataType.JSON.name: lambda x: isinstance(x, (dict, list)),
DataType.FLOAT_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) == dim,
DataType.BINARY_VECTOR.name: lambda x, dim: isinstance(x, list) and len(x) * 8 == dim,
DataType.ARRAY.name: lambda x: isinstance(x, list),
DataType.ARRAY.name: lambda x, cap: isinstance(x, list) and len(x) <= cap,
}

NUMPY_TYPE_CREATOR = {
Expand Down
Loading
Loading