Skip to content

Commit

Permalink
fix ip functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amorynan committed Nov 30, 2024
1 parent c01b79b commit 96f60c3
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 51 deletions.
62 changes: 19 additions & 43 deletions be/src/vec/functions/function_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,11 @@ class FunctionIPv4CIDRToRange : public IFunction {
}
};

/**
* this function accepts two arguments: an IPv6 address and a CIDR mask
* IPv6 address can be either ipv6 type or string type as ipv6 string address
* FE: PropagateNullable is used to handle nullable columns
*/
class FunctionIPv6CIDRToRange : public IFunction {
public:
static constexpr auto name = "ipv6_cidr_to_range";
Expand Down Expand Up @@ -900,9 +905,11 @@ class FunctionIPv6CIDRToRange : public IFunction {
col_res = execute_impl<ColumnIPv6>(*ipv6_addr_column, *cidr_col, input_rows_count,
add_col_const, col_const);
} else if (addr_type.is_string()) {
const auto* str_addr_column = assert_cast<const ColumnString*>(addr_column.get());
col_res = execute_impl<ColumnString>(*str_addr_column, *cidr_col, input_rows_count,
add_col_const, col_const);
ColumnPtr col_ipv6 =
convert_to_ipv6<IPConvertExceptionMode::Throw>(addr_column, nullptr);
const auto* ipv6_addr_column = assert_cast<const ColumnIPv6*>(col_ipv6.get());
col_res = execute_impl<ColumnIPv6>(*ipv6_addr_column, *cidr_col, input_rows_count,
add_col_const, col_const);
} else {
return Status::RuntimeError(
"Illegal column {} of argument of function {}, Expected IPv6 or String",
Expand All @@ -923,19 +930,8 @@ class FunctionIPv6CIDRToRange : public IFunction {
auto& vec_res_upper_range = col_res_upper_range->get_data();

static constexpr UInt8 max_cidr_mask = IPV6_BINARY_LENGTH * 8;
unsigned char ipv6_address_data[IPV6_BINARY_LENGTH];

if (is_addr_const) {
StringRef str_ref = from_column.get_data_at(0);
const char* value = str_ref.data;
size_t value_size = str_ref.size;
if (value_size > IPV6_BINARY_LENGTH || value == nullptr || value_size == 0 ||
!IPv6Value::is_valid_string(value, value_size)) {
throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal ipv6 address '{}'",
std::string(value, value_size));
}
memcpy(ipv6_address_data, value, value_size);
memset(ipv6_address_data + value_size, 0, IPV6_BINARY_LENGTH - value_size);
for (size_t i = 0; i < input_rows_count; ++i) {
auto cidr = cidr_column.get_int(i);
if (cidr < 0 || cidr > max_cidr_mask) {
Expand All @@ -945,9 +941,9 @@ class FunctionIPv6CIDRToRange : public IFunction {
if constexpr (std::is_same_v<FromColumn, ColumnString>) {
// 16 bytes ipv6 string is stored in big-endian byte order
// so transfer to little-endian firstly
std::reverse(ipv6_address_data, ipv6_address_data + IPV6_BINARY_LENGTH);
apply_cidr_mask(reinterpret_cast<const char*>(&ipv6_address_data),
reinterpret_cast<char*>(&vec_res_lower_range[i]),
auto* src_data = const_cast<char*>(from_column.get_data_at(0).data);
std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]),
reinterpret_cast<char*>(&vec_res_upper_range[i]),
cast_set<UInt8>(cidr));
} else {
Expand All @@ -967,19 +963,9 @@ class FunctionIPv6CIDRToRange : public IFunction {
if constexpr (std::is_same_v<FromColumn, ColumnString>) {
// 16 bytes ipv6 string is stored in big-endian byte order
// so transfer to little-endian firstly
StringRef str_ref = from_column.get_data_at(i);
const char* value = str_ref.data;
size_t value_size = str_ref.size;
if (value_size > IPV6_BINARY_LENGTH || value == nullptr || value_size == 0 ||
!IPv6Value::is_valid_string(value, value_size)) {
throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal ipv6 address '{}'",
std::string(value, value_size));
}
memcpy(ipv6_address_data, value, value_size);
memset(ipv6_address_data + value_size, 0, IPV6_BINARY_LENGTH - value_size);
std::reverse(ipv6_address_data, ipv6_address_data + IPV6_BINARY_LENGTH);
apply_cidr_mask(reinterpret_cast<const char*>(&ipv6_address_data),
reinterpret_cast<char*>(&vec_res_lower_range[i]),
auto* src_data = const_cast<char*>(from_column.get_data_at(i).data);
std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]),
reinterpret_cast<char*>(&vec_res_upper_range[i]),
cast_set<UInt8>(cidr));
} else {
Expand All @@ -999,19 +985,9 @@ class FunctionIPv6CIDRToRange : public IFunction {
if constexpr (std::is_same_v<FromColumn, ColumnString>) {
// 16 bytes ipv6 string is stored in big-endian byte order
// so transfer to little-endian firstly
StringRef str_ref = from_column.get_data_at(i);
const char* value = str_ref.data;
size_t value_size = str_ref.size;
if (value_size > IPV6_BINARY_LENGTH || value == nullptr || value_size == 0 ||
!IPv6Value::is_valid_string(value, value_size)) {
throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal ipv6 address '{}'",
std::string(value, value_size));
}
memcpy(ipv6_address_data, value, value_size);
memset(ipv6_address_data + value_size, 0, IPV6_BINARY_LENGTH - value_size);
std::reverse(ipv6_address_data, ipv6_address_data + IPV6_BINARY_LENGTH);
apply_cidr_mask(reinterpret_cast<const char*>(&ipv6_address_data),
reinterpret_cast<char*>(&vec_res_lower_range[i]),
auto* src_data = const_cast<char*>(from_column.get_data_at(i).data);
std::reverse(src_data, src_data + IPV6_BINARY_LENGTH);
apply_cidr_mask(src_data, reinterpret_cast<char*>(&vec_res_lower_range[i]),
reinterpret_cast<char*>(&vec_res_upper_range[i]),
cast_set<UInt8>(cidr));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ suite("nereids_scalar_fn_IP") {
qt_sql_cidr_ipv6_nullable_ "select id, ipv6_cidr_to_range(to_ipv6('::'), 32) from fn_test_ip_nullable order by id;"
test {
sql "select id, ipv6_cidr_to_range(nullable(''), 32) from fn_test_ip_nullable order by id"
exception "Illegal ipv6 address"
exception "Invalid IPv6 value"
}
test {
sql "select id, ipv6_cidr_to_range(nullable('abc'), 32) from fn_test_ip_not_nullable order by id"
exception "Invalid IPv6 value"
}
// test IPV4_STRING_TO_NUM/IPV6_STRING_TO_NUM (we have null value in ip4 and ip6 column in fn_test_ip_nullable table)
test {
Expand Down Expand Up @@ -162,7 +166,12 @@ suite("nereids_scalar_fn_IP") {
qt_sql_not_null_cidr_ipv6_nullable_ "select id, ipv6_cidr_to_range(to_ipv6('::'), 32) from fn_test_ip_nullable order by id;"
test {
sql "select id, ipv6_cidr_to_range(nullable(''), 32) from fn_test_ip_not_nullable order by id"
exception "Illegal ipv6 address"
exception "Invalid IPv6 value"
}

test {
sql "select id, ipv6_cidr_to_range(nullable('abc'), 32) from fn_test_ip_not_nullable order by id"
exception "Invalid IPv6 value"
}
// test IPV4_STRING_TO_NUM/IPV6_STRING_TO_NUM
qt_sql_not_null_ipv6_string_to_num 'select id, hex(ipv6_string_to_num(ip6)) from fn_test_ip_not_nullable order by id'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ suite("test_ipv6_cidr_to_range_function") {
(9, 'ffff:0000:0000:0000:0000:0000:0000:0000', NULL)
"""

qt_sql "select id, struct_element(ipv6_cidr_to_range(ipv6_string_to_num_or_null(addr), cidr), 'min') as min_range, struct_element(ipv6_cidr_to_range(ipv6_string_to_num_or_null(addr), cidr), 'max') as max_range from test_str_cidr_to_range_function order by id"
qt_sql "select id, struct_element(ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num_or_null(addr)), cidr), 'min') as min_range, struct_element(ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num_or_null(addr)), cidr), 'max') as max_range from test_str_cidr_to_range_function order by id"

sql """ DROP TABLE IF EXISTS test_str_cidr_to_range_function """

qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 0)"
qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001'), 128)"
qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'), 64)"
qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('0000:0000:0000:0000:0000:0000:0000:0000'), 8)"
qt_sql "select ipv6_cidr_to_range(ipv6_string_to_num('ffff:0000:0000:0000:0000:0000:0000:0000'), 4)"
qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001')), 0)"
qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('2001:0db8:0000:85a3:0000:0000:ac1f:8001')), 128)"
qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff')), 64)"
qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('0000:0000:0000:0000:0000:0000:0000:0000')), 8)"
qt_sql "select ipv6_cidr_to_range(ipv6_num_to_string(ipv6_string_to_num('ffff:0000:0000:0000:0000:0000:0000:0000')), 4)"
}

0 comments on commit 96f60c3

Please sign in to comment.