Skip to content

Commit

Permalink
feat: Update segcore for VECTOR_INT8 (#39415)
Browse files Browse the repository at this point in the history
Issue: #38666

Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Jan 21, 2025
1 parent 905c3b8 commit 341d6c1
Show file tree
Hide file tree
Showing 29 changed files with 332 additions and 50 deletions.
23 changes: 18 additions & 5 deletions internal/core/src/common/ChunkWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,14 @@ create_chunk(const FieldMeta& field_meta,
}
case milvus::DataType::VECTOR_FLOAT: {
w = std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, float>>(dim, nullable);
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::fp32>>(
dim, nullable);
break;
}
case milvus::DataType::VECTOR_BINARY: {
w = std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, uint8_t>>(dim / 8,
nullable);
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::bin1>>(
dim / 8, nullable);
break;
}
case milvus::DataType::VECTOR_FLOAT16: {
Expand All @@ -377,6 +378,12 @@ create_chunk(const FieldMeta& field_meta,
dim, nullable);
break;
}
case milvus::DataType::VECTOR_INT8: {
w = std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::int8>>(
dim, nullable);
break;
}
case milvus::DataType::VARCHAR:
case milvus::DataType::STRING: {
w = std::make_shared<StringChunkWriter>(nullable);
Expand Down Expand Up @@ -450,13 +457,13 @@ create_chunk(const FieldMeta& field_meta,
}
case milvus::DataType::VECTOR_FLOAT: {
w = std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, float>>(
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::fp32>>(
dim, file, file_offset, nullable);
break;
}
case milvus::DataType::VECTOR_BINARY: {
w = std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, uint8_t>>(
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::bin1>>(
dim / 8, file, file_offset, nullable);
break;
}
Expand All @@ -472,6 +479,12 @@ create_chunk(const FieldMeta& field_meta,
dim, file, file_offset, nullable);
break;
}
case milvus::DataType::VECTOR_INT8: {
w = std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::int8>>(
dim, file, file_offset, nullable);
break;
}
case milvus::DataType::VARCHAR:
case milvus::DataType::STRING: {
w = std::make_shared<StringChunkWriter>(
Expand Down
1 change: 1 addition & 0 deletions internal/core/src/common/FieldData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
case DataType::VECTOR_FLOAT:
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_INT8:
case DataType::VECTOR_BINARY: {
auto array_info =
GetDataInfoFromArray<arrow::FixedSizeBinaryArray,
Expand Down
11 changes: 11 additions & 0 deletions internal/core/src/common/FieldData.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ class FieldData<SparseFloatVector> : public FieldDataSparseVectorImpl {
}
};

template <>
class FieldData<Int8Vector> : public FieldDataImpl<int8, false> {
public:
explicit FieldData(int64_t dim,
DataType data_type,
int64_t buffered_num_rows = 0)
: FieldDataImpl<int8, false>::FieldDataImpl(
dim, data_type, false, buffered_num_rows) {
}
};

using FieldDataPtr = std::shared_ptr<FieldDataBase>;
using FieldDataChannel = Channel<FieldDataPtr>;
using FieldDataChannelPtr = std::shared_ptr<FieldDataChannel>;
Expand Down
17 changes: 11 additions & 6 deletions internal/core/src/common/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ GetDataTypeSize(DataType data_type, int dim = 1) {
AssertInfo(dim % 8 == 0, "dim={}", dim);
return dim / 8;
}
case DataType::VECTOR_FLOAT16: {
case DataType::VECTOR_FLOAT16:
return sizeof(float16) * dim;
}
case DataType::VECTOR_BFLOAT16: {
case DataType::VECTOR_BFLOAT16:
return sizeof(bfloat16) * dim;
}
case DataType::VECTOR_INT8:
return sizeof(int8) * dim;
// Not supporting variable length types(such as VECTOR_SPARSE_FLOAT and
// VARCHAR) here intentionally. We can't easily estimate the size of
// them. Caller of this method must handle this case themselves and must
Expand Down Expand Up @@ -192,6 +192,8 @@ GetDataTypeName(DataType data_type) {
return "vector_bfloat16";
case DataType::VECTOR_SPARSE_FLOAT:
return "vector_sparse_float";
case DataType::VECTOR_INT8:
return "vector_int8";
default:
PanicInfo(DataTypeInvalid, "Unsupported DataType({})", data_type);
}
Expand Down Expand Up @@ -325,7 +327,7 @@ IsSparseFloatVectorDataType(DataType data_type) {
}

inline bool
IsInt8VectorDataType(DataType data_type) {
IsIntVectorDataType(DataType data_type) {
return data_type == DataType::VECTOR_INT8;
}

Expand All @@ -338,7 +340,7 @@ IsFloatVectorDataType(DataType data_type) {
inline bool
IsVectorDataType(DataType data_type) {
return IsBinaryVectorDataType(data_type) ||
IsFloatVectorDataType(data_type) || IsInt8VectorDataType(data_type);
IsFloatVectorDataType(data_type) || IsIntVectorDataType(data_type);
}

inline bool
Expand Down Expand Up @@ -642,6 +644,9 @@ struct fmt::formatter<milvus::DataType> : formatter<string_view> {
case milvus::DataType::VECTOR_SPARSE_FLOAT:
name = "VECTOR_SPARSE_FLOAT";
break;
case milvus::DataType::VECTOR_INT8:
name = "VECTOR_INT8";
break;
}
return formatter<string_view>::format(name, ctx);
}
Expand Down
8 changes: 0 additions & 8 deletions internal/core/src/common/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,6 @@ IsMetricType(const std::string_view str,
return !strcasecmp(str.data(), metric_type.c_str());
}

inline bool
IsFloatMetricType(const knowhere::MetricType& metric_type) {
return IsMetricType(metric_type, knowhere::metric::L2) ||
IsMetricType(metric_type, knowhere::metric::IP) ||
IsMetricType(metric_type, knowhere::metric::COSINE) ||
IsMetricType(metric_type, knowhere::metric::BM25);
}

inline bool
PositivelyRelated(const knowhere::MetricType& metric_type) {
return IsMetricType(metric_type, knowhere::metric::IP) ||
Expand Down
36 changes: 27 additions & 9 deletions internal/core/src/common/VectorTrait.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,30 @@ namespace milvus {

#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
using elem_type = std::conditional_t< \
std::is_same_v<TraitType, milvus::BinaryVector>, \
BinaryVector::embedded_type, \
std::is_same_v<TraitType, milvus::FloatVector>, \
milvus::FloatVector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::Float16Vector>, \
Float16Vector::embedded_type, \
milvus::Float16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
BFloat16Vector::embedded_type, \
FloatVector::embedded_type>>>;
milvus::BFloat16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::Int8Vector>, \
milvus::Int8Vector::embedded_type, \
milvus::BinaryVector::embedded_type>>>>;

#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
auto schema_data_type = \
std::is_same_v<TraitType, milvus::FloatVector> \
? FloatVector::schema_data_type \
? milvus::FloatVector::schema_data_type \
: std::is_same_v<TraitType, milvus::Float16Vector> \
? Float16Vector::schema_data_type \
? milvus::Float16Vector::schema_data_type \
: std::is_same_v<TraitType, milvus::BFloat16Vector> \
? BFloat16Vector::schema_data_type \
: BinaryVector::schema_data_type;
? milvus::BFloat16Vector::schema_data_type \
: std::is_same_v<TraitType, milvus::Int8Vector> \
? milvus::Int8Vector::schema_data_type \
: milvus::BinaryVector::schema_data_type;

class VectorTrait {};

Expand Down Expand Up @@ -118,6 +123,19 @@ class SparseFloatVector : public VectorTrait {
proto::common::PlaceholderType::SparseFloatVector;
};

class Int8Vector : public VectorTrait {
public:
using embedded_type = int8;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_INT8;
static constexpr auto c_data_type = CDataType::Int8Vector;
static constexpr auto schema_data_type =
proto::schema::DataType::Int8Vector;
static constexpr auto vector_type = proto::plan::VectorType::Int8Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::Int8Vector;
};

template <typename T>
constexpr bool IsVector = std::is_base_of_v<VectorTrait, T>;

Expand Down
1 change: 1 addition & 0 deletions internal/core/src/common/type_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum CDataType {
Float16Vector = 102,
BFloat16Vector = 103,
SparseFloatVector = 104,
Int8Vector = 105,
};
typedef enum CDataType CDataType;

Expand Down
13 changes: 13 additions & 0 deletions internal/core/src/index/IndexFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ IndexFactory::VecIndexLoadResource(
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(
index_type, index_version, config);
break;
case milvus::DataType::VECTOR_INT8:
resource = knowhere::IndexStaticFaced<
knowhere::int8>::EstimateLoadResource(index_type,
index_version,
index_size_gb,
config);
has_raw_data =
knowhere::IndexStaticFaced<knowhere::int8>::HasRawData(
index_type, index_version, config);
break;
default:
LOG_ERROR("invalid data type to estimate index load resource: {}",
field_type);
Expand Down Expand Up @@ -426,6 +436,9 @@ IndexFactory::CreateVectorIndex(
return std::make_unique<VectorDiskAnnIndex<float>>(
index_type, metric_type, version, file_manager_context);
}
case DataType::VECTOR_INT8: {
// TODO caiyd, not support yet
}
default:
PanicInfo(
DataTypeInvalid,
Expand Down
5 changes: 5 additions & 0 deletions internal/core/src/query/ExecPlanNodeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,9 @@ ExecPlanNodeVisitor::visit(SparseFloatVectorANNS& node) {
VectorVisitorImpl<SparseFloatVector>(node);
}

void
ExecPlanNodeVisitor::visit(Int8VectorANNS& node) {
VectorVisitorImpl<Int8Vector>(node);
}

} // namespace milvus::query
3 changes: 3 additions & 0 deletions internal/core/src/query/ExecPlanNodeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
void
visit(SparseFloatVectorANNS& node) override;

void
visit(Int8VectorANNS& node) override;

void
visit(RetrievePlanNode& node) override;

Expand Down
5 changes: 5 additions & 0 deletions internal/core/src/query/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ SparseFloatVectorANNS::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);
}

void
Int8VectorANNS::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);
}

void
RetrievePlanNode::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);
Expand Down
6 changes: 6 additions & 0 deletions internal/core/src/query/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ struct SparseFloatVectorANNS : VectorPlanNode {
accept(PlanNodeVisitor&) override;
};

struct Int8VectorANNS : VectorPlanNode {
public:
void
accept(PlanNodeVisitor&) override;
};

struct RetrievePlanNode : PlanNode {
public:
void
Expand Down
3 changes: 3 additions & 0 deletions internal/core/src/query/PlanNodeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class PlanNodeVisitor {
virtual void
visit(SparseFloatVectorANNS&) = 0;

virtual void
visit(Int8VectorANNS&) = 0;

virtual void
visit(RetrievePlanNode&) = 0;
};
Expand Down
3 changes: 3 additions & 0 deletions internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::SparseFloatVector) {
return std::make_unique<SparseFloatVectorANNS>();
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::Int8Vector) {
return std::make_unique<Int8VectorANNS>();
} else {
return std::make_unique<FloatVectorANNS>();
}
Expand Down
Loading

0 comments on commit 341d6c1

Please sign in to comment.