Skip to content

Commit

Permalink
Expose streams in public round APIs (#16925)
Browse files Browse the repository at this point in the history
Contributes to #13744

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Bradley Dice (https://github.com/bdice)

URL: #16925
  • Loading branch information
Matt711 authored Nov 1, 2024
1 parent b5b47fe commit 0a87284
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
3 changes: 3 additions & 0 deletions cpp/include/cudf/round.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/export.hpp>
#include <cudf/utilities/memory_resource.hpp>

Expand Down Expand Up @@ -66,6 +67,7 @@ enum class rounding_method : int32_t { HALF_UP, HALF_EVEN };
* @param decimal_places Number of decimal places to round to (default 0). If negative, this
* specifies the number of positions to the left of the decimal point.
* @param method Rounding method
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
*
* @return Column with each of the values rounded
Expand All @@ -74,6 +76,7 @@ std::unique_ptr<column> round(
column_view const& input,
int32_t decimal_places = 0,
rounding_method method = rounding_method::HALF_UP,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

/** @} */ // end of group
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/round/round.cu
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,11 @@ std::unique_ptr<column> round(column_view const& input,
std::unique_ptr<column> round(column_view const& input,
int32_t decimal_places,
rounding_method method,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
return detail::round(input, decimal_places, method, cudf::get_default_stream(), mr);
return detail::round(input, decimal_places, method, stream, mr);
}

} // namespace cudf
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ ConfigureTest(STREAM_REDUCTION_TEST streams/reduction_test.cpp STREAM_MODE testi
ConfigureTest(STREAM_REPLACE_TEST streams/replace_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_RESHAPE_TEST streams/reshape_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_ROLLING_TEST streams/rolling_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_ROUND_TEST streams/round_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_SEARCH_TEST streams/search_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_SORTING_TEST streams/sorting_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_STREAM_COMPACTION_TEST streams/stream_compaction_test.cpp STREAM_MODE testing)
Expand Down
40 changes: 40 additions & 0 deletions cpp/tests/streams/round_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/default_stream.hpp>

#include <cudf/column/column_view.hpp>
#include <cudf/round.hpp>

#include <vector>

class RoundTest : public cudf::test::BaseFixture {};

TEST_F(RoundTest, RoundHalfToEven)
{
std::vector<double> vals = {1.729, 17.29, 172.9, 1729};
cudf::test::fixed_width_column_wrapper<double> input(vals.begin(), vals.end());
cudf::round(input, 0, cudf::rounding_method::HALF_UP, cudf::test::get_default_stream());
}

TEST_F(RoundTest, RoundHalfAwayFromEven)
{
std::vector<double> vals = {1.5, 2.5, 1.35, 1.45, 15, 25};
cudf::test::fixed_width_column_wrapper<double> input(vals.begin(), vals.end());
cudf::round(input, -1, cudf::rounding_method::HALF_EVEN, cudf::test::get_default_stream());
}

0 comments on commit 0a87284

Please sign in to comment.