Skip to content

Commit

Permalink
add sph_base
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoon-yoon committed Feb 21, 2023
1 parent e59aef4 commit c444884
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 4 deletions.
3 changes: 2 additions & 1 deletion data/scenes/dragon_bath.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
"exponent": 7,
"exportFrame": false,
"exportPly": false,
"exportObj": false
"exportObj": false,
"dt": 1e-4
},
"RigidBodies": [
{
Expand Down
27 changes: 24 additions & 3 deletions particle_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,23 @@ def add_cube(self, object_id, box_start, box_end, velocity, density, color, is_d
density_arr, pressure_arr, material_arr, color_arr, is_dynamic_arr)

@ti.func
def get_grid_idx_from_pos(self, position):
grid_idx = (position / self.grid_size).cast(ti.i32) # floor operation
def pos2index(self, position):
return (position / self.grid_size).cast(ti.i32)

@ti.func
def flatten_grid_index(self, grid_idx):
if self.dim == 3:
flatten_grid_idx = grid_idx[0] * self.grid_num[1] * self.grid_num[2] + grid_idx[1] * self.grid_num[2] + \
grid_idx[2]
else:
flatten_grid_idx = grid_idx[0] * self.grid_num[1] + grid_idx[1]
return flatten_grid_idx

@ti.func
def get_grid_idx_from_pos(self, position):
grid_idx = self.pos2index(position) # floor operation
return self.flatten_grid_index(grid_idx)

@ti.kernel
def update_grid_id(self):
self.counting_sort_accumulatedArray.fill(0)
Expand All @@ -337,7 +345,8 @@ def update_grid_id(self):
def counting_sort(self):
for i in range(self.total_particle_num):
grid_idx = self.grid_id[i]
self.grid_id_for_sort[i] = ti.atomic_sub(self.counting_sort_accumulatedArray[grid_idx], 1) - 1
base_offset = 0 if grid_idx == 0 else self.counting_sort_accumulatedArray[grid_idx - 1]
self.grid_id_for_sort[i] = ti.atomic_sub(self.counting_sort_countArray[grid_idx], 1) + base_offset - 1
for i in self.grid_id_for_sort:
new_idx = self.grid_id_for_sort[i]
self.grid_id_buffer[new_idx] = self.grid_id[i]
Expand Down Expand Up @@ -375,6 +384,18 @@ def counting_sort(self):
self.color[i] = self.color_buffer[i]
self.is_dynamic[i] = self.is_dynamic_buffer[i]

@ti.func
def for_all_neighbors(self, idx_i, task: ti.template(), ret: ti.template()):
center_cell_grid_idx = self.pos2index(self.position[idx_i])
for offset in ti.grouped(ti.ndrange(*(((-1, 2),) * self.dim))):
neighbor_grid_flatten_idx = self.flatten_grid_index(offset + center_cell_grid_idx)
start_idx = 0 if neighbor_grid_flatten_idx == 0 else self.counting_sort_accumulatedArray[
neighbor_grid_flatten_idx - 1]
# TODO: can we somewhat modify to enable using ti.static?
for idx_j in range(start_idx, self.counting_sort_accumulatedArray[neighbor_grid_flatten_idx]):
if idx_i[0] != idx_j and (self.position[idx_i] - self.position[idx_j]).norm() < self.support_length:
task(idx_i, idx_j, ret)

def update_particle_system(self):
self.update_grid_id()
self.prefix_sum_executor.run(self.counting_sort_accumulatedArray)
Expand Down
41 changes: 41 additions & 0 deletions sph_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import taichi as ti
import numpy as np


@ti.data_oriented
class SPHBase:
def __init__(self, particle_system):
self.ps = particle_system
self.g = self.ps.config['gravitation']
self.dt = ti.field(ti.f32, shape=())
self.dt[None] = self.ps.config['dt']

@ti.func
def cubic_spline_kernel(self, r_norm):
h = self.ps.support_length # 4r
coeff = 8 / np.pi if self.ps.dim == 3 else 40 / 7 / np.pi
coeff /= (h ** self.ps.dim)
q = r_norm / h
kernel_val = 0.0
if q <= 1.0:
if q <= 0.5:
kernel_val = coeff * (1 - 6 * (q ** 2) + 6 * (q ** 3))
else:
kernel_val = coeff * (2 * (1 - q) ** 3)
return kernel_val

@ti.func
def cubic_spline_kernel_derivative(self, r):
h = self.ps.support_length
coeff = 16 / np.pi if self.ps.dim == 3 else 80 / 7 / np.pi
coeff /= (h ** (self.ps.dim + 1))
derivative = ti.Vector.zero(ti.f32, self.ps.dim)
r_norm = r.norm()
r_hat = r / (r_norm + 1e-6)
q = r_norm / h
if q <= 1.0:
if q <= 0.5:
derivative = coeff * (9 * q ** 2 - 6 * q) * r_hat
else:
derivative = coeff * (-3 * (1 - q) ** 2) * r_hat
return derivative

0 comments on commit c444884

Please sign in to comment.