Skip to content

Commit

Permalink
1. Fix the bug of sort node return empty block if child eos is true a…
Browse files Browse the repository at this point in the history
…nd add some comment (apache#23)

2. Use SIMD to speed up has_null() in column nullable
3. Support UDAF of avg

Change-Id: I13846d7275e1cc37085d3afbf41d60261e296662

Co-authored-by: lihaopeng <[email protected]>
  • Loading branch information
HappenLee and lihaopeng committed Jul 13, 2021
1 parent 206fd97 commit 8259751
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 37 deletions.
1 change: 1 addition & 0 deletions be/src/vec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(VEC_FILES
aggregate_functions/aggregate_function_null.cpp
aggregate_functions/aggregate_function_sum.cpp
aggregate_functions/aggregate_function_min_max.cpp
aggregate_functions/aggregate_function_avg.cpp
columns/collator.cpp
columns/column.cpp
columns/column_const.cpp
Expand Down
50 changes: 50 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_avg.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/aggregate_functions/factory_helpers.h"

namespace doris::vectorized
{

namespace
{

template <typename T>
struct Avg
{
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType>>;
};

template <typename T>
using AggregateFuncAvg = typename Avg<T>::Function;

AggregateFunctionPtr createAggregateFunctionAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);

AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvg>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<AggregateFuncAvg>(*data_type, argument_types));

if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}

}

//void registerAggregateFunctionAvg(AggregateFunctionFactory & factory)
//{
// factory.registerFunction("avg", createAggregateFunctionAvg, AggregateFunctionFactory::CaseInsensitive);
//}

void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory) {
factory.registerFunction("avg", createAggregateFunctionAvg);
}
}
115 changes: 115 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#pragma once

#include "vec/data_types/data_types_number.h"
#include "vec/data_types/data_types_decimal.h"
#include "vec/columns/columns_number.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/io/io_helper.h"

namespace doris::vectorized
{

namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}

template <typename T>
struct AggregateFunctionAvgData
{
T sum = 0;
UInt64 count = 0;

template <typename ResultT>
ResultT NO_SANITIZE_UNDEFINED result() const
{
if constexpr (std::is_floating_point_v<ResultT>)
if constexpr (std::numeric_limits<ResultT>::is_iec559)
return static_cast<ResultT>(sum) / count; /// allow division by zero

if (!count)
throw Exception("AggregateFunctionAvg with zero values", ErrorCodes::LOGICAL_ERROR);
return static_cast<ResultT>(sum) / count;
}

void write(std::ostream& buf) const {
writeBinary(sum, buf);
writeBinary(count, buf);
}

void read(std::istream& buf) {
readBinary(sum, buf);
readBinary(count, buf);
}
};


/// Calculates arithmetic mean of numbers.
template <typename T, typename Data>
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>
{
public:
using ResultType = std::conditional_t<IsDecimalNumber<T>, Decimal128, Float64>;
using ResultDataType = std::conditional_t<IsDecimalNumber<T>, DataTypeDecimal<Decimal128>, DataTypeNumber<Float64>>;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, ColumnVector<Float64>>;

/// ctor for native types
AggregateFunctionAvg(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {})
, scale(0)
{}

/// ctor for Decimals
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {})
, scale(getDecimalScale(data_type))
{}

String getName() const override { return "avg"; }

DataTypePtr getReturnType() const override
{
if constexpr (IsDecimalNumber<T>)
return std::make_shared<ResultDataType>(ResultDataType::maxPrecision(), scale);
else
return std::make_shared<ResultDataType>();
}

void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & column = static_cast<const ColVecType &>(*columns[0]);
this->data(place).sum += column.getData()[row_num];
++this->data(place).count;
}

void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).sum += this->data(rhs).sum;
this->data(place).count += this->data(rhs).count;
}

void serialize(ConstAggregateDataPtr place, std::ostream& buf) const override
{
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr place, std::istream& buf, Arena *) const override
{
this->data(place).read(buf);
}

