Skip to content

Commit

Permalink
revise operator using latest string lib
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 27, 2024
1 parent 46bede7 commit da96b48
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 308 deletions.
49 changes: 17 additions & 32 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,23 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
// TODO: remove output_indices.
const auto& output_indices = shader.AddIndices("output_indices", false);
const auto interleaved_str = interleaved_ ? "true" : "false";
shader.SetMainFunctionBody(
" let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n"
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n",
" if (global_idx >= size) { return; }\n"
" if (bsnh[3] < half_rotary_emb_dim) {\n"
" let position_ids_idx = " +
position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" +
" let position_id = u32(" +
position_ids.GetByOffset("position_ids_idx") + ")" +
" + select(0, bsnh[1], position_ids_idx == 0);\n"
" let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " +
interleaved_str +
");\n"
" let j = i + select(half_rotary_emb_dim, 1, " +
interleaved_str +
");\n"
" let re = " +
input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + "-" +
input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + ";\n" +
" " + output.SetByOffset("i", "re") + "\n" +
" let im = " + input.GetByOffset("i") + " * " +
sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") +
"+ " + input.GetByOffset("j") +
" * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") +
";\n " + output.SetByOffset("j", "im") +
"\n"
" } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" +
" " + output.SetByOffset("k", input.GetByOffset("k")) +
"\n"
" }");
shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n"
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n"
" if (global_idx >= size) { return; }\n"
" if (bsnh[3] < half_rotary_emb_dim) {\n"
<< " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n"

Check warning on line 37 in onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc:37: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n"

Check warning on line 38 in onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc:38: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n"

Check warning on line 39 in onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc:39: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"

Check warning on line 41 in onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc:41: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"

Check warning on line 43 in onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc:43: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " " << output.SetByOffset("j", "im") << "\n"
<< " } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
<< " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n"
<< " }";

return Status::OK();
}
Expand Down
Loading

0 comments on commit da96b48

Please sign in to comment.