diff --git a/jax_md/partition.py b/jax_md/partition.py index 0760d8f1..f1aca0e3 100644 --- a/jax_md/partition.py +++ b/jax_md/partition.py @@ -206,11 +206,7 @@ def _estimate_cell_capacity(position: Array, return int(cell_capacity * buffer_size_multiplier) -<<<<<<< Updated upstream -def _shift_array(arr: Array, dindex: Array) -> Array: -======= def shift_array(arr: Array, dindex: Array) -> Array: ->>>>>>> Stashed changes if len(dindex) == 2: dx, dy = dindex dz = 0 @@ -235,11 +231,7 @@ def shift_array(arr: Array, dindex: Array) -> Array: return arr -<<<<<<< Updated upstream -def _unflatten_cell_buffer(arr: Array, -======= def unflatten_cell_buffer(arr: Array, ->>>>>>> Stashed changes cells_per_side: Array, dim: int) -> Array: if (isinstance(cells_per_side, int) or @@ -388,23 +380,14 @@ def cell_list_fn(position: Array, cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) -<<<<<<< Updated upstream - cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) - cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) -======= cell_position = unflatten_cell_buffer(cell_position, cells_per_side, dim) cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim) ->>>>>>> Stashed changes for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1,)) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) -<<<<<<< Updated upstream - cell_kwargs[k] = _unflatten_cell_buffer( -======= cell_kwargs[k] = unflatten_cell_buffer( ->>>>>>> Stashed changes cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)