void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & column = static_cast<ColVecResult &>(to);
column.getData().push_back(this->data(place).template result<ResultType>());
}

const char * getHeaderFilePath() const override { return __FILE__; }

private:
UInt32 scale;
};


}
1 change: 1 addition & 0 deletions be/src/vec/aggregate_functions/aggregate_function_null.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionSimpleFactory& fac
factory.registerFunction("sum", creator, true);
factory.registerFunction("max", creator, true);
factory.registerFunction("min", creator, true);
factory.registerFunction("avg", creator, true);
}

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AggregateFunctionSimpleFactory;
void registerAggregateFunctionSum(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCombinatorNull(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionMinMax(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory);

using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
Expand Down Expand Up @@ -82,6 +83,7 @@ class AggregateFunctionSimpleFactory {
std::call_once(oc, [&]() {
registerAggregateFunctionSum(instance);
registerAggregateFunctionMinMax(instance);
registerAggregateFunctionAvg(instance);
registerAggregateFunctionCombinatorNull(instance);
});
return instance;
Expand Down
33 changes: 27 additions & 6 deletions be/src/vec/columns/column_nullable.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,34 @@ class ColumnNullable final : public COWHelper<IColumn, ColumnNullable> {
void checkConsistency() const;

bool has_null() const {
auto begin = getNullMapData().begin();
auto end = getNullMapData().end();
while (begin < end) {
if (*begin != 0) {
return *begin;
size_t size = getNullMapData().size();
const UInt8* null_pos = getNullMapData().data();
const UInt8* null_pos_end = getNullMapData().data() + size;
#ifdef __SSE2__
/** A slightly more optimized version.
* Based on the assumption that often pieces of consecutive values
* completely pass or do not pass the filter.
* Therefore, we will optimistically check the parts of `SIMD_BYTES` values.
*/
static constexpr size_t SIMD_BYTES = 16;
const __m128i zero16 = _mm_setzero_si128();
const UInt8* null_end_sse = null_pos + size / SIMD_BYTES * SIMD_BYTES;

while (null_pos < null_end_sse) {
int mask = _mm_movemask_epi8(_mm_cmpgt_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(null_pos)), zero16));

if (0 != mask) {
return true;
}
++begin;
null_pos += SIMD_BYTES;
}
#endif
while (null_pos < null_pos_end) {
if (*null_pos != 0) {
return true;
}
null_pos++;
}
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/core/sort_cursor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct ReceiveQueueSortCursorImpl : public SortCursorImpl {
desc[i].direction = is_asc_order[i] ? 1 : -1;
desc[i].nulls_direction = nulls_first[i] ? 1 : -1;
}
has_next_block();
_is_eof = !has_next_block();
}

bool has_next_block() override {
Expand Down
5 changes: 3 additions & 2 deletions be/src/vec/exec/vaggregation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ struct AggregationMethodSerialized {

Data data;

AggregationMethodSerialized() {}
AggregationMethodSerialized() = default;

template <typename Other>
AggregationMethodSerialized(const Other& other) : data(other.data) {}
explicit AggregationMethodSerialized(const Other& other) : data(other.data) {}

using State = ColumnsHashing::HashMethodSerialized<typename Data::value_type, Mapped>;

Expand Down Expand Up @@ -153,6 +153,7 @@ class AggregationNode : public ::doris::ExecNode {
using vectorized_execute = std::function<Status(Block* block)>;
using vectorized_get_result =
std::function<Status(RuntimeState* state, Block* block, bool* eos)>;

struct executor {
vectorized_execute execute;
vectorized_get_result get_result;
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/exec/vsort_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Status VSortNode::sort_input(RuntimeState* state) {
do {
Block block;
RETURN_IF_ERROR(child(0)->get_next(state, &block, &eos));
if (!eos && block.rows() != 0) {
if ( block.rows() != 0) {
RETURN_IF_ERROR(pretreat_block(block));
_sorted_blocks.emplace_back(std::move(block));
RETURN_IF_CANCELLED(state);
Expand Down
18 changes: 5 additions & 13 deletions be/src/vec/exec/vsort_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,9 @@
namespace doris {

namespace vectorized {
// Node that implements a full sort of its input with a fixed memory budget, spilling
// to disk if the input is larger than available memory.
// Uses SpillSorter and BufferedBlockMgr for the external sort implementation.
// Input rows to SortNode are materialized by the SpillSorter into a single tuple
// using the expressions specified in _sort_exec_exprs.
// In get_next(), SortNode passes in the output batch to the sorter instance created
// in open() to fill it with sorted rows.
// If a merge phase was performed in the sort, sorted rows are deep copied into
// the output batch. Otherwise, the sorter instance owns the sorted data.
// Node that implements a full sort of its input with a fixed memory budget
// In open() the input Block to VSortNode will sort firstly, using the expressions specified in _sort_exec_exprs.
// In get_next(), VSortNode do the merge sort to gather data to a new block

// support spill to disk in the future
class VSortNode : public doris::ExecNode {
Expand Down Expand Up @@ -75,7 +69,7 @@ class VSortNode : public doris::ExecNode {
// Number of rows to skip.
int64_t _offset;

// Expressions and parameters used for tuple materialization and tuple comparison.
// Expressions and parameters used for build _sort_description
VSortExecExprs _vsort_exec_exprs;
std::vector<bool> _is_asc_order;
std::vector<bool> _nulls_first;
Expand All @@ -85,11 +79,9 @@ class VSortNode : public doris::ExecNode {
std::vector<Block> _sorted_blocks;
std::priority_queue<SortCursor> _priority_queue;

// TODO: Not using now, maybe should be delete
// Keeps track of the number of rows skipped for handling _offset.
int64_t _num_rows_skipped;

// END: Members that must be reset()
/////////////////////////////////////////
};

}
Expand Down
20 changes: 6 additions & 14 deletions be/src/vec/runtime/vsorted_run_merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,12 @@ class RuntimeProfile;

namespace vectorized {
class Block;
// VSortedRunMerger is used to merge multiple sorted runs of tuples. A run is a sorted
// sequence of row batches, which are fetched from a BlockSupplier function object.
// VSortedRunMerger is used to merge multiple sorted runs of blocks. A run is a sorted
// sequence of blocks, which are fetched from a BlockSupplier function object.
// Merging is implemented using a binary min-heap that maintains the run with the next
// tuple in sorted order at the top of the heap.
// rows in sorted order at the top of the heap.
//
// Merged batches of rows are retrieved from VSortedRunMerger via calls to get_next().
// The merger is constructed with a boolean flag deep_copy_input.
// If true, sorted output rows are deep copied into the data pool of the output batch.
// If false, get_next() only copies tuple pointers (TupleRows) into the output batch,
// and transfers resource ownership from the input batches to the output batch when
// an input batch is processed.
// Merged block of rows are retrieved from VSortedRunMerger via calls to get_next().
class VSortedRunMerger {
public:
// Function that returns the next block of rows from an input sorted run. The batch
Expand All @@ -57,10 +52,10 @@ class VSortedRunMerger {
// the priority queue.
Status prepare(const std::vector<BlockSupplier>& input_runs, bool parallel = false);

// Return the next batch of sorted rows from this merger.
// Return the next block of sorted rows from this merger.
Status get_next(Block* output_block, bool *eos);

// Only Child class implement this Method, Return the next batch of sorted rows from this merger.
// Do not support now
virtual Status get_batch(RowBatch **output_batch) {
return Status::InternalError("no support method get_batch(RowBatch** output_batch)");
}
Expand All @@ -80,9 +75,6 @@ class VSortedRunMerger {

Block _empty_block;

// Pool of BatchedRowSupplier instances.
ObjectPool _pool;

// Times calls to get_next().
RuntimeProfile::Counter *_get_next_timer;

Expand Down

0 comments on commit 8259751

Please sign in to comment.