From 8e99afdbfa80092fddf7fe4fe2dd3eb6d3b1489d Mon Sep 17 00:00:00 2001 From: leonardodalinky Date: Sun, 1 Sep 2024 02:30:19 +0800 Subject: [PATCH] fix: forced memory layout --- python/fpsample/wrapper.py | 11 +++++------ src/bucket_fps/mod.rs | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/fpsample/wrapper.py b/python/fpsample/wrapper.py index 44c52b0..ec4ce8e 100644 --- a/python/fpsample/wrapper.py +++ b/python/fpsample/wrapper.py @@ -30,9 +30,8 @@ def fps_sampling(pc: np.ndarray, n_samples: int, start_idx: Optional[int] = None assert ( start_idx is None or 0 <= start_idx < n_pts ), "start_idx should be None or 0 <= start_idx < n_pts" - pc = pc.astype(np.float32) # best performance with fortran array - pc = np.asfortranarray(pc) + pc = np.asfortranarray(pc, dtype=np.float32) # Random pick a start start_idx = np.random.randint(low=0, high=n_pts) if start_idx is None else start_idx return _fps_sampling(pc, n_samples, start_idx) @@ -60,7 +59,7 @@ def fps_npdu_sampling( assert ( start_idx is None or 0 <= start_idx < n_pts ), "start_idx should be None or 0 <= start_idx < n_pts" - pc = pc.astype(np.float32) + pc = np.ascontiguousarray(pc, dtype=np.float32) w = w or int(n_pts / n_samples * 16) if w >= n_pts - 1: warnings.warn(f"k is too large, set to {n_pts - 1}") @@ -93,7 +92,7 @@ def fps_npdu_kdtree_sampling( assert ( start_idx is None or 0 <= start_idx < n_pts ), "start_idx should be None or 0 <= start_idx < n_pts" - pc = pc.astype(np.float32) + pc = np.ascontiguousarray(pc, dtype=np.float32) w = w or int(n_pts / n_samples * 16) if w >= n_pts: warnings.warn(f"k is too large, set to {n_pts}") @@ -123,7 +122,7 @@ def bucket_fps_kdtree_sampling( assert ( start_idx is None or 0 <= start_idx < n_pts ), "start_idx should be None or 0 <= start_idx < n_pts" - pc = pc.astype(np.float32) + pc = np.ascontiguousarray(pc, dtype=np.float32) # Random pick a start start_idx = np.random.randint(low=0, high=n_pts) if start_idx is None else start_idx return _bucket_fps_kdtree_sampling(pc, n_samples, start_idx) @@ -155,7 +154,7 @@ def bucket_fps_kdline_sampling( assert ( start_idx is None or 0 <= start_idx < n_pts ), "start_idx should be None or 0 <= start_idx < n_pts" - pc = pc.astype(np.float32) + pc = np.ascontiguousarray(pc, dtype=np.float32) # Random pick a start start_idx = np.random.randint(low=0, high=n_pts) if start_idx is None else start_idx return _bucket_fps_kdline_sampling(pc, n_samples, h, start_idx) diff --git a/src/bucket_fps/mod.rs b/src/bucket_fps/mod.rs index 736697a..dc9c99d 100644 --- a/src/bucket_fps/mod.rs +++ b/src/bucket_fps/mod.rs @@ -7,12 +7,12 @@ pub fn bucket_fps_kdtree_sampling( start_idx: usize, ) -> Array1 { let[p, c] = points.shape() else {panic !("points must be a 2D array")}; - let raw_data = points.as_standard_layout().as_ptr(); + let raw_data = points.as_standard_layout(); let mut sampled_point_indices = vec![0; n_samples]; let ret_code; unsafe { ret_code = ffi::bucket_fps_kdtree( - raw_data, + raw_data.as_ptr(), *p, *c, n_samples, @@ -33,12 +33,12 @@ pub fn bucket_fps_kdline_sampling( start_idx: usize, ) -> Array1 { let[p, c] = points.shape() else {panic !("points must be a 2D array")}; - let raw_data = points.as_standard_layout().as_ptr(); + let raw_data = points.as_standard_layout(); let mut sampled_point_indices = vec![0; n_samples]; let ret_code; unsafe { ret_code = ffi::bucket_fps_kdline( - raw_data, + raw_data.as_ptr(), *p, *c, n_samples,