Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Lang] Merge irpass::half2_vectorize() with irpass::scalarize() (#8102)
Issue: # ### Brief Summary <!-- copilot:summary --> ### <samp>🤖 Generated by Copilot at 44b862c</samp> This pull request enhances the support and optimization for matrices and vectors in the IR and the code generation, especially for f16 data types. It adds `ndarray` methods to `MatrixType` and `VectorType` classes, fixes code generation bugs and data flow analysis for matrix and vector operations, simplifies and improves the scalarization and vectorization of matrices in CUDA offloads, and adds a half2 vectorization optimization for f16 matrices and vectors. It also updates the `IRBuilder` class, the `Ndarray` class, and the test cases accordingly. ### Walkthrough <!-- copilot:walkthrough --> ### <samp>🤖 Generated by Copilot at 44b862c</samp> * Add `ndarray` methods to `MatrixType` and `VectorType` classes to create `Matrix` and `Vector` objects from IR types ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5913c0a6b6a5e279414150955f30b96ea6b9676a1f5b1931ca4bcb39f19c81e9R1504-R1508),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5913c0a6b6a5e279414150955f30b96ea6b9676a1f5b1931ca4bcb39f19c81e9R1606-R1608)) * Fix bug in `defined` function that used incorrect type to check signedness of binary operands involving `MatrixType` or `VectorType` ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313L614-R614),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313L661-R661)) * Modify `reaching_definition_analysis`, `live_variable_analysis` and `dead_store_elimination` functions to handle `MatrixPtrStmt` as local variables in data dependency analysis ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL387-R390),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR558-R559),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR584-R585),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR717-R718),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR818-R819),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR846-R847)) * Add `taichi/ir/statements.h` header file to `taichi/ir/ir_builder.h` and change return type of `get_constant` method to `Stmt *` to handle `MatrixType` and `VectorType` constants ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-1894085b261e833e3e66924fc5b1cf63b9dd8b8aa0b3e78ec64366396131470dR5),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-1894085b261e833e3e66924fc5b1cf63b9dd8b8aa0b3e78ec64366396131470dL140-R141)) * Add `half2_optimization_enabled` parameter to `scalarize` function and `Scalarize` class to control half2 vectorization optimization for `MatrixType` and `VectorType` operands with two f16 elements ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L33-R33),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L19-R23),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L417-R503),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L841-R928),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1195-R1285)) * Add `fp16.h` header file to `taichi/program/ndarray.cpp` and modify `read` and `write` methods of `Ndarray` class to handle f16 data types correctly ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-c88c6764ffa952681c8b0db12b376c473a8422cb7bb0243a10cc643cc245a5a1R5),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-c88c6764ffa952681c8b0db12b376c473a8422cb7bb0243a10cc643cc245a5a1L171-R185)) * Remove redundant call to `scalarize` function with `config.real_matrix_scalarize` flag and modify call to `scalarize` function with `half2_optimization_enabled` flag in `offload_to_executable` function ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL234-L241),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL296-R297)) * Add `transform_pow_op_impl` method to `DemoteOperations` class to transform power operation with scalar operands and modify `visit` method to handle `floordiv`, `bit_sar` and `pow` operations with scalar and tensor operands ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR19-R129),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR135-R146),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR165),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL67-R198),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL99-R228),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL109-L150),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL159-R254),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL176-R265),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL191-R287)) * Modify `half2_vectorization_test.cpp` to use tensor type operands with two f16 elements and call `scalarize` function with `half2_optimization_enabled` flag instead of `vectorize_half2` function ([link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L34-R68),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L105-R88),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L144-R115),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L187-R146),[link](https://github.com/taichi-dev/taichi/pull/8102/files?diff=unified&w=0#diff-5136d70f7a32456bee3938daca1066aa3d380aecd7d18257fc893b83dfd72a79L234-L258))
- Loading branch information