-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Hausdorff distance to header-only API #538
Merged
rapids-bot
merged 112 commits into
rapidsai:branch-22.08
from
harrism:fea-header-only-hausdorff
Jun 23, 2022
Merged
Changes from 107 commits
Commits
Show all changes
112 commits
Select commit
Hold shift + click to select a range
314580d
Create header-only refactoring of cuspatial::haversine_distance
harrism 0736356
Merge branch 'branch-22.04' into fea-header-only-haversine
harrism 5830b66
Apply suggestions from code review
harrism 4c16cd6
require RandomAccessIterator
harrism 7580661
Merge branch 'fea-header-only-haversine' of github.com:harrism/cuspat…
harrism 073e2d7
Convert haversine API to use AOS inputs.
harrism e37d61e
Revert cosmetic changes to top-level haversine.hpp
harrism 9d8e3eb
Align location_2d and remove unused location_3d and coord_2d.
harrism 7f2bbad
__device__ only
harrism 2448677
"" --> <>
harrism 0d954e0
Remove unused macro.
harrism 7d100dc
Add refactoring guide.
harrism daa82a8
Add refactoring guide.
harrism c4ba1f7
Merge branch 'branch-22.04' into fea-header-only-haversine
harrism 454d967
Add fancy iterator test
harrism a5dab4a
Merge branch 'branch-22.06' into fea-header-only-haversine
harrism 08dfe95
.hpp->.cuh
harrism e373065
Add note about not making tests depend on libcudf_test
harrism 49bb466
lonlat_to_cartesian declaration
harrism 3ae133e
gitignore
harrism f8947eb
Add missing include and @
harrism 564cc4c
Don't hide the stream parameter in the detail layer.
harrism 8b3edae
Merge branch 'fea-header-only-haversine' into fea-header-only-coordin…
harrism 9e7c8f0
Stream parameter
harrism 4a6976e
header cleanup
harrism 048df0d
Implementation progress (not working)
harrism c208144
Simplify coordinate types to a single vec_2d
harrism 67ab7b7
Merge branch 'fea-header-only-haversine' into fea-header-only-coordin…
harrism 45953e6
lonlat_to_cartesian refactored and new tests added
harrism 21db279
Clean up haversine_test.cu includes
harrism 66a7e1f
Merge branch 'fea-header-only-haversine' into fea-header-only-coordin…
harrism be1226c
type-safe vectors
harrism 8134f12
Fix typo
harrism 42f909e
vec_2d --> lonlat_2d in docs
harrism e5cb703
Merge branch 'branch-22.06' into fea-header-only-haversine
harrism 7939b15
Merge branch 'fea-header-only-haversine' into fea-header-only-coordin…
harrism 213299b
Update for type-safe vector types
harrism 0b78d87
Review suggestions
harrism c6a392c
Merge branch 'fea-header-only-haversine' into fea-header-only-coordin…
harrism 17569c6
Clarified documentation / refactoring guide.
harrism d5ef3b7
style
harrism 2acba65
Merge branch 'fea-header-only-haversine' into fea-header-only-coordin…
harrism 4883664
Merge branch 'branch-22.06' into fea-header-only-coordinate-transform
harrism 7af5673
hausdorff header
harrism ae8ba04
.hpp --> .cuh
harrism f863e61
Doc fix
harrism 4f859ba
Document template parameters and preconditions
harrism 2c42845
header-only Implementation
harrism 434cbe9
Explicit lonlat_2d<T> type.
harrism 6858cc4
Merge branch 'fea-header-only-coordinate-transform' into fea-header-o…
harrism 50146a8
Initial conversion, passes compilation and test
isVoid 1861dc9
Add docstring
isVoid 0d5dc94
Add RAI specification
isVoid ca54b2c
add default stream parameter
isVoid a427af6
Add first test and cast references around.
isVoid 2c6ec85
Add more tests
isVoid 9958f48
fix offset arrays
isVoid dc944c5
fix wrong gtest binary name
isVoid e36186a
Add precommit hooks and script for cmake format/lint
isVoid 62aa1ad
update with optimized code
isVoid af38652
remove dependency on cudf atomics
isVoid ae11443
regroup includes
isVoid 79d57aa
Use size_t as index type.
isVoid e2a3db7
some fixes on tests
isVoid a1512d8
Revert cmake-format and precommit hooks
isVoid 9b1687c
Remove `Cart2dA` and `Cart2dB`
isVoid dea41f1
Documentation update
isVoid 270cfe2
Style fix
isVoid 37fe51e
Update to use latest vec_2d changes.
harrism 815a00b
Improve vec_2d documentation
harrism 2204661
Merge branch 'fea-header-only-coordinate-transform' into fea-header-o…
harrism cb5d301
fix broken compile
isVoid 183c10c
Removes `device_atomics` usage
isVoid b2fbaad
Add `internal` marker to internal docstrings.
isVoid 1752485
Move derived traits to traits.hpp
isVoid fe83ea2
add back raw_reference_cast
isVoid eebb9d6
Add libcudacxx cmake dependency
harrism 13c8144
Convert cudf-based API to use header-only API
harrism 7ee9b94
Revert "Removes `device_atomics` usage"
isVoid 5926bcc
add atomicMax
isVoid 8194142
Address atomic operation review
isVoid 4a81a5b
style
isVoid 0309fd9
Revert "Address atomic operation review"
isVoid ce6280f
address device atomics reviews
isVoid 2c5e8bf
inline `addr` dereference
isVoid 80fceba
Reverting attempts to cast to `ll`, not `ull`
isVoid d849ce3
Add mutable requirement for OutputIterator
harrism e8fe80b
Merge branch 'branch-22.06' into feature/header_only_linestring_distance
harrism d441462
Document atomics
harrism 9a15132
Remove erroneous nested std::vector
harrism d242701
Responds to review feedback
harrism 2c7ee53
Merge branch 'branch-22.06' into fea-header-only-hausdorff
harrism bda3226
Merge branch 'feature/header_only_linestring_distance' into fea-heade…
harrism b9c704c
Enable multiple includes
harrism f8d10c4
Use deviceAtomics.cuh
harrism 1a8132a
Merge branch 'branch-22.08' into fea-header-only-hausdorff
harrism 7a2ee08
Implement tests for header-only API
harrism 76280e7
Remove non-cuDF-specific tests of cuDF-based API.
harrism a7a318a
copyright
harrism d6a2955
Initial conversion, passes compilation and test
harrism 1df7784
Remove get_libcudacxx.cmake
harrism f29fde4
Remove reference to types.hpp
harrism 64a7e2a
Assert that output iterator value_type is floating point.
harrism c7ba2ce
Doc fixes/improvements based on review
harrism ff62071
Fix preconditions.
harrism 2996a59
Remove invalid preconditions.
harrism 76604d4
Add non-overlap precondition for coordinate_transform.
harrism c56a289
Improve readability with suggestion from @isVoid
harrism 54857ba
Merge branch 'branch-22.08' into fea-header-only-hausdorff
harrism 7535336
Fix vec_2d includes
harrism 1636ba2
Merge branch 'branch-22.08' into fea-header-only-hausdorff
harrism b8766e2
Documentation and vec_2d header location fix
harrism File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
cpp/include/cuspatial/experimental/detail/hausdorff.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cuspatial/error.hpp> | ||
#include <cuspatial/utility/device_atomics.cuh> | ||
#include <cuspatial/utility/traits.hpp> | ||
#include <cuspatial/utility/vec_2d.hpp> | ||
#include <rmm/cuda_stream_view.hpp> | ||
#include <rmm/exec_policy.hpp> | ||
|
||
#include <thrust/binary_search.h> | ||
|
||
#include <cuda/atomic> | ||
|
||
#include <type_traits> | ||
|
||
namespace cuspatial { | ||
|
||
namespace detail { | ||
|
||
template <typename T> | ||
constexpr auto magnitude_squared(T a, T b) | ||
{ | ||
return a * a + b * b; | ||
} | ||
|
||
/** | ||
* @brief computes Hausdorff distance by equally dividing up work on a per-thread basis. | ||
* | ||
* Each thread is responsible for computing the distance from a single point in the input against | ||
* all other points in the input. Because points in the input can originate from different spaces, | ||
* each thread must know which spaces it is comparing. For the LHS argument, the point is always | ||
* the same for any given thread and is determined once for that thread using a binary search of | ||
* the provided space_offsets. Therefore if space 0 contains 10 points, the first 10 threads will | ||
* know that the LHS space is 0. The 11th thread will know the LHS space is 1, and so on depending | ||
* on the sizes/offsets of each space. Each thread then loops over each space, and uses an inner | ||
* loop to loop over each point within that space, thereby knowing the RHS space and RHS point. | ||
* the thread computes the minimum distance from it's LHS point to _any_ point in the RHS space, as | ||
* this is the first step to computing Hausdorff distance. The second step of computing Hausdorff | ||
* distance is to determine the maximum of these minimums, which is done by each thread writing | ||
* it's minimum to the output using atomicMax. This is done once per thread per RHS space. Once | ||
* all threads have run to completion, all "maximums of the minumum distances" (aka, directed | ||
* Hausdorff distances) reside in the output. | ||
* | ||
* @tparam T type of coordinate, either float or double. | ||
* @param num_points number of total points in the input (sum of points from all spaces) | ||
* @param points x/y points to compute the distances between | ||
* @param num_spaces number of spaces in the input | ||
* @param space_offsets starting position of first point in each space | ||
* @param results directed Hausdorff distances computed by kernel | ||
*/ | ||
template <typename T, | ||
typename Index, | ||
typename PointsIter, | ||
typename OffsetsIter, | ||
typename OutputIter> | ||
__global__ void kernel_hausdorff(Index num_points, | ||
PointsIter points, | ||
Index num_spaces, | ||
OffsetsIter space_offsets, | ||
OutputIter results) | ||
{ | ||
using Point = typename std::iterator_traits<PointsIter>::value_type; | ||
|
||
// determine the LHS point this thread is responsible for. | ||
auto const thread_idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
Index const lhs_p_idx = thread_idx; | ||
|
||
if (lhs_p_idx >= num_points) { return; } | ||
|
||
// determine the LHS space this point belongs to. | ||
Index const lhs_space_idx = | ||
thrust::distance( | ||
space_offsets, | ||
thrust::upper_bound(thrust::seq, space_offsets, space_offsets + num_spaces, lhs_p_idx)) - | ||
1; | ||
isVoid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// get the coordinates of this LHS point. | ||
Point const lhs_p = points[lhs_p_idx]; | ||
|
||
// loop over each RHS space, as determined by space_offsets | ||
for (uint32_t rhs_space_idx = 0; rhs_space_idx < num_spaces; rhs_space_idx++) { | ||
// determine the begin/end offsets of points contained within this RHS space. | ||
Index const rhs_p_idx_begin = space_offsets[rhs_space_idx]; | ||
Index const rhs_p_idx_end = | ||
(rhs_space_idx + 1 == num_spaces) ? num_points : space_offsets[rhs_space_idx + 1]; | ||
|
||
// each space must contain at least one point, this initial value is just an identity value to | ||
// simplify calculations. If a space contains <= 0 points, then this initial value will be | ||
// written to the output, which can serve as a signal that the input is ill-formed. | ||
auto min_distance_squared = std::numeric_limits<T>::max(); | ||
|
||
// loop over each point in the current RHS space | ||
for (uint32_t rhs_p_idx = rhs_p_idx_begin; rhs_p_idx < rhs_p_idx_end; rhs_p_idx++) { | ||
// get the x and y coordinate of this RHS point | ||
Point const rhs_p = thrust::raw_reference_cast(points[rhs_p_idx]); | ||
|
||
// get distance between the LHS and RHS point | ||
auto const distance_squared = magnitude_squared(rhs_p.x - lhs_p.x, rhs_p.y - lhs_p.y); | ||
|
||
// remember only smallest distance from this LHS point to any RHS point. | ||
min_distance_squared = min(min_distance_squared, distance_squared); | ||
} | ||
|
||
// determine the output offset for this pair of spaces (LHS, RHS) | ||
Index output_idx = lhs_space_idx * num_spaces + rhs_space_idx; | ||
|
||
// use atomicMax to find the maximum of the minimum distance calculated for each space pair. | ||
atomicMax(&thrust::raw_reference_cast(*(results + output_idx)), | ||
static_cast<T>(std::sqrt(min_distance_squared))); | ||
} | ||
} | ||
|
||
} // namespace detail | ||
|
||
template <class PointIt, class OffsetIt, class OutputIt> | ||
OutputIt directed_hausdorff_distance(PointIt points_first, | ||
PointIt points_last, | ||
OffsetIt space_offsets_first, | ||
OffsetIt space_offsets_last, | ||
OutputIt distance_first, | ||
rmm::cuda_stream_view stream) | ||
{ | ||
using Point = typename std::iterator_traits<PointIt>::value_type; | ||
using Index = typename std::iterator_traits<OffsetIt>::value_type; | ||
using T = typename Point::value_type; | ||
using OutputT = typename std::iterator_traits<OutputIt>::value_type; | ||
|
||
static_assert(std::is_convertible_v<Point, cuspatial::vec_2d<T>>, | ||
"Input points must be convertible to cuspatial::vec_2d"); | ||
static_assert(detail::is_floating_point<T, OutputT>(), | ||
"Hausdorff supports only floating-point coordinates."); | ||
static_assert(std::is_integral_v<Index>, "Indices must be integral"); | ||
|
||
auto const num_points = std::distance(points_first, points_last); | ||
auto const num_spaces = std::distance(space_offsets_first, space_offsets_last); | ||
|
||
CUSPATIAL_EXPECTS(num_points >= num_spaces, "At least one point is required for each space"); | ||
CUSPATIAL_EXPECTS(num_spaces < (1 << 15), "Total number of spaces must be less than 2^16"); | ||
|
||
auto const num_results = num_spaces * num_spaces; | ||
|
||
if (num_results > 0) { | ||
// Due to hausdorff kernel using `atomicMax` for output, the output must be initialized to <= 0 | ||
// here the output is being initialized to -1, which should always be overwritten. If -1 is | ||
// found in the output, there is a bug where the output is not being written to in the hausdorff | ||
// kernel. | ||
thrust::fill_n(rmm::exec_policy(stream), distance_first, num_results, -1); | ||
|
||
auto const threads_per_block = 64; | ||
auto const num_tiles = (num_points + threads_per_block - 1) / threads_per_block; | ||
|
||
detail::kernel_hausdorff<T, decltype(num_points)> | ||
<<<num_tiles, threads_per_block, 0, stream.value()>>>( | ||
num_points, points_first, num_spaces, space_offsets_first, distance_first); | ||
|
||
CUSPATIAL_CUDA_TRY(cudaGetLastError()); | ||
} | ||
|
||
return distance_first + num_results; | ||
isVoid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
} // namespace cuspatial |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cudf/types.hpp> | ||
#include <memory> | ||
|
||
#include <rmm/cuda_stream_view.hpp> | ||
|
||
namespace cuspatial { | ||
|
||
/** | ||
* @ingroup distance | ||
* @brief Computes Hausdorff distances for all pairs in a collection of spaces | ||
harrism marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* https://en.wikipedia.org/wiki/Hausdorff_distance | ||
* | ||
* Example in 1D (this function operates in 2D): | ||
* ``` | ||
* spaces | ||
* [0 2 5] [9] [3 7] | ||
* | ||
* spaces represented as points per space and concatenation of all points | ||
* [0 2 5 9 3 7] [3 1 2] | ||
* | ||
* note: the following matrices are visually separated to highlight the relationship of a pair of | ||
* points with the pair of spaces from which it is produced | ||
* | ||
* cartesian product of all | ||
* points by pair of spaces distance between points | ||
* +----------+----+-------+ +---------+---+------+ | ||
* : 00 02 05 : 09 : 03 07 : : 0 2 5 : 9 : 3 7 : | ||
* : 20 22 25 : 29 : 23 27 : : 2 0 3 : 7 : 1 5 : | ||
* : 50 52 55 : 59 : 53 57 : : 5 3 0 : 4 : 2 2 : | ||
* +----------+----+-------+ +---------+---+------+ | ||
* : 90 92 95 : 99 : 93 97 : : 9 7 4 : 0 : 6 2 : | ||
* +----------+----+-------+ +---------+---+------+ | ||
* : 30 32 35 : 39 : 33 37 : : 3 1 2 : 6 : 0 4 : | ||
* : 70 72 75 : 79 : 73 77 : : 7 5 2 : 2 : 4 0 : | ||
* +----------+----+-------+ +---------+---+------+ | ||
* minimum distance from | ||
* every point in one Hausdorff distance is | ||
* space to any point in the maximum of the | ||
* the other space minimum distances | ||
* +----------+----+-------+ +---------+---+------+ | ||
* : 0 : 9 : 3 : : 0 : 9 : 3 : | ||
* : 0 : 7 : 1 : : : : : | ||
* : 0 : 4 : 2 : : : : : | ||
* +----------+----+-------+ +---------+---+------+ | ||
* : 4 : 0 : 2 : : 4 : 0 : 2 : | ||
* +----------+----+-------+ +---------+---+------+ | ||
* : 1 : 6 : 0 : : : 6 : 0 : | ||
* : 2 : 2 : 0 : : 2 : : : | ||
* +----------+----+-------+ +---------+---+------+ | ||
* | ||
* returned as concatenation of columns | ||
* [0 2 4 3 0 2 9 6 0] | ||
* ``` | ||
* | ||
* @param[in] points_first: xs: beginning of range of (x,y) points | ||
* @param[in] points_lasts: xs: end of range of (x,y) points | ||
* @param[in] space_offsets_first: beginning of range of indices to each space. | ||
* @param[in] space_offsets_first: end of range of indices to each space. Last index is the last | ||
* @param[in] distance_first: beginning of range of output Hausdorff distance for each pair of | ||
* spaces | ||
* | ||
* @tparam PointIt Iterator to input points. Points must be of a type that is convertible to | ||
* `cuspatial::vec_2d<T>`. Must meet the requirements of [LegacyRandomAccessIterator][LinkLRAI] and | ||
* be device-accessible. | ||
* @tparam OffsetIt Iterator to space offsets. Value type must be integral. Must meet the | ||
* requirements of [LegacyRandomAccessIterator][LinkLRAI] and be device-accessible. | ||
* @tparam OutputIt Output iterator. Must meet the requirements of | ||
* [LegacyRandomAccessIterator][LinkLRAI] and be device-accessible and mutable. | ||
* | ||
* @pre All iterators must have the same underlying floating-point value type. | ||
* | ||
* @return Output iterator to the element past the last distance computed. | ||
* | ||
* @note Hausdorff distances are asymmetrical | ||
*/ | ||
template <class PointIt, class OffsetIt, class OutputIt> | ||
OutputIt directed_hausdorff_distance(PointIt points_first, | ||
PointIt points_last, | ||
OffsetIt space_offsets_first, | ||
OffsetIt space_offsets_last, | ||
OutputIt distance_first, | ||
rmm::cuda_stream_view stream = rmm::cuda_stream_default); | ||
|
||
} // namespace cuspatial | ||
|
||
#include <cuspatial/experimental/detail/hausdorff.cuh> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
namespace cuspatial { | ||
namespace detail { | ||
|
||
|
@@ -30,7 +32,7 @@ namespace detail { | |
* @param val The value to compare | ||
* @return The old value stored in `addr`. | ||
*/ | ||
__device__ double atomicMin(double* addr, double val) | ||
__device__ inline double atomicMin(double* addr, double val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we hold on to these changes until #561 is merged? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I needed them to get this PR working. |
||
{ | ||
unsigned long long int* address_as_ll = reinterpret_cast<unsigned long long int*>(addr); | ||
unsigned long long int old = __double_as_longlong(*addr); | ||
|
@@ -60,7 +62,7 @@ __device__ double atomicMin(double* addr, double val) | |
* @param val The value to compare | ||
* @return The old value stored in `addr`. | ||
*/ | ||
__device__ float atomicMin(float* addr, float val) | ||
__device__ inline float atomicMin(float* addr, float val) | ||
{ | ||
unsigned int* address_as_ui = reinterpret_cast<unsigned int*>(addr); | ||
unsigned int old = __float_as_uint(*addr); | ||
|
@@ -90,7 +92,7 @@ __device__ float atomicMin(float* addr, float val) | |
* @param val The value to compare | ||
* @return The old value stored in `addr`. | ||
*/ | ||
__device__ double atomicMax(double* addr, double val) | ||
__device__ inline double atomicMax(double* addr, double val) | ||
{ | ||
unsigned long long int* address_as_ll = reinterpret_cast<unsigned long long int*>(addr); | ||
unsigned long long int old = __double_as_longlong(*addr); | ||
|
@@ -120,7 +122,7 @@ __device__ double atomicMax(double* addr, double val) | |
* @param val The value to compare | ||
* @return The old value stored in `addr`. | ||
*/ | ||
__device__ float atomicMax(float* addr, float val) | ||
__device__ inline float atomicMax(float* addr, float val) | ||
{ | ||
unsigned int* address_as_ui = reinterpret_cast<unsigned int*>(addr); | ||
unsigned int old = __float_as_uint(*addr); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we usually omit the template parameters that are auto deduced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that I noticed. Why omit it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, I approved too soon. The docs for template parameter here should be added.