Skip to content

Commit

Permalink
Creating mesh shader for ocean. Sloppy and not used but it's a start
Browse files Browse the repository at this point in the history
  • Loading branch information
Honeybunch committed Nov 21, 2024
1 parent 7f29f4b commit c95c326
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 2 deletions.
15 changes: 14 additions & 1 deletion addons/water/include/tb_ocean.slangh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
typedef float4 TbOceanWave; // xy = dir, z = steep, w = wavelength

TB_GPU_STRUCT_DECL_NOPACK(TbOceanData, {
float4 time_waves; // x = time, y = wave count
uint32_t mesh_idx; // Index of the ocean patch mesh in the mesh list
float4 time_waves; // x = time, y = wave count
TbOceanWave wave[TB_WAVE_MAX];
});
TB_GPU_STRUCT_DECL_NOPACK(TbOceanDrawData, {
float4 instance_pos;
});
TB_GPU_STRUCT_DECL(TbOceanPushConstants, {
float4x4 m;
});
Expand All @@ -37,6 +41,15 @@ _Static_assert(sizeof(TbOceanPushConstants) <= TB_PUSH_CONSTANT_BYTES,
[[vk::push_constant]] \
ConstantBuffer<TbOceanPushConstants> consts

#define OCEAN_DRAW_SET(b) \
[[vk::binding(0, b)]] \
StructuredBuffer<TbOceanDrawData> draw_data;

TbOceanDrawData tb_get_ocean_draw_data(int32_t draw,
StructuredBuffer<TbOceanDrawData> data) {
return data[draw];
}

void gerstner_wave(TbOceanWave wave, float time, inout float3 pos,
inout float3 tangent, inout float3 binormal) {
float steepness = wave.z;
Expand Down
183 changes: 183 additions & 0 deletions addons/water/source/tb_ocean_two.slangm
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Adapted heavily from https://catlikecoding.com/unity/tutorials/flow/waves/

#include "tb_gltf.slangh"
#include "tb_lighting.slangh"
#include "tb_ocean.slangh"
#include "tb_pbr.slangh"

OCEAN_SET(0);
OCEAN_DRAW_SET(1);
GLTF_VIEW_SET(2);
TB_MESHLET_SET(3);
TB_MESHLET_TRI_SET(4);
TB_MESHLET_VERT_SET(5);
TB_POS_SET(6);

// TODO: Can we drive these from specialization constants?
#define NUM_THREADS_X TB_MESHLET_THREADS
#define NUM_THREADS_Y 1
#define NUM_THREADS_Z 1

struct VertexOut {
float4 clip_pos : SV_POSITION;
float3 world_pos : POSITION0;
float3 view_pos : POSITION1;
float4 screen_pos : POSITION2;
float3 tangent : TANGENT0;
float3 binormal : BINORMAL0;
float4 clip : TEXCOORD0;
};

[shader("mesh")]
[outputtopology("triangle")]
[numthreads(NUM_THREADS_X, NUM_THREADS_Y, NUM_THREADS_Z)]
void mesh(uint32_t gtid: SV_GroupThreadID, uint32_t gid: SV_GroupID,
out vertices VertexOut verts[TB_MESHLET_MAX_VERTICES],
out indices uint3 triangles[TB_MESHLET_MAX_TRIANGLES]) {
uint32_t draw_idx = tb_get_draw_index();
uint32_t mesh_idx = ocean_data.mesh_idx;
TbOceanDrawData draw = tb_get_ocean_draw_data(draw_idx, draw_data);
uint32_t meshlet_idx = gid;
TbMeshlet meshlet = tb_get_meshlet(meshlets[mesh_idx], meshlet_idx);

// Set number of outputs
SetMeshOutputCounts(meshlet.vert_count, meshlet.prim_count);

if (gtid < meshlet.vert_count) {
uint32_t idx = tb_get_meshlet_vertex(meshlet_verts[mesh_idx], meshlet, gtid);

int3 local_pos =
tb_vert_get_local_pos(TB_INPUT_PERM_POSITION, idx, mesh_idx, pos_buffers);
float3 pos = mul(consts.m, float4(local_pos, 1)).xyz + draw.instance_pos.xyz;

float3 tangent = float3(0, 0, 1);
float3 binormal = float3(1, 0, 0);
pos = calc_wave_pos(pos, ocean_data, tangent, binormal);
float4 world_pos = float4(pos, 1.0);
float4 clip_pos = mul(camera_data.vp, world_pos);

VertexOut vert = {};
vert.clip_pos = clip_pos;
vert.world_pos = world_pos.xyz;
vert.view_pos = mul(camera_data.v, world_pos).xyz;
vert.screen_pos = clip_to_screen(clip_pos);
vert.tangent = tangent;
vert.binormal = binormal;
vert.clip = clip_pos;
verts[gtid] = vert;
}

if (gtid < meshlet.prim_count) {
triangles[gtid] =
tb_get_meshlet_primitive(meshlet_prims[mesh_idx], meshlet, gtid);
}
}

float4 frag(VertexOut vert) : SV_TARGET {
float3 view_dir_vec = camera_data.view_pos - vert.world_pos;

// Calculate normal after interpolation
float3 N = normalize(cross(normalize(vert.tangent), normalize(vert.binormal)));
float3 V = normalize(view_dir_vec);
float3 R = reflect(-V, N);
float3 L = light_data.light_dir;
float2 screen_uv = (vert.clip.xy / vert.clip.w) * 0.5 + 0.5;

float3 albedo = float3(0, 0, 0);

float2 uv = (vert.screen_pos.xy) / vert.screen_pos.w;

// Underwater fog
{
const float near = camera_data.proj_params.x;
const float far = camera_data.proj_params.y;
// TODO: Paramaterize
const float fog_density = 0.078f;
const float3 fog_color = float3(0.095, 0.163, 0.282);

// World position depth
float scene_eye_depth =
linear_depth(depth_map.Sample(material_sampler, uv).r, near, far);
float fragment_eye_depth = -vert.view_pos.z;
float3 world_pos = camera_data.view_pos -
((view_dir_vec / fragment_eye_depth) * scene_eye_depth);
float depth_diff = world_pos.y - vert.world_pos.y;

float fog = saturate(exp(fog_density * depth_diff));
float3 background_color = color_map.Sample(material_sampler, uv).rgb;
albedo = lerp(fog_color, background_color, fog);
}

// Add a bit of a fresnel effect
{
// TODO: Parameterize
const float horizon_dist = 5.0f;
const float3 horizon_color = float3(0.8, 0.9, 0.8);

float fresnel = dot(N, V);
fresnel = pow(saturate(1 - fresnel), horizon_dist);
albedo = lerp(albedo, horizon_color, fresnel);
}

// PBR Lighting
float3 color = float3(0, 0, 0);
{
float metallic = 0.0;
float roughness = 0.0;

// Calculate shadow first
float shadow = 1.0f;

{
Light l;
l.light = light_data;
l.shadow_map = shadow_map;
l.shadow_sampler = shadow_sampler;

Surface s;
s.base_color = float4(albedo, 1);
s.view_pos = vert.view_pos;
s.world_pos = vert.world_pos;
s.screen_uv = screen_uv;
s.metallic = metallic;
s.roughness = roughness;
s.N = N;
s.V = V;
s.R = R;
s.emissives = 0;
shadow = shadow_visibility(l, s);
}

// Lighting
{
float2 brdf =
brdf_lut.Sample(brdf_sampler, float2(max(dot(N, V), 0.0), roughness))
.rg;
float3 reflection = prefiltered_reflection(
prefiltered_map, filtered_env_sampler, R, roughness);
float3 irradiance =
irradiance_map.SampleLevel(filtered_env_sampler, N, 0).rgb;
color = pbr_lighting(shadow, 1, albedo, metallic, roughness, brdf,
reflection, irradiance, light_data.color, L, V, N);
}
}

// Subsurface Scattering
if (L.y > 0) {
float distortion = 0.4f;
float power = 2.0f;
float scale = 4.0f;
float3 attenuation = 0.3f;
float3 ambient = 0.1f;
float3 sss_color = float3(0.13f, 0.69f, 0.67f);
// Without handling thickness

float3 H = normalize(L + N * distortion);
float VdotH = pow(saturate(dot(V, -H)), power) * scale;
float3 I = attenuation * (VdotH * ambient);

color += (sss_color * light_data.color * I);
}

return float4(color, 1);
}
1 change: 1 addition & 0 deletions include/tb_common.slangh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "tb_pi.h"
#include "tb_simd.h"
#include "tb_meshlet.h"
#include "tb_intrin.slangh"

#define TB_PUSH_CONSTANT_BYTES 128

Expand Down
4 changes: 4 additions & 0 deletions include/tb_intrin.slangh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#ifdef TB_SHADER

// See https://github.com/shader-slang/slang/issues/4352
uint32_t tb_get_draw_index() {
return spirv_asm {
Expand All @@ -8,3 +10,5 @@ uint32_t tb_get_draw_index() {
result:$$uint = OpLoad builtin(DrawIndex:uint);
};
}

#endif
1 change: 0 additions & 1 deletion source/tb_gltf_two.slangm
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "tb_gltf.slangh"
#include "tb_intrin.slangh"
#include "tb_lighting.slangh"

GLTF_VIEW_SET(0)
Expand Down

0 comments on commit c95c326

Please sign in to comment.