diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 16db75a75fdca..0c144f2eb4082 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,5 +1,5 @@ --- -name: Bug report +name: Bug Report about: Create a report to help us improve title: '' labels: potential bug diff --git a/.github/ISSUE_TEMPLATE/document.md b/.github/ISSUE_TEMPLATE/document.md new file mode 100644 index 0000000000000..ddb5001154da3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/document.md @@ -0,0 +1,12 @@ +--- +name: Doc +about: Have doubts or suggestions about Taichi Docs +title: '' +labels: doc +assignees: '' + +--- + + diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 9a7b1c1797446..1bf3415e010e7 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,5 +1,5 @@ --- -name: Feature request +name: Feature Request about: Suggest an idea for this project title: '' labels: feature request diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000000000..accca01a54df9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,17 @@ +--- +name: Ask a Question +about: Ask anything about Taichi +title: '' +labels: question +assignees: '' + +--- + + diff --git a/.github/workflows/cancel.yml b/.github/workflows/cancel.yml deleted file mode 100644 index da36ae34e9405..0000000000000 --- a/.github/workflows/cancel.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Cancel -on: - workflow_run: - workflows: ["Presubmit Checks"] - types: - - requested - -jobs: - cancel: - runs-on: ubuntu-latest - steps: - - uses: styfle/cancel-workflow-action@0.9.0 - with: - workflow_id: ${{ github.event.workflow.id }} diff --git a/.github/workflows/issue_comment.yml b/.github/workflows/issue_comment.yml index 7398d128a61fe..5f8f46081f383 100644 --- a/.github/workflows/issue_comment.yml +++ b/.github/workflows/issue_comment.yml @@ -18,3 +18,5 @@ jobs: commands: | format benchmark + rebase + rerun diff --git a/.github/workflows/performance_monitoring.yml b/.github/workflows/performance_monitoring.yml new file mode 100644 index 0000000000000..e2bbded40c117 --- /dev/null +++ b/.github/workflows/performance_monitoring.yml @@ -0,0 +1,33 @@ +name: Performance Monitoring +on: + push: + branches: + - master + +jobs: + gpu_backends: + name: Performance monitoring (NVIDIA GPU) + timeout-minutes: 60 + # Disable this workflow on forks + if: github.repository_owner == 'taichi-dev' + runs-on: [self-hosted, x64, cuda, linux, benchmark] + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - name: Build & Install + run: | + .github/workflows/scripts/unix_build.sh + python3 -m pip install dist/*.whl + + - name: Run performance-monitoring + run: | + cd .. + rm -rf performance-monitoring + git clone git@github.com:taichi-dev/performance-monitoring.git + cd performance-monitoring + export WORKFLOW_MODE=postsubmit + ./run.sh + env: + GITHUB_CONTEXT: ${{ toJson(github) }} diff --git a/.github/workflows/postsubmit.yml b/.github/workflows/postsubmit.yml deleted file mode 100644 index e7ad3b5bebc28..0000000000000 --- a/.github/workflows/postsubmit.yml +++ /dev/null @@ -1,233 +0,0 @@ -name: Postsubmit Checks -on: - push: - branches: - - master - -jobs: - build_and_test_cpu_required: - # This job will be required to pass before merging to master branch. - name: Required Build and Test (CPU) - timeout-minutes: 60 - strategy: - matrix: - include: - - os: ubuntu-latest - python: 3.6 - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - - name: Download Pre-Built LLVM 10.0.0 - run: | - python misc/ci_download.py - mkdir taichi-llvm - cd taichi-llvm - unzip ../taichi-llvm.zip - env: - CI_PLATFORM: ${{ matrix.os }} - - - name: Build & Install - run: .github/workflows/scripts/unix_build.sh - env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=ON -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON - CXX: clang++ - - - name: Test - run: .github/workflows/scripts/unix_test.sh - - build_and_test_cpu: - name: Build and Test (CPU) - needs: build_and_test_cpu_required - timeout-minutes: 60 - strategy: - matrix: - include: - - os: macos-latest - python: 3.7 - with_cc: OFF - with_cpp_tests: ON - - os: ubuntu-latest - python: 3.9 - with_cc: OFF - with_cpp_tests: OFF - - os: ubuntu-latest - python: 3.8 - with_cc: ON - with_cpp_tests: OFF - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - - name: Download Pre-Built LLVM 10.0.0 - run: | - python misc/ci_download.py - mkdir taichi-llvm - cd taichi-llvm - unzip ../taichi-llvm.zip - env: - CI_PLATFORM: ${{ matrix.os }} - - - name: Build & Install - run: .github/workflows/scripts/unix_build.sh - env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=${{ matrix.with_cc }} -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=${{ matrix.with_cpp_tests }} - CXX: clang++ - # [DEBUG] Copy this step around to enable debugging inside Github Action instances. - #- name: Setup tmate session - # uses: mxschmitt/action-tmate@v3 - # with: - # limit-access-to-actor: true - - - name: Test - run: .github/workflows/scripts/unix_test.sh - env: - RUN_CPP_TESTS: ${{ matrix.with_cpp_tests }} - - build_and_test_gpu_linux: - name: Build and Test (GPU) - runs-on: [self-hosted, cuda, vulkan, cn] - timeout-minutes: 60 - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - name: Build - run: | - export PATH=$PATH:/usr/local/cuda/bin - .github/workflows/scripts/unix_build.sh - env: - LLVM_LIB_ROOT_DIR: /opt/taichi-llvm-10.0.0 - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=ON -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=ON - BUILD_NUM_THREADS: 8 - LLVM_PATH: /opt/taichi-llvm-10.0.0/bin - LLVM_DIR: /opt/taichi-llvm-10.0.0/lib/cmake/llvm - CXX: clang++-8 - - - name: Test - run: .github/workflows/scripts/unix_test.sh - env: - DISPLAY: :1 - GPU_TEST: ON - - build_and_test_windows: - name: Build and Test (Windows) - runs-on: windows-latest - timeout-minutes: 60 - steps: - - - name: Install 7Zip PowerShell - shell: powershell - run: Install-Module 7Zip4PowerShell -Force -Verbose - - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - uses: actions/setup-python@v2 - with: - python-version: 3.7 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1.0.2 - - - name: Download And Install Vulkan - shell: powershell - run: | - Invoke-WebRequest -Uri "https://sdk.lunarg.com/sdk/download/1.2.189.0/windows/VulkanSDK-1.2.189.0-Installer.exe" -OutFile VulkanSDK.exe - $installer = Start-Process -FilePath VulkanSDK.exe -Wait -PassThru -ArgumentList @("/S"); - $installer.WaitForExit(); - - - name: Build - shell: powershell - run: | - $env:Path += ";C:/VulkanSDK/1.2.189.0/Bin" - cd C:\ - Remove-item alias:curl - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-msvc2019.zip -LO - 7z x taichi-llvm-10.0.0-msvc2019.zip -otaichi_llvm - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip -LO - 7z x clang-10.0.0-win.zip -otaichi_clang - $env:PATH = ";C:\taichi_llvm\bin;C:\taichi_clang\bin;" + $env:PATH - clang --version - cd D:\a\taichi\taichi - python -m pip install -r requirements_dev.txt - cd python - git fetch origin master - $env:TAICHI_CMAKE_ARGS = $env:CI_SETUP_CMAKE_ARGS - python build.py build - cd ..\dist - $env:WHL = $(dir *.whl) - python -m pip install $env:WHL - env: - PYTHON: C:\hostedtoolcache\windows\Python\3.7.9\x64\python.exe - CI_SETUP_CMAKE_ARGS: -G "Visual Studio 16 2019" -A x64 -DLLVM_DIR=C:\taichi_llvm\lib\cmake\llvm -DTI_WITH_VULKAN:BOOL=ON - VULKAN_SDK: C:/VulkanSDK/1.2.189.0 - - - name: Test - shell: powershell - run: | - $env:PATH = ";C:\taichi_llvm\bin;C:\taichi_clang\bin;" + $env:PATH - python -c "import taichi" - python examples/algorithm/laplace.py - python bin/taichi diagnose - python bin/taichi changelog - python bin/taichi test -vr2 -t2 - env: - PYTHON: C:\hostedtoolcache\windows\Python\3.7.9\x64\python.exe - - build_and_test_m1: - name: Build and Test (Apple M1) - timeout-minutes: 60 - strategy: - matrix: - include: - - os: macos-latest - python: 3.8 - defaults: - run: - # https://github.com/actions/runner/issues/805#issuecomment-844426478 - shell: "/usr/bin/arch -arch arm64e /bin/bash --noprofile --norc -eo pipefail {0}" - runs-on: [self-hosted, m1] - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - name: Build - run: | - python3 -m pip uninstall taichi -y - rm -rf $HOME/Library/Python/3.8/lib/python/site-packages/taichi - git --version - export CXX=clang++ - python3 -m pip install -r requirements_dev.txt - cd python - git fetch origin master - TAICHI_CMAKE_ARGS=$CI_SETUP_CMAKE_ARGS python3 build.py build - cd .. - export NUM_WHL=`ls dist/*.whl | wc -l` - if [ $NUM_WHL -ne 1 ]; then echo 'ERROR: created more than 1 whl.' && exit 1; fi - python3 -m pip install dist/*.whl - env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON - - - name: Test - run: | - export PATH=$PATH:$HOME/Library/Python/3.8/bin - python3 examples/algorithm/laplace.py - TI_LIB_DIR=`python3 -c "import taichi;print(taichi.__path__[0])" | tail -1` - TI_LIB_DIR="$TI_LIB_DIR/lib" ./build/taichi_cpp_tests - ti test -vr2 -t4 -x diff --git a/.github/workflows/presubmit.yml b/.github/workflows/presubmit.yml deleted file mode 100644 index 398bfed6631aa..0000000000000 --- a/.github/workflows/presubmit.yml +++ /dev/null @@ -1,298 +0,0 @@ -name: Presubmit Checks -on: - pull_request: - types: [opened, synchronize, reopened] - -jobs: - title_format: - name: Check PR Title - if: ${{ github.event.pull_request }} - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Run PR Title Checker - run: | - pip install semver GitPython - python misc/ci_check_pr_title.py "$PR_TITLE" - env: - PR_TITLE: ${{ github.event.pull_request.title }} - - check_code_format: - name: Check Code Format - runs-on: ubuntu-latest - # This job will be required to pass before merging to master branch. - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - - name: Setup git & clang-format - run: | - git config user.email "taichigardener@gmail.com" - git config user.name "Taichi Gardener" - git checkout -b _fake_squash - git remote add upstream https://github.com/taichi-dev/taichi.git - git fetch upstream master - sudo apt install clang-format-10 - - - name: Cache PIP - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ hashFiles('setup.py') }}-${{ hashFiles('requirements_dev.txt') }} - - - name: Install requirements - run: | - python3 -m pip install --user -r requirements_dev.txt - - - name: Check code format - run: | - python3 misc/code_format.py - git checkout -b _enforced_format - git commit -am "enforce code format" || true - # exit with 1 if there were differences: - git diff _fake_squash _enforced_format --exit-code - - - name: Pylint - run: | - # Make sure pylint doesn't regress - pylint python/taichi/ --disable=all --enable=C0121,C0415 - if [ $? -eq 0 ] - then - echo "PASSED: pylint is happy" - exit 0 - else - echo "FAILED: please run the pylint command above and make sure it passes" - exit 1 - fi - - build_and_test_cpu_required: - # This job will be required to pass before merging to master branch. - name: Required Build and Test (CPU) - needs: check_code_format - timeout-minutes: 60 - strategy: - matrix: - include: - - os: ubuntu-latest - python: 3.6 - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - - name: Download Pre-Built LLVM 10.0.0 - run: | - python misc/ci_download.py - mkdir taichi-llvm - cd taichi-llvm - unzip ../taichi-llvm.zip - env: - CI_PLATFORM: ${{ matrix.os }} - - - name: Build & Install - run: .github/workflows/scripts/unix_build.sh - env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=ON -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON - CXX: clang++ - - - name: Test - run: .github/workflows/scripts/unix_test.sh - env: - RUN_CPP_TESTS: ON - - build_and_test_cpu: - name: Build and Test (CPU) - needs: build_and_test_cpu_required - timeout-minutes: 60 - strategy: - matrix: - include: - - os: macos-latest - python: 3.7 - with_cc: OFF - with_cpp_tests: ON - - os: ubuntu-latest - python: 3.9 - with_cc: OFF - with_cpp_tests: OFF - - os: ubuntu-latest - python: 3.8 - with_cc: ON - with_cpp_tests: OFF - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - - name: Download Pre-Built LLVM 10.0.0 - run: | - python misc/ci_download.py - mkdir taichi-llvm - cd taichi-llvm - unzip ../taichi-llvm.zip - env: - CI_PLATFORM: ${{ matrix.os }} - - - name: Build & Install - run: .github/workflows/scripts/unix_build.sh - env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=${{ matrix.with_cc }} -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=${{ matrix.with_cpp_tests }} - CXX: clang++ - # [DEBUG] Copy this step around to enable debugging inside Github Action instances. - #- name: Setup tmate session - # uses: mxschmitt/action-tmate@v3 - # with: - # limit-access-to-actor: true - - - name: Test - run: .github/workflows/scripts/unix_test.sh - env: - RUN_CPP_TESTS: ${{ matrix.with_cpp_tests }} - - build_and_test_gpu_linux: - name: Build and Test (GPU) - needs: check_code_format - runs-on: [self-hosted, cuda, vulkan, cn] - timeout-minutes: 60 - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - name: Build - run: | - export PATH=$PATH:/usr/local/cuda/bin - .github/workflows/scripts/unix_build.sh - env: - LLVM_LIB_ROOT_DIR: /opt/taichi-llvm-10.0.0 - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=ON -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=ON -DTI_BUILD_TESTS:BOOL=ON - BUILD_NUM_THREADS: 8 - LLVM_PATH: /opt/taichi-llvm-10.0.0/bin - LLVM_DIR: /opt/taichi-llvm-10.0.0/lib/cmake/llvm - CXX: clang++-8 - - - name: Test - run: .github/workflows/scripts/unix_test.sh - env: - DISPLAY: :1 - GPU_TEST: ON - RUN_CPP_TESTS: ON - - build_and_test_windows: - name: Build and Test (Windows) - needs: check_code_format - runs-on: windows-latest - timeout-minutes: 90 - steps: - - - name: Install 7Zip PowerShell - shell: powershell - run: Install-Module 7Zip4PowerShell -Force -Verbose - - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - uses: actions/setup-python@v2 - with: - python-version: 3.7 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1.0.2 - - - name: Download And Install Vulkan - shell: powershell - run: | - Invoke-WebRequest -Uri "https://sdk.lunarg.com/sdk/download/1.2.189.0/windows/VulkanSDK-1.2.189.0-Installer.exe" -OutFile VulkanSDK.exe - $installer = Start-Process -FilePath VulkanSDK.exe -Wait -PassThru -ArgumentList @("/S"); - $installer.WaitForExit(); - - - name: Build - shell: powershell - run: | - $env:Path += ";C:/VulkanSDK/1.2.189.0/Bin" - cd C:\ - Remove-item alias:curl - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-msvc2019.zip -LO - 7z x taichi-llvm-10.0.0-msvc2019.zip -otaichi_llvm - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip -LO - 7z x clang-10.0.0-win.zip -otaichi_clang - $env:PATH = ";C:\taichi_llvm\bin;C:\taichi_clang\bin;" + $env:PATH - clang --version - cd D:\a\taichi\taichi - python -m pip install -r requirements_dev.txt - cd python - git fetch origin master - $env:TAICHI_CMAKE_ARGS = $env:CI_SETUP_CMAKE_ARGS - python build.py build - cd ..\dist - $env:WHL = $(dir *.whl) - python -m pip install $env:WHL - env: - PYTHON: C:\hostedtoolcache\windows\Python\3.7.9\x64\python.exe - CI_SETUP_CMAKE_ARGS: -G "Visual Studio 16 2019" -A x64 -DLLVM_DIR=C:\taichi_llvm\lib\cmake\llvm -DTI_WITH_VULKAN:BOOL=ON - VULKAN_SDK: C:/VulkanSDK/1.2.189.0 - - - name: Test - shell: powershell - run: | - $env:PATH = ";C:\taichi_llvm\bin;C:\taichi_clang\bin;" + $env:PATH - python -c "import taichi" - python examples/algorithm/laplace.py - python bin/taichi diagnose - python bin/taichi changelog - python bin/taichi test -vr2 -t2 - env: - PYTHON: C:\hostedtoolcache\windows\Python\3.7.9\x64\python.exe - - build_and_test_m1: - name: Build and Test (Apple M1) - needs: check_code_format - timeout-minutes: 60 - strategy: - matrix: - include: - - os: macos-latest - python: 3.8 - defaults: - run: - # https://github.com/actions/runner/issues/805#issuecomment-844426478 - shell: "/usr/bin/arch -arch arm64e /bin/bash --noprofile --norc -eo pipefail {0}" - runs-on: [self-hosted, m1] - steps: - - uses: actions/checkout@v2 - with: - submodules: 'recursive' - - - name: Build - run: | - rm -rf $HOME/Library/Python/3.8/lib/python/site-packages/taichi - .github/workflows/scripts/unix_build.sh - env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON - CXX: clang++ - - - name: Test - run: | - export PATH=$PATH:$HOME/Library/Python/3.8/bin - python3 examples/algorithm/laplace.py - TI_LIB_DIR=`python3 -c "import taichi;print(taichi.__path__[0])" | tail -1` - TI_LIB_DIR="$TI_LIB_DIR/lib" ./build/taichi_cpp_tests - ti test -vr2 -t4 -x diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 0000000000000..290a2961bf0a7 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,30 @@ +name: Presubmit Title Checks +on: + pull_request_target: + types: [opened, synchronize, reopened, edited] + +jobs: + pre_submit: + name: Presubmit Title Checks + if: ${{ github.event.pull_request }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + - name: Install Dependencies + run: pip install semver GitPython PyGithub + + - name: Run PR Title Checker + run: | + python misc/ci_check_pr_title.py "$PR_TITLE" + env: + PR_TITLE: ${{ github.event.pull_request.title }} + + - name: PR Project Card Creation + if: github.event.action == 'opened' || github.event.action == 'edited' + run: python misc/ci_create_pr_card.py + env: + GITHUB_TOKEN: ${{ secrets.GARDENER_PAT }} + GH_EVENT: ${{ toJson(github.event) }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ee651749ca5a8..a448e4ab1d939 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,35 +1,65 @@ name: Publishing Release on: - release: - # https://docs.github.com/en/free-pro-team@latest/actions/reference/events-that-trigger-workflows#release - types: [published] - # When triggered by schedule and workflow_dispatch, github.event.action is an empty string. - # We use this to distinguish which taichi to release. schedule: - cron: "0 0 * * *" workflow_dispatch: + # Manually trigger the release workflow, a version must be provided + inputs: + version: + description: "The version to release (e.g. v0.8.0)" + type: string + required: true + +env: + PROD_PWD: ${{ secrets.PYPI_PWD_PROD }} + NIGHT_PWD: ${{ secrets.PYPI_PWD_NIGHTLY }} + METADATA_USERNAME: ${{ secrets.METADATA_USERNAME }} + METADATA_PASSWORD: ${{ secrets.METADATA_PASSWORD }} + METADATA_URL: ${{ secrets.METADATA_URL }} + RELEASE_VERSION: ${{ github.event.inputs.version }} jobs: + add_version_to_database: + name: Add version to database + # Skip running release workflow on forks + if: github.repository_owner == 'taichi-dev' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Save new version + run: | + if [ -z "$RELEASE_VERSION" ]; then + echo "Not a production release" + exit 0 + fi + python3 -m pip install requests==2.26 + python3 misc/save_new_version.py + # This job set environment matrix with respect to production release and nightly release. matrix_prep: runs-on: ubuntu-latest + needs: add_version_to_database outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} matrix_osx: ${{ steps.set-matrix.outputs.matrix_osx }} steps: - id: set-matrix run: | - # For nightly release, we only run on python 3.8 - [ -z "${{ github.event.action }}" ] && matrix="[{\"name\":\"taichi-nightly\",\"python\":\"3.8\"}]" - # For production release, we run on four python versions. - [ -z "${{ github.event.action }}" ] || matrix="[{\"name\":\"taichi\",\"python\":\"3.6\"},{\"name\":\"taichi\",\"python\":\"3.7\"},{\"name\":\"taichi\",\"python\":\"3.8\"},{\"name\":\"taichi\",\"python\":\"3.9\"}]" - echo ::set-output name=matrix::{\"include\":$(echo $matrix)}\" - # M1 only supports py38 and py39(conda), so change matrix. - [ -z "${{ github.event.action }}" ] && matrix_osx="[{\"name\":\"taichi-nightly\",\"python\":\"3.8\"}]" - [ -z "${{ github.event.action }}" ] || matrix_osx="[{\"name\":\"taichi\",\"python\":\"3.8\"},{\"name\":\"taichi\",\"python\":\"3.9\"}]" - echo ::set-output name=matrix_osx::{\"include\":$(echo $matrix_osx)}\" - - build_and_upload_linux: + if [[ "$GITHUB_EVENT_NAME" == "workflow_dispatch" ]]; then + # For production release, we run on four python versions. + echo '::set-output name=matrix::{"include":[{"name":"taichi","python":"3.6","conda_python":"py36"},{"name":"taichi","python":"3.7","conda_python":"py37"},{"name":"taichi","python":"3.8","conda_python":"py38"},{"name":"taichi","python":"3.9","conda_python":"py39"}]}"' + + echo '::set-output name=matrix_osx::{"include":[{"name":"taichi","python":"3.8"},{"name":"taichi","python":"3.9"}]}"' + else + # For nightly release, we only run on python 3.8 + echo '::set-output name=matrix::{"include":[{"name":"taichi-nightly","python":"3.8","conda_python":"py38"},{"name":"taichi-nightly","python":"3.10","conda_python":"py310"}]}"' + + # M1 only supports py38 and py39(conda), so change matrix. + echo '::set-output name=matrix_osx::{"include":[{"name":"taichi-nightly","python":"3.8"},{"name":"taichi-nightly","python":"3.10"}]}"' + fi + + build_and_test_linux: name: Build and Upload (linux only) needs: matrix_prep strategy: @@ -39,138 +69,115 @@ jobs: steps: - uses: actions/checkout@v2 with: - submodules: 'recursive' + submodules: "recursive" - - name: Create Python Wheel + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-linux-gpu-${{ github.sha }} + restore-keys: | + sccache-linux-gpu- + + - name: Build run: | - # We hacked here because conda activate in CI won't update python PATH - # automatically. So we don't activate and use desired python version - # directly. - export PATH=/home/buildbot/miniconda3/envs/$PYTHON/bin:$PATH - TAICHI_REPO_DIR=`pwd` - export PATH=$LLVM_LIB_ROOT_DIR/bin:/usr/local/cuda/bin:$PATH - export LLVM_DIR=$LLVM_LIB_ROOT_DIR/lib/cmake/llvm - export CXX=clang++-8 - python3 -m pip uninstall taichi taichi-nightly -y - python3 -m pip install -r requirements_dev.txt - python3 -m pip install twine - cd python - git fetch origin master - export TAICHI_CMAKE_ARGS=$CI_SETUP_CMAKE_ARGS - python3 build.py build --project_name $PROJECT_NAME - cd .. - NUM_WHL=`ls dist/*.whl | wc -l` - if [ $NUM_WHL -ne 1 ]; then echo 'ERROR: created more than 1 whl.' && exit 1; fi - python3 -m pip install dist/*.whl + mkdir -m777 shared + docker create --user dev --name taichi_build --gpus all -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY -e PY -e GPU_BUILD -e TAICHI_CMAKE_ARGS -e PROJECT_NAME \ + registry.taichigraphics.com/taichidev-ubuntu18.04:v0.2.1 \ + /home/dev/${{ github.event.repository.name }}/.github/workflows/scripts/unix_build.sh + tar -cf - ../${{ github.event.repository.name }} --mode u=+rwx,g=+rwx,o=+rwx --owner 1000 --group 1000 | docker cp - taichi_build:/home/dev/ + docker start -a taichi_build + docker cp taichi_build:/home/dev/${{ github.event.repository.name }}/dist shared/dist + docker cp taichi_build:/home/dev/${{ github.event.repository.name }}/build shared/build env: - LLVM_LIB_ROOT_DIR: /opt/taichi-llvm-10.0.0 - BUILD_NUM_THREADS: 8 - CI_SETUP_CMAKE_ARGS: -DTI_WITH_VULKAN:BOOL=ON -DTI_WITH_OPENGL:BOOL=ON -DTI_WITH_CC:BOOL=OFF -DTI_BUILD_TESTS:BOOL=${{ matrix.with_cpp_tests }} + PY: ${{ matrix.conda_python }} + GPU_BUILD: ON + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=ON -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=ON -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache PROJECT_NAME: ${{ matrix.name }} - PYTHON: ${{ matrix.python }} + DISPLAY: ":1" - name: Archive Wheel Artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: - name: ${{ matrix.name }}-py${{ matrix.python }}-linux.whl - path: dist/*.whl + name: dist + path: shared/dist/*.whl + retention-days: 20 - name: Test run: | - export PATH=/home/buildbot/miniconda3/envs/$PYTHON/bin:$PATH - python3 examples/algorithm/laplace.py - export DISPLAY=:1 - hash -r - glewinfo - ti diagnose - ti changelog - ti test -vr2 -t2 -k "not ndarray and not torch" - # ndarray test might OOM if run with -t2. - # FIXME: unify this with presubmit.yml to avoid further divergence - ti test -vr2 -t1 -k "ndarray or torch" + docker create --user dev --name taichi_test --gpus all -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY -e PY -e GPU_TEST registry.taichigraphics.com/taichidev-ubuntu18.04:v0.2.1 \ + /home/dev/unix_test.sh + docker cp .github/workflows/scripts/unix_test.sh taichi_test:/home/dev/unix_test.sh + docker cp ./requirements_test.txt taichi_test:/home/dev/requirements_test.txt + docker cp shared/dist/ taichi_test:/home/dev/ + docker cp shared/build/ taichi_test:/home/dev/ + docker cp tests/ taichi_test:/home/dev/ + docker start -a taichi_test env: - PYTHON: ${{ matrix.python }} + PY: ${{ matrix.conda_python }} + GPU_TEST: ON + DISPLAY: ":1" - - name: Upload PyPI - env: - # https://docs.github.com/en/free-pro-team@latest/actions/reference/encrypted-secrets#using-encrypted-secrets-in-a-workflow - PROD_PWD: ${{ secrets.PYPI_PWD_PROD }} - NIGHT_PWD: ${{ secrets.PYPI_PWD_NIGHTLY }} - PROJECT_NAME: ${{ matrix.name }} - PYTHON: ${{ matrix.python }} + - name: clean docker container + if: always() run: | - export PATH=/home/buildbot/miniconda3/envs/$PYTHON/bin:$PATH - cd python - if [ $PROJECT_NAME == "taichi-nightly" ]; then export PYPI_PWD="$NIGHT_PWD" && python3 build.py upload --skip_build --testpypi --project_name $PROJECT_NAME - elif [ $PROJECT_NAME == "taichi" ]; then export PYPI_PWD="$PROD_PWD" && python3 build.py upload --skip_build; fi + docker rm taichi_build taichi_test -f - build_and_upload_mac: + build_and_test_mac: name: Build and Upload (macOS only) needs: matrix_prep strategy: fail-fast: false matrix: ${{ fromJson(needs.matrix_prep.outputs.matrix) }} - runs-on: macos-latest + runs-on: macos-10.15 steps: - uses: actions/checkout@v2 with: - submodules: 'recursive' + submodules: "recursive" + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-mac-${{ github.sha }} + restore-keys: | + sccache-mac- - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python }} - name: Download Pre-Built LLVM 10.0.0 - run: | - python misc/ci_download.py - mkdir taichi-llvm - cd taichi-llvm - unzip ../taichi-llvm.zip + run: python misc/ci_download.py env: - CI_PLATFORM: macos-latest + CI_PLATFORM: macos-10.15 - name: Create Python Wheel run: | - TAICHI_REPO_DIR=`pwd` - export PATH=$TAICHI_REPO_DIR/taichi-llvm/bin/:$PATH - export CXX=clang++ - python -m pip install -r requirements_dev.txt - cd python - git fetch origin master - export TAICHI_CMAKE_ARGS=$CI_SETUP_CMAKE_ARGS - python build.py build --project_name $PROJECT_NAME - cd .. - NUM_WHL=`ls dist/*.whl | wc -l` - if [ $NUM_WHL -ne 1 ]; then echo 'ERROR: created more than 1 whl.' && exit 1; fi - pip install dist/*.whl + brew install molten-vk + export PATH=$(pwd)/taichi-llvm/bin/:$PATH + bash .github/workflows/scripts/unix_build.sh + brew uninstall molten-vk env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_BUILD_TESTS:BOOL=${{ matrix.with_cpp_tests }} + TAICHI_CMAKE_ARGS: -DTI_WITH_VULKAN:BOOL=ON -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache PROJECT_NAME: ${{ matrix.name }} + CXX: clang++ - name: Archive Wheel Artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: - name: ${{ matrix.name }}-py${{ matrix.python }}-macos.whl + name: dist path: dist/*.whl + retention-days: 20 - name: Test - run: | - python examples/algorithm/laplace.py - ti diagnose - ti test -vr2 -t2 - - - name: Upload PyPI + run: .github/workflows/scripts/unix_test.sh env: - # https://docs.github.com/en/free-pro-team@latest/actions/reference/encrypted-secrets#using-encrypted-secrets-in-a-workflow - PROD_PWD: ${{ secrets.PYPI_PWD_PROD }} - NIGHT_PWD: ${{ secrets.PYPI_PWD_NIGHTLY }} - PROJECT_NAME: ${{ matrix.name }} - run: | - cd python - if [ $PROJECT_NAME == "taichi-nightly" ]; then export PYPI_PWD="$NIGHT_PWD" && python build.py upload --skip_build --testpypi --project_name $PROJECT_NAME - elif [ $PROJECT_NAME == "taichi" ]; then export PYPI_PWD="$PROD_PWD" && python build.py upload --skip_build; fi + TI_WANTED_ARCHS: "cpu" - build_and_upload_m1: + build_and_test_m1: name: Build and Upload (Apple M1) needs: matrix_prep strategy: @@ -183,51 +190,48 @@ jobs: steps: - uses: actions/checkout@v2 with: - submodules: 'recursive' + submodules: "recursive" + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-m1-${{ github.sha }} + restore-keys: | + sccache-m1- - name: Build run: | + brew install molten-vk # We hacked here because conda activate in CI won't update python PATH # automatically. So we don't activate and use desired python version # directly. export PATH=/Users/github/miniforge3/envs/$PYTHON/bin:$PATH - python3 -m pip uninstall taichi taichi-nightly -y - git --version - export CXX=clang++ - python3 -m pip install -r requirements_dev.txt - cd python - git fetch origin master - export TAICHI_CMAKE_ARGS=$CI_SETUP_CMAKE_ARGS - python3 build.py build --project_name $PROJECT_NAME - cd .. - export NUM_WHL=`ls dist/*.whl | wc -l` - if [ $NUM_WHL -ne 1 ]; then echo 'ERROR: created more than 1 whl.' && exit 1; fi - python3 -m pip install dist/*.whl + bash .github/workflows/scripts/unix_build.sh + brew uninstall molten-vk env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_TESTS:BOOL=ON + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=ON -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache PROJECT_NAME: ${{ matrix.name }} PYTHON: ${{ matrix.python }} + CXX: clang++ - name: Archive Wheel Artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: - name: ${{ matrix.name }}-py${{ matrix.python }}-macos-m1.whl + name: dist path: dist/*.whl + retention-days: 20 - - name: Upload PyPI - env: - # https://docs.github.com/en/free-pro-team@latest/actions/reference/encrypted-secrets#using-encrypted-secrets-in-a-workflow - PROD_PWD: ${{ secrets.PYPI_PWD_PROD }} - NIGHT_PWD: ${{ secrets.PYPI_PWD_NIGHTLY }} - PROJECT_NAME: ${{ matrix.name }} - PYTHON: ${{ matrix.python }} + - name: Test run: | export PATH=/Users/github/miniforge3/envs/$PYTHON/bin:$PATH - cd python - if [ $PROJECT_NAME == "taichi-nightly" ]; then export PYPI_PWD="$NIGHT_PWD" && python3 build.py upload --skip_build --testpypi --project_name $PROJECT_NAME - elif [ $PROJECT_NAME == "taichi" ]; then export PYPI_PWD="$PROD_PWD" && python3 build.py upload --skip_build; fi + .github/workflows/scripts/unix_test.sh + env: + TI_WANTED_ARCHS: "metal,vulkan,cpu" + PYTHON: ${{ matrix.python }} + GPU_TEST: ON - build_and_upload_macos_1014: + build_and_test_macos_1014: name: Build and Upload (macos 1014) needs: matrix_prep strategy: @@ -237,7 +241,15 @@ jobs: steps: - uses: actions/checkout@v2 with: - submodules: 'recursive' + submodules: "recursive" + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-1014-${{ github.sha }} + restore-keys: | + sccache-1014- - name: Build run: | @@ -247,53 +259,29 @@ jobs: export PATH=/Users/buildbot6/miniconda3/envs/$PYTHON/bin:$PATH export LLVM_DIR=/Users/buildbot6/taichi-llvm-10.0.0-macos export PATH=$LLVM_DIR/bin:$PATH - python3 -m pip uninstall taichi taichi-nightly -y - git --version - export CXX=clang++ - python3 -m pip install -r requirements_dev.txt - cd python - git fetch origin master - export TAICHI_CMAKE_ARGS=$CI_SETUP_CMAKE_ARGS - python3 build.py build --project_name $PROJECT_NAME - cd .. - export NUM_WHL=`ls dist/*.whl | wc -l` - if [ $NUM_WHL -ne 1 ]; then echo 'ERROR: created more than 1 whl.' && exit 1; fi - python3 -m pip install dist/*.whl + bash .github/workflows/scripts/unix_build.sh env: - CI_SETUP_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_TESTS:BOOL=ON + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache PROJECT_NAME: ${{ matrix.name }} PYTHON: ${{ matrix.python }} + CXX: clang++ - name: Archive Wheel Artifacts - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: - name: ${{ matrix.name }}-py${{ matrix.python }}-macos-1014.whl + name: dist path: dist/*.whl + retention-days: 20 - name: Test run: | export PATH=/Users/buildbot6/miniconda3/envs/$PYTHON/bin:$PATH - python examples/algorithm/laplace.py - ti diagnose - ti test -vr2 -t2 -a cpu + .github/workflows/scripts/unix_test.sh env: + TI_WANTED_ARCHS: "cpu" PYTHON: ${{ matrix.python }} - - name: Upload PyPI - env: - # https://docs.github.com/en/free-pro-team@latest/actions/reference/encrypted-secrets#using-encrypted-secrets-in-a-workflow - PROD_PWD: ${{ secrets.PYPI_PWD_PROD }} - NIGHT_PWD: ${{ secrets.PYPI_PWD_NIGHTLY }} - PROJECT_NAME: ${{ matrix.name }} - PYTHON: ${{ matrix.python }} - run: | - export PATH=/Users/buildbot6/miniconda3/envs/$PYTHON/bin:$PATH - cd python - if [ $PROJECT_NAME == "taichi-nightly" ]; then export PYPI_PWD="$NIGHT_PWD" && python3 build.py upload --skip_build --testpypi --project_name $PROJECT_NAME - elif [ $PROJECT_NAME == "taichi" ]; then export PYPI_PWD="$PROD_PWD" && python3 build.py upload --skip_build; fi - - - build_and_upload_windows: + build_and_test_windows: name: Build and Upload (Windows only) needs: matrix_prep strategy: @@ -301,79 +289,138 @@ jobs: matrix: ${{ fromJson(needs.matrix_prep.outputs.matrix) }} runs-on: windows-latest steps: - - name: Install 7Zip PowerShell - shell: powershell - run: Install-Module 7Zip4PowerShell -Force -Verbose - - uses: actions/checkout@v2 with: - submodules: 'recursive' + submodules: "recursive" - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python }} - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v1.0.2 + - name: Add Visual Studio Shell to ENV + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 - - name: Download And Install Vulkan - shell: powershell - run: | - Invoke-WebRequest -Uri "https://sdk.lunarg.com/sdk/download/1.2.189.0/windows/VulkanSDK-1.2.189.0-Installer.exe" -OutFile VulkanSDK.exe - $installer = Start-Process -FilePath VulkanSDK.exe -Wait -PassThru -ArgumentList @("/S"); - $installer.WaitForExit(); + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: ccache_cache + key: ccache-win64-clang-${{ github.sha }} + restore-keys: | + ccache-win64-clang- - name: Build Python Wheel shell: powershell run: | - $env:Path += ";C:/VulkanSDK/1.2.189.0/Bin" - cd C:\ - Remove-item alias:curl - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-msvc2019.zip -LO - 7z x taichi-llvm-10.0.0-msvc2019.zip -otaichi_llvm - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip -LO - 7z x clang-10.0.0-win.zip -otaichi_clang - $env:PATH = ";C:\taichi_llvm\bin;C:\taichi_clang\bin;" + $env:PATH - clang --version - cd D:\a\taichi\taichi - python -m pip install -r requirements_dev.txt - cd python - git fetch origin master - $env:TAICHI_CMAKE_ARGS = $env:CI_SETUP_CMAKE_ARGS - python build.py build --project_name $env:PROJECT_NAME - cd ..\dist - $env:WHL = $(dir *.whl) - python -m pip install $env:WHL + .\.github\workflows\scripts\win_build.ps1 -installVulkan -libsDir C:\ + venv\Scripts\python -m pip install $(dir dist\*.whl) env: - CI_SETUP_CMAKE_ARGS: -G "Visual Studio 16 2019" -A x64 -DLLVM_DIR=C:\taichi_llvm\lib\cmake\llvm -DTI_WITH_VULKAN:BOOL=ON PROJECT_NAME: ${{ matrix.name }} - VULKAN_SDK: C:/VulkanSDK/1.2.189.0 - name: Archive Wheel Artifacts uses: actions/upload-artifact@v2 with: - name: ${{ matrix.name }}-py${{ matrix.python }}-windows.whl + name: dist path: dist/*.whl + retention-days: 20 - name: Test shell: powershell run: | $env:PATH = ";C:\taichi_llvm\bin;C:\taichi_clang\bin;" + $env:PATH + . venv\Scripts\activate.ps1 python -c "import taichi" - python examples/algorithm/laplace.py - python bin/taichi diagnose - python bin/taichi test -vr2 -t2 - - - name: Upload PyPI - shell: powershell + pip install torch + ti diagnose + python tests/run_tests.py -vr2 -t2 env: - # https://docs.github.com/en/free-pro-team@latest/actions/reference/encrypted-secrets#using-encrypted-secrets-in-a-workflow - PROD_PWD: ${{ secrets.PYPI_PWD_PROD }} - NIGHT_PWD: ${{ secrets.PYPI_PWD_NIGHTLY }} - PROJECT_NAME: ${{ matrix.name }} + TI_SKIP_VERSION_CHECK: ON + + upload_to_pypi: + name: Upload release to PyPI + needs: + [ + build_and_test_linux, + build_and_test_mac, + build_and_test_m1, + build_and_test_macos_1014, + build_and_test_windows, + ] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Get dist files + uses: actions/download-artifact@v3 + with: + name: dist + path: dist + + - name: Upload to PyPI run: | - cd python - if ( $env:PROJECT_NAME -eq "taichi-nightly" ) {$env:PYPI_PWD = "$env:NIGHT_PWD"} - if ( $env:PROJECT_NAME -eq "taichi-nightly" ) {python build.py upload --skip_build --testpypi --project_name $env:PROJECT_NAME} - if ( $env:PROJECT_NAME -eq "taichi" ) {$env:PYPI_PWD = "$env:PROD_PWD"} - if ( $env:PROJECT_NAME -eq "taichi" ) {python build.py upload --skip_build} + ls -l dist/ + if [ -z "$RELEASE_VERSION" ]; then + export PROJECT_NAME="taichi-nightly" + else + export PROJECT_NAME="taichi" + fi + python -m pip install requests twine + python misc/upload_release.py + + create_release: + name: Create tag and publish release + needs: upload_to_pypi + runs-on: ubuntu-latest + if: github.event_name == 'workflow_dispatch' + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Generate Changelog + id: changelog + run: | + pip3 install gitpython + content=$(python3 misc/make_changelog.py) + echo $content + # Escape multiline strings: + # https://renehernandez.io/snippets/multiline-strings-as-a-job-output-in-github-actions/ + content="${content//'%'/'%25'}" + content="${content//$'\n'/'%0A'}" + content="${content//$'\r'/'%0D'}" + echo "::set-output name=content::$content" + + - name: Create tag + run: | + git config user.email "taichigardener@gmail.com" + git config user.name "Taichi Gardener" + git tag -a ${RELEASE_VERSION} -m "Release ${RELEASE_VERSION}" + git push origin --tags + + - name: Publish release + uses: softprops/action-gh-release@v1 + with: + body: ${{ steps.changelog.outputs.content }} + tag_name: ${{ github.event.inputs.version }} + + - name: Bump version + run: | + version_parts=(${RELEASE_VERSION//./ }) + version_parts[2]=$(expr ${version_parts[2]} + 1) + next_version=$(IFS=.; echo "${version_parts[*]}") + # Update version.txt + git checkout -b "bump/$next_version" + echo "$next_version" > version.txt + git add version.txt + # Commit and push changes + git commit -m "Bump version to $next_version" + git push origin "bump/$next_version" + # Create pull request + gh pr create -B master -t "[misc] Bump version to $next_version" + env: + GITHUB_TOKEN: ${{ secrets.GARDENER_PAT }} diff --git a/.github/workflows/scripts/check_clang_tidy.sh b/.github/workflows/scripts/check_clang_tidy.sh new file mode 100755 index 0000000000000..d9db1c9a3433f --- /dev/null +++ b/.github/workflows/scripts/check_clang_tidy.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +CI_SETUP_CMAKE_ARGS=$1 + +cd taichi +python3 -m pip install -r requirements_dev.txt + +rm -rf build && mkdir build && cd build +cmake $CI_SETUP_CMAKE_ARGS .. + +cd .. +python3 ./scripts/run_clang_tidy.py $PWD/taichi -clang-tidy-binary clang-tidy-10 -checks=-*,performance-inefficient-string-concatenation,readability-identifier-naming -header-filter=$PWD/taichi -p $PWD/build -j2 diff --git a/.github/workflows/scripts/unix_build.sh b/.github/workflows/scripts/unix_build.sh index f715a8146bf13..a4960f1d2e1ff 100755 --- a/.github/workflows/scripts/unix_build.sh +++ b/.github/workflows/scripts/unix_build.sh @@ -1,11 +1,76 @@ +#!/bin/bash set -ex -export PATH=`pwd`/taichi-llvm/bin/:$LLVM_PATH:$PATH -python3 -m pip uninstall taichi taichi-nightly -y -python3 -m pip install -r requirements_dev.txt -cd python -git fetch origin master -TAICHI_CMAKE_ARGS=$CI_SETUP_CMAKE_ARGS python3 build.py build -cd .. -export NUM_WHL=`ls dist/*.whl | wc -l` -if [ $NUM_WHL -ne 1 ]; then echo `ERROR: created more than 1 whl.` && exit 1; fi -python3 -m pip install dist/*.whl + +check_in_docker() { + # This is a temporary solution to detect in a docker, but it should work + if [[ $(whoami) == "dev" ]]; then + echo "true" + else + echo "false" + fi +} + +IN_DOCKER=$(check_in_docker) +[[ "$IN_DOCKER" == "true" ]] && cd taichi + +setup_sccache() { + export SCCACHE_DIR=$(pwd)/sccache_cache + export SCCACHE_CACHE_SIZE="128M" + export SCCACHE_LOG=error + export SCCACHE_ERROR_LOG=$(pwd)/sccache_error.log + mkdir -p "$SCCACHE_DIR" + echo "sccache dir: $SCCACHE_DIR" + ls -la "$SCCACHE_DIR" + + if [[ $OSTYPE == "linux-"* ]]; then + wget https://github.com/mozilla/sccache/releases/download/v0.2.15/sccache-v0.2.15-x86_64-unknown-linux-musl.tar.gz + tar -xzf sccache-v0.2.15-x86_64-unknown-linux-musl.tar.gz + chmod +x sccache-v0.2.15-x86_64-unknown-linux-musl/sccache + export PATH=$(pwd)/sccache-v0.2.15-x86_64-unknown-linux-musl:$PATH + elif [[ $(uname -m) == "arm64" ]]; then + wget https://github.com/mozilla/sccache/releases/download/v0.2.15/sccache-v0.2.15-aarch64-apple-darwin.tar.gz + tar -xzf sccache-v0.2.15-aarch64-apple-darwin.tar.gz + chmod +x sccache-v0.2.15-aarch64-apple-darwin/sccache + export PATH=$(pwd)/sccache-v0.2.15-aarch64-apple-darwin:$PATH + else + wget https://github.com/mozilla/sccache/releases/download/v0.2.15/sccache-v0.2.15-x86_64-apple-darwin.tar.gz + tar -xzf sccache-v0.2.15-x86_64-apple-darwin.tar.gz + chmod +x sccache-v0.2.15-x86_64-apple-darwin/sccache + export PATH=$(pwd)/sccache-v0.2.15-x86_64-apple-darwin:$PATH + fi +} + +setup_python() { + if [[ "$IN_DOCKER" == "true" ]]; then + source $HOME/miniconda/etc/profile.d/conda.sh + conda activate "$PY" + fi + python3 -m pip uninstall taichi taichi-nightly -y + python3 -m pip install -r requirements_dev.txt +} + +build() { + git fetch origin master + PROJECT_TAGS="" + EXTRA_ARGS="" + if [ "$PROJECT_NAME" = "taichi-nightly" ]; then + PROJECT_TAGS="egg_info --tag-date" + fi + + if [[ $OSTYPE == "linux-"* ]]; then + EXTRA_ARGS="-p manylinux1_x86_64" + fi + python3 misc/make_changelog.py origin/master ./ True + python3 setup.py $PROJECT_TAGS bdist_wheel $EXTRA_ARGS + sccache -s +} + +setup_sccache +setup_python +build +cat "$SCCACHE_ERROR_LOG" +NUM_WHL=$(ls dist/*.whl | wc -l) +if [ $NUM_WHL -ne 1 ]; then echo "ERROR: created more than 1 whl." && exit 1; fi + +chmod -R 777 "$SCCACHE_DIR" +rm -f python/CHANGELOG.md diff --git a/.github/workflows/scripts/unix_test.sh b/.github/workflows/scripts/unix_test.sh index 907c52160b8be..fca398ae663af 100755 --- a/.github/workflows/scripts/unix_test.sh +++ b/.github/workflows/scripts/unix_test.sh @@ -1,15 +1,63 @@ +#!/bin/bash set -ex -TAICHI_REPO_DIR=`pwd` -TI_LIB_DIR=`python3 -c "import taichi;print(taichi.__path__[0])" | tail -1` -[[ $RUN_CPP_TESTS == "ON" ]] && TI_LIB_DIR="$TI_LIB_DIR/lib" ./build/taichi_cpp_tests -export PATH=$TAICHI_REPO_DIR/taichi-llvm/bin/:$PATH -## Only GPU machine uses system python. -[ -z $GPU_TEST ] || export PATH=$PATH:$HOME/.local/bin -hash -r -python3 examples/algorithm/laplace.py + +check_in_docker() { + # This is a temporary solution to detect in a docker, but it should work + if [[ $(whoami) == "dev" ]]; then + echo "true" + else + echo "false" + fi +} + +export TI_SKIP_VERSION_CHECK=ON +export TI_IN_DOCKER=$(check_in_docker) + +if [[ "$TI_IN_DOCKER" == "true" ]]; then + source $HOME/miniconda/etc/profile.d/conda.sh + conda activate "$PY" +fi +python3 -m pip install dist/*.whl +if [ -z "$GPU_TEST" ]; then + python3 -m pip install -r requirements_test.txt + python3 -m pip install "torch; python_version < '3.10'" +else + ## Only GPU machine uses system python. + export PATH=$PATH:$HOME/.local/bin + # pip will skip packages if already installed + python3 -m pip install -r requirements_test.txt +fi ti diagnose ti changelog -[ -z $GPU_TEST ] && ti test -vr2 -t2 +echo "wanted archs: $TI_WANTED_ARCHS" + +TI_PATH=$(python3 -c "import taichi;print(taichi.__path__[0])" | tail -1) +TI_LIB_DIR="$TI_PATH/_lib/runtime" ./build/taichi_cpp_tests -[ -z $GPU_TEST ] || ti test -vr2 -t2 -k "not ndarray and not torch" -[ -z $GPU_TEST ] || ti test -vr2 -t1 -k "ndarray or torch" +if [ -z "$GPU_TEST" ]; then + if [[ $PLATFORM == *"m1"* ]]; then + # Split per arch to avoid flaky test + python3 tests/run_tests.py -vr2 -t4 -k "not torch" -a cpu + # Run metal and vulkan separately so that they don't use M1 chip simultaneously. + python3 tests/run_tests.py -vr2 -t4 -k "not torch" -a vulkan + python3 tests/run_tests.py -vr2 -t2 -k "not torch" -a metal + python3 tests/run_tests.py -vr2 -t1 -k "torch" -a "$TI_WANTED_ARCHS" + else + python3 tests/run_tests.py -vr2 -t4 -a "$TI_WANTED_ARCHS" + fi +else + # Split per arch to increase parallelism for linux GPU tests + if [[ $TI_WANTED_ARCHS == *"cuda"* ]]; then + python3 tests/run_tests.py -vr2 -t4 -k "not torch" -a cuda + fi + if [[ $TI_WANTED_ARCHS == *"cpu"* ]]; then + python3 tests/run_tests.py -vr2 -t8 -k "not torch" -a cpu + fi + if [[ $TI_WANTED_ARCHS == *"vulkan"* ]]; then + python3 tests/run_tests.py -vr2 -t8 -k "not torch" -a vulkan + fi + if [[ $TI_WANTED_ARCHS == *"opengl"* ]]; then + python3 tests/run_tests.py -vr2 -t4 -k "not torch" -a opengl + fi + python3 tests/run_tests.py -vr2 -t1 -k "torch" -a "$TI_WANTED_ARCHS" +fi diff --git a/.github/workflows/scripts/win_build.ps1 b/.github/workflows/scripts/win_build.ps1 new file mode 100644 index 0000000000000..e58c179ee6952 --- /dev/null +++ b/.github/workflows/scripts/win_build.ps1 @@ -0,0 +1,101 @@ +# Build script for windows + +param ( + [switch]$clone = $false, + [switch]$installVulkan = $false, + [switch]$develop = $false, + [switch]$install = $false, + [string]$libsDir = "." +) + +$ErrorActionPreference = "Stop" + +$RepoURL = 'https://github.com/taichi-dev/taichi' + +function WriteInfo($text) { + Write-Host -ForegroundColor Green "[BUILD] $text" +} + +# Get sccache +$env:CCACHE_DIR="${pwd}/ccache_cache" +$env:CCACHE_MAXSIZE="128M" +$env:CCACHE_LOGFILE="${pwd}/ccache_error.log" +WriteInfo("ccache dir: $Env:CCACHE_DIR") +md "$Env:CCACHE_DIR" -ea 0 +if (-not (Test-Path "ccache-4.5.1-windows-64")) { + curl.exe --retry 10 --retry-delay 5 https://github.com/ccache/ccache/releases/download/v4.5.1/ccache-4.5.1-windows-64.zip -LO + 7z x ccache-4.5.1-windows-64.zip + $env:PATH += ";${pwd}/ccache-4.5.1-windows-64" +} +ccache -v -s + +# WriteInfo("Install 7Zip") +# Install-Module 7Zip4PowerShell -Force -Verbose -Scope CurrentUser + +if ($clone) { + WriteInfo("Clone the repository") + git clone --recurse-submodules $RepoURL + Set-Location .\taichi +} + +$libsDir = (Resolve-Path $libsDir).Path + +if (-not (Test-Path $libsDir)) { + New-Item -ItemType Directory -Path $libsDir +} +Push-Location $libsDir +if (-not (Test-Path "taichi_llvm")) { + WriteInfo("Download and extract LLVM") + curl.exe --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-msvc2019.zip -LO + 7z x taichi-llvm-10.0.0-msvc2019.zip -otaichi_llvm +} +if (-not (Test-Path "taichi_clang")) { + WriteInfo("Download and extract Clang") + curl.exe --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip -LO + 7z x clang-10.0.0-win.zip -otaichi_clang +} +$env:LLVM_DIR = "$libsDir\taichi_llvm" +$env:TAICHI_CMAKE_ARGS += " -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang" +if ($installVulkan) { + WriteInfo("Download and install Vulkan") + if (-not (Test-Path "VulkanSDK")) { + curl.exe --retry 10 --retry-delay 5 https://sdk.lunarg.com/sdk/download/1.2.189.0/windows/VulkanSDK-1.2.189.0-Installer.exe -Lo VulkanSDK.exe + $installer = Start-Process -FilePath VulkanSDK.exe -Wait -PassThru -ArgumentList @("/S"); + $installer.WaitForExit(); + } + $env:VULKAN_SDK = "$libsDir\VulkanSDK\1.2.189.0" + $env:PATH += ";$env:VULKAN_SDK\Bin" + $env:TAICHI_CMAKE_ARGS += " -DTI_WITH_VULKAN:BOOL=ON" +} + +Pop-Location +clang --version + +WriteInfo("Setting up Python environment") +python -m venv venv +. venv\Scripts\activate.ps1 +python -m pip install wheel +python -m pip install -r requirements_dev.txt +python -m pip install -r requirements_test.txt +if (-not $?) { exit 1 } +WriteInfo("Building Taichi") +$env:TAICHI_CMAKE_ARGS += " -DCLANG_EXECUTABLE=$libsDir\\taichi_clang\\bin\\clang++.exe" +$env:TAICHI_CMAKE_ARGS += " -DLLVM_AS_EXECUTABLE=$libsDir\\taichi_llvm\\bin\\llvm-as.exe" +if ($install) { + if ($develop) { + python setup.py develop + } else { + python setup.py install + } + if (-not $?) { exit 1 } + WriteInfo("Build and install finished") +} else { + if ($env:PROJECT_NAME -eq "taichi-nightly") { + python setup.py egg_info --tag-date bdist_wheel + } else { + python setup.py bdist_wheel + } + if (-not $?) { exit 1 } + WriteInfo("Build finished") +} +ccache -s -v diff --git a/.github/workflows/scripts/win_test.ps1 b/.github/workflows/scripts/win_test.ps1 new file mode 100644 index 0000000000000..40ab79826257d --- /dev/null +++ b/.github/workflows/scripts/win_test.ps1 @@ -0,0 +1,28 @@ +$ErrorActionPreference = "Stop" + +. venv\Scripts\activate.ps1 +python -c "import taichi" +ti diagnose +ti changelog +echo wanted arch: $env:TI_WANTED_ARCHS +pip install -r requirements_test.txt +# TODO relax this when torch supports 3.10 +if ("$env:TI_WANTED_ARCHS".Contains("cuda")) { + pip install "torch==1.10.1+cu113; python_version < '3.10'" -f https://download.pytorch.org/whl/cu113/torch_stable.html +} else { + pip install "torch; python_version < '3.10'" +} +if ("$env:TI_WANTED_ARCHS".Contains("cuda")) { + python tests/run_tests.py -vr2 -t4 -k "not torch" -a cuda + if (-not $?) { exit 1 } +} +if ("$env:TI_WANTED_ARCHS".Contains("cpu")) { + python tests/run_tests.py -vr2 -t6 -k "not torch" -a cpu + if (-not $?) { exit 1 } +} +if ("$env:TI_WANTED_ARCHS".Contains("opengl")) { + python tests/run_tests.py -vr2 -t4 -k "not torch" -a opengl + if (-not $?) { exit 1 } +} +python tests/run_tests.py -vr2 -t2 -k "torch" -a "$env:TI_WANTED_ARCHS" +if (-not $?) { exit 1 } diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml new file mode 100644 index 0000000000000..dc8b63a3f86ec --- /dev/null +++ b/.github/workflows/testing.yml @@ -0,0 +1,471 @@ +name: Build and Test +on: + pull_request: + types: [opened, synchronize, reopened] + push: + branches: [master] + +concurrency: + group: ${{ github.event.number || github.run_id }} + cancel-in-progress: true + +jobs: + check_files: + name: Check files + # Disable this workflow on forks + if: github.repository_owner == 'taichi-dev' + outputs: + run_job: ${{ steps.check_files.outputs.run_job }} + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + with: + fetch-depth: 2 + + - name: check modified files + id: check_files + run: | + echo "Concurrency group: ${{ github.event.number || github.run_id }}" + echo "=============== list modified files ===============" + git diff --name-only @^ + + chore_files=( LICENSE CONTRIBUTING.md README.md netlify.toml ) + chore_dirs=( docs ) + run_job=false + + for file in $(git diff --name-only @^); do + is_chore=false + + for chore_file in ${chore_files[*]}; do + [[ ${file} == ${chore_file} ]] && is_chore=true && break + done + + for chore_dir in ${chore_dirs[*]}; do + [[ ${file} == ${chore_dir}/* ]] && is_chore=true && break + done + + if ! ${is_chore}; then + run_job=true + break + fi + done + + if ${run_job}; then + echo "::set-output name=run_job::true" + else + echo "::set-output name=run_job::false" + fi + + check_code_format: + name: Check Code Format + runs-on: ubuntu-latest + needs: check_files + # This job will be required to pass before merging to master branch. + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Setup git & clang-format + run: | + git config user.email "taichigardener@gmail.com" + git config user.name "Taichi Gardener" + git checkout -b _fake_squash + git remote add upstream https://github.com/taichi-dev/taichi.git + git fetch upstream master + sudo apt install clang-format-10 + + - name: Cache PIP + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ hashFiles('setup.py') }}-${{ hashFiles('requirements_dev.txt') }} + + - name: Install requirements + run: | + python3 -m pip install --user -r requirements_dev.txt + + - name: Check code format + run: | + python3 misc/code_format.py + git checkout -b _enforced_format + git commit -am "enforce code format" || true + # exit with 1 if there were differences: + git diff _fake_squash _enforced_format --exit-code + + check_static_analyzer: + name: Check Static Analyzer + runs-on: ubuntu-latest + needs: check_files + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Pylint + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + python3 -m pip install --user pylint + # Make sure pylint doesn't regress + pylint python/taichi/ --disable=all --enable=$(python scripts/generate_pylint_tags.py) + if [ $? -eq 0 ] + then + echo "PASSED: pylint is happy" + exit 0 + else + echo "FAILED: please run the pylint command above and make sure it passes" + exit 1 + fi + + - name: clang-tidy + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + # https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#upgrading-a-workflow-that-accesses-ghcrio + echo $CR_PAT | docker login ghcr.io -u ${{ github.actor }} --password-stdin + docker pull ghcr.io/taichi-dev/taichidev-cpu-ubuntu18.04:v0.2.2 + docker run -id --user dev --name check_clang_tidy ghcr.io/taichi-dev/taichidev-cpu-ubuntu18.04:v0.2.2 /bin/bash + tar -cf - ../${{ github.event.repository.name }} --mode u=+rwx,g=+rwx,o=+rwx --owner 1000 --group 1000 | docker cp - check_clang_tidy:/home/dev/ + docker exec --user root check_clang_tidy apt install -y clang-tidy-10 + docker exec --user dev check_clang_tidy /home/dev/taichi/.github/workflows/scripts/check_clang_tidy.sh "$CI_SETUP_CMAKE_ARGS" + env: + CR_PAT: ${{ secrets.GITHUB_TOKEN }} + CI_SETUP_CMAKE_ARGS: -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=ON -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=OFF + + build_and_test_cpu_linux: + name: Build and Test linux (CPU) + needs: [check_code_format, check_files] + timeout-minutes: 60 + strategy: + matrix: + include: + - os: ubuntu-latest + python: py39 + with_cc: ON + wanted_archs: "cpu,cc" + - os: ubuntu-latest + python: py310 + with_cc: ON + wanted_archs: "cpu,cc" + runs-on: ${{ matrix.os }} + permissions: + packages: read + contents: read + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-linux-${{matrix.with_cc}}-${{ github.sha }} + restore-keys: | + sccache-linux-${{matrix.with_cc}}- + + - name: Get docker images + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + # https://docs.github.com/en/packages/managing-github-packages-using-github-actions-workflows/publishing-and-installing-a-package-with-github-actions#upgrading-a-workflow-that-accesses-ghcrio + echo $CR_PAT | docker login ghcr.io -u ${{ github.actor }} --password-stdin + docker pull ghcr.io/taichi-dev/taichidev-cpu-ubuntu18.04:v0.2.2 + env: + CR_PAT: ${{ secrets.GITHUB_TOKEN }} + + - name: Build + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + mkdir -m777 shared + docker create --user dev --name taichi_build \ + -e PY -e PROJECT_NAME -e TAICHI_CMAKE_ARGS \ + ghcr.io/taichi-dev/taichidev-cpu-ubuntu18.04:v0.2.2 \ + /home/dev/taichi/.github/workflows/scripts/unix_build.sh + # A tarball is needed because sccache needs some permissions that only the file owner has. + # 1000 is the uid and gid of user "dev" in the container. + # If the uid or gid of the user inside the docker changes, please change the uid and gid in the following line. + tar -cf - ../${{ github.event.repository.name }} --mode u=+rwx,g=+rwx,o=+rwx --owner 1000 --group 1000 | docker cp - taichi_build:/home/dev/ + docker start -a taichi_build + rm -rf sccache_cache + docker cp taichi_build:/home/dev/taichi/sccache_cache sccache_cache + docker cp taichi_build:/home/dev/taichi/dist shared/dist + docker cp taichi_build:/home/dev/taichi/build shared/build + env: + PY: ${{ matrix.python }} + PROJECT_NAME: taichi + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=${{ matrix.with_cc }} -DTI_WITH_VULKAN:BOOL=OFF -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + + - name: Test + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + docker create --user dev --name taichi_test -e PY -e TI_WANTED_ARCHS ghcr.io/taichi-dev/taichidev-cpu-ubuntu18.04:v0.2.2 /home/dev/unix_test.sh + docker cp .github/workflows/scripts/unix_test.sh taichi_test:/home/dev/unix_test.sh + docker cp shared/dist/ taichi_test:/home/dev/ + docker cp shared/build/ taichi_test:/home/dev/ + docker cp ./requirements_test.txt taichi_test:/home/dev/requirements_test.txt + docker cp tests/ taichi_test:/home/dev/ + docker start -a taichi_test + env: + PY: ${{ matrix.python }} + TI_WANTED_ARCHS: ${{ matrix.wanted_archs }} + + - name: clean docker container + if: always() + run: | + docker rm taichi_build taichi_test -f + + build_and_test_cpu_mac: + name: Build and Test macos (CPU) + needs: [check_code_format, check_files] + timeout-minutes: 60 + strategy: + matrix: + include: + - os: macos-10.15 + python: 3.7 + with_cc: OFF + with_cpp_tests: ON + wanted_archs: "cpu" + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-mac-${{ github.sha }} + restore-keys: | + sccache-mac- + + - name: Download Pre-Built LLVM 10.0.0 + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + python misc/ci_download.py + env: + CI_PLATFORM: ${{ matrix.os }} + + - name: Build & Install + run: | + brew install molten-vk + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + mkdir -p sccache_cache + export PATH=`pwd`/taichi-llvm/bin/:$PATH + .github/workflows/scripts/unix_build.sh + brew uninstall molten-vk + env: + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=${{ matrix.with_cc }} -DTI_WITH_VULKAN:BOOL=ON -DTI_BUILD_TESTS:BOOL=${{ matrix.with_cpp_tests }} -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + CXX: clang++ + # [DEBUG] Copy this step around to enable debugging inside Github Action instances. + #- name: Setup tmate session + # uses: mxschmitt/action-tmate@v3 + # with: + # limit-access-to-actor: true + + - name: Test + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + .github/workflows/scripts/unix_test.sh + env: + TI_WANTED_ARCHS: ${{ matrix.wanted_archs }} + + build_and_test_gpu_linux: + name: Build and Test (GPU) + needs: [check_code_format, check_files] + runs-on: [self-hosted, cuda, vulkan, cn] + timeout-minutes: 60 + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-linux-gpu-${{ github.sha }} + restore-keys: | + sccache-linux-gpu- + + - name: Build & Install + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + mkdir -m777 shared + docker create --user dev --name taichi_build --gpus all -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e PY -e GPU_BUILD -e PROJECT_NAME -e TAICHI_CMAKE_ARGS -e DISPLAY \ + registry.taichigraphics.com/taichidev-ubuntu18.04:v0.2.1 \ + /home/dev/taichi/.github/workflows/scripts/unix_build.sh + # A tarball is needed because sccache needs some permissions that only the file owner has. + # 1000 is the uid and gid of user "dev" in the container. + # If the uid or gid of the user inside the docker changes, please change the uid and gid in the following line. + tar -cf - ../${{ github.event.repository.name }} --mode u=+rwx,g=+rwx,o=+rwx --owner 1000 --group 1000 | docker cp - taichi_build:/home/dev/ + docker start -a taichi_build + rm -rf sccache_cache + docker cp taichi_build:/home/dev/taichi/sccache_cache sccache_cache + docker cp taichi_build:/home/dev/taichi/dist shared/dist + docker cp taichi_build:/home/dev/taichi/build shared/build + env: + PY: py38 + GPU_BUILD: ON + PROJECT_NAME: taichi + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=ON -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=ON -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + DISPLAY: :1 + + - name: Test + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + docker create --user dev --name taichi_test --gpus all -v /tmp/.X11-unix:/tmp/.X11-unix \ + -e DISPLAY -e PY -e GPU_TEST -e TI_WANTED_ARCHS \ + registry.taichigraphics.com/taichidev-ubuntu18.04:v0.2.1 \ + /home/dev/unix_test.sh + docker cp .github/workflows/scripts/unix_test.sh taichi_test:/home/dev/unix_test.sh + docker cp shared/dist/ taichi_test:/home/dev/ + docker cp shared/build/ taichi_test:/home/dev/ + docker cp tests/ taichi_test:/home/dev/ + docker cp requirements_test.txt taichi_test:/home/dev/requirements_test.txt + docker start -a taichi_test + env: + PY: py38 + GPU_TEST: ON + DISPLAY: :1 + TI_WANTED_ARCHS: "cpu,cuda,vulkan,opengl" + + - name: clean docker container + if: always() + run: | + docker rm taichi_build taichi_test -f + + build_and_test_windows: + name: Build and Test Windows + needs: [check_code_format, check_files] + runs-on: [self-hosted, windows, gpu] + timeout-minutes: 90 + steps: + # See also https://github.com/taichi-dev/taichi/issues/4161 + - name: Cleanup + shell: powershell + run: | + remove-item '${{ github.workspace }}\*' -recurse -force + + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - uses: actions/setup-python@v2 + with: + python-version: 3.7 + + - name: Add Visual Studio Shell to ENV + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: ccache_cache + key: ccache-win64-${{ github.sha }} + restore-keys: | + ccache-win64- + + - name: Build + shell: powershell + if: ${{ needs.check_files.outputs.run_job != 'false' }} + run: | + .\.github\workflows\scripts\win_build.ps1 -installVulkan -install -libsDir C:\ + + - name: Test + shell: powershell + if: ${{ needs.check_files.outputs.run_job != 'false' }} + run: | + .\.github\workflows\scripts\win_test.ps1 + env: + TI_WANTED_ARCHS: cpu,cuda,opengl + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=ON -DTI_WITH_CC:BOOL=OFF + TI_SKIP_VERSION_CHECK: ON + PYTHON: "3.7" + + build_and_test_m1: + name: Build and Test (Apple M1) + needs: [check_code_format, check_files] + timeout-minutes: 60 + strategy: + matrix: + include: + - os: macos-latest + python: 3.8 + defaults: + run: + # https://github.com/actions/runner/issues/805#issuecomment-844426478 + shell: "/usr/bin/arch -arch arm64e /bin/bash --noprofile --norc -eo pipefail {0}" + runs-on: [self-hosted, m1] + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + + - name: Get sccache cache + uses: actions/cache@v2 + with: + path: sccache_cache + key: sccache-m1-${{ github.sha }} + restore-keys: | + sccache-m1- + + - name: Build + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + export PATH=/Users/github/miniforge3/envs/$PY/bin:$PATH + brew install molten-vk + .github/workflows/scripts/unix_build.sh + env: + TAICHI_CMAKE_ARGS: -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_CC:BOOL=OFF -DTI_WITH_VULKAN:BOOL=ON -DTI_BUILD_TESTS:BOOL=ON -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + PY: ${{ matrix.python }} + CXX: clang++ + + - name: Test + run: | + if [[ ${{needs.check_files.outputs.run_job}} == false ]]; then + exit 0 + fi + export PATH=/Users/github/miniforge3/envs/$PY/bin:$PATH + .github/workflows/scripts/unix_test.sh + env: + TI_WANTED_ARCHS: "metal,vulkan,cpu" + PY: ${{ matrix.python }} + PLATFORM: "m1" diff --git a/.gitignore b/.gitignore index 902a589b6ca7c..fd39d08f9acea 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ __pycache__ *.jpg !docs/**/*.jpg !docs/**/*.png +!tests/python/expected/*.png *.egg-info .tlang_cache /taichi/common/version.h @@ -77,6 +78,10 @@ _build *.ll *.bc *.yml +!.github/**/*.yml *.dot *.json +!tests/**/*.json !docs/**/*.json +imgui.ini +/venv/ diff --git a/.gitmodules b/.gitmodules index e9d6d00209671..be8c04fa168b0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -50,3 +50,9 @@ [submodule "external/SPIRV-Cross"] path = external/SPIRV-Cross url = https://github.com/KhronosGroup/SPIRV-Cross +[submodule "external/Vulkan-Headers"] + path = external/Vulkan-Headers + url = https://github.com/KhronosGroup/Vulkan-Headers +[submodule "external/FP16"] + path = external/FP16 + url = https://github.com/Maratyszcza/FP16 diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b5beab468380..ecfb0117f3e0f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,9 +6,12 @@ cmake_minimum_required(VERSION 3.12) project(taichi) -SET(TI_VERSION_MAJOR ${TI_VERSION_MAJOR}) -SET(TI_VERSION_MINOR ${TI_VERSION_MINOR}) -SET(TI_VERSION_PATCH ${TI_VERSION_PATCH}) +if (NOT DEFINED TI_VERSION_MAJOR) + message(WARNING "It seems that you are running cmake manually, which may cause issues. Please use setup.py to build taichi from source, see https://docs.taichi.graphics/lang/articles/contribution/dev_install for more details.") + set(TI_VERSION_MAJOR 0) + set(TI_VERSION_MINOR 0) + set(TI_VERSION_PATCH 0) +endif() set(CMAKE_CXX_STANDARD 17) @@ -48,10 +51,17 @@ endif () set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build") -include(cmake/PythonNumpyPybind11.cmake) +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") +endif() + +# No support of Python for Android build +if (NOT ANDROID) + include(cmake/PythonNumpyPybind11.cmake) +endif() include(cmake/TaichiCXXFlags.cmake) include(cmake/TaichiCore.cmake) -include(cmake/TaichiMain.cmake) option(TI_BUILD_TESTS "Build the CPP tests" OFF) @@ -60,6 +70,12 @@ if (TI_BUILD_TESTS) include(cmake/TaichiTests.cmake) endif() +option(TI_BUILD_EXAMPLES "Build the CPP examples" ON) + +if (TI_BUILD_EXAMPLES) + include(cmake/TaichiExamples.cmake) +endif() + include_directories(${PROJECT_SOURCE_DIR}/external/eigen) message("C++ Flags: ${CMAKE_CXX_FLAGS}") @@ -70,47 +86,66 @@ if (NOT TI_WITH_CUDA) set(CUDA_TOOLKIT_ROOT_DIR "") endif() -message("python=${PYTHON_EXECUTABLE}") - -add_custom_target( - generate_commit_hash - COMMAND ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_LIST_DIR}/misc/generate_commit_hash.py" -) -add_dependencies(${CORE_LIBRARY_NAME} generate_commit_hash) - if (TI_WITH_CUDA) set(CUDA_ARCH "cuda") endif() -find_program(CLANG_EXECUTABLE NAMES clang clang-7 clang-8 clang-9 clang-10) +if (CLANG_EXECUTABLE) + message("Clang executable ${CLANG_EXECUTABLE}") +elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") + set (CLANG_EXECUTABLE ${CMAKE_CXX_COMPILER}) + message("Clang executable using host compiler ${CLANG_EXECUTABLE}") +else() + find_program(CLANG_EXECUTABLE NAMES clang clang-10 clang-11 clang-9 clang-8 clang-7) + message("Clang executable found at ${CLANG_EXECUTABLE}") +endif() + if (NOT CLANG_EXECUTABLE) - message(FATAL_ERROR "Cannot find any clang executable.") + message(FATAL_ERROR "Cannot find any clang executable.") +endif() + +macro(check_clang_version) + execute_process(COMMAND ${CLANG_EXECUTABLE} --version OUTPUT_VARIABLE CLANG_VERSION_OUTPUT) + string(REGEX MATCH "([0-9]+)\\.[0-9]+(\\.[0-9]+)?" CLANG_VERSION "${CLANG_VERSION_OUTPUT}") + message("${CLANG_EXECUTABLE} --version: ${CLANG_VERSION}") + + set(CLANG_VERSION_MAJOR "${CMAKE_MATCH_1}") +endmacro() + +if (APPLE) + set(CLANG_OSX_FLAGS "-isysroot${CMAKE_OSX_SYSROOT}") + set(CLANG_HIGHEST_VERSION "13") +else() + set(CLANG_HIGHEST_VERSION "11") endif() -find_program(LLVM_AS_EXECUTABLE NAMES llvm-as) -if (NOT LLVM_AS_EXECUTABLE) - message(FATAL_ERROR "Cannot find llvm-as executable.") +check_clang_version() + +if (${CLANG_VERSION_MAJOR} VERSION_GREATER ${CLANG_HIGHEST_VERSION}) + unset(CLANG_EXECUTABLE) + find_program(CLANG_EXECUTABLE NAMES clang-10 clang-11 clang-9 clang-8 clang-7) + if (NOT CLANG_EXECUTABLE) + message(FATAL_ERROR "${CLANG_EXECUTABLE} version: ${CLANG_VERSION}, required: <=${CLANG_HIGHEST_VERSION}. Condider passing -DCLANG_PATH=/path/to/clang to cmake to use a specific clang.") + else() + check_clang_version() + if (${CLANG_VERSION_MAJOR} VERSION_GREATER ${CLANG_HIGHEST_VERSION}) + message(FATAL_ERROR "${CLANG_EXECUTABLE} version: ${CLANG_VERSION}, required: <=${CLANG_HIGHEST_VERSION}. Condider passing -DCLANG_PATH=/path/to/clang to cmake to use a specific clang.") + endif() + endif() endif() # Build llvm-runtime for host arch and cuda (if available) foreach(arch IN LISTS HOST_ARCH CUDA_ARCH) add_custom_target( "generate_llvm_runtime_${arch}" - COMMAND ${CLANG_EXECUTABLE} -S runtime.cpp -o "runtime_${arch}.ll" -fno-exceptions -emit-llvm -std=c++17 -D "ARCH_${arch}" -I ${PROJECT_SOURCE_DIR}; - COMMAND ${LLVM_AS_EXECUTABLE} "runtime_${arch}.ll" -o "runtime_${arch}.bc" + COMMAND ${CLANG_EXECUTABLE} ${CLANG_OSX_FLAGS} -c runtime.cpp -o "runtime_${arch}.bc" -fno-exceptions -emit-llvm -std=c++17 -D "ARCH_${arch}" -I ${PROJECT_SOURCE_DIR}; WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}/taichi/runtime/llvm" ) add_dependencies(${CORE_LIBRARY_NAME} "generate_llvm_runtime_${arch}") endforeach() - -FILE(WRITE ${CMAKE_CURRENT_LIST_DIR}/taichi/common/version.h - "#pragma once\n" - "#define TI_VERSION_MAJOR \"${TI_VERSION_MAJOR}\"\n" - "#define TI_VERSION_MINOR \"${TI_VERSION_MINOR}\"\n" - "#define TI_VERSION_PATCH \"${TI_VERSION_PATCH}\"\n" - "#define TI_CUDAVERSION \"${CUDA_VERSION}\"\n" - ) +configure_file(taichi/common/version.h.in ${CMAKE_SOURCE_DIR}/taichi/common/version.h) +configure_file(taichi/common/commit_hash.h.in ${CMAKE_SOURCE_DIR}/taichi/common/commit_hash.h) option(TI_EXPORT_CORE "export taichi core" OFF) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000..1c6fa35ec9f4c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,13 @@ +# Contributing Guide + +Thank you for your interest in contributing to Taichi! Please check out the [Contribution Guidelines](https://docs.taichi.graphics/lang/articles/contribution/contributor_guide) for how to make a contribution. + +## Developer installation + +Taichi is developed mainly in C++17 and Python3. Please check out the [Developer Installation](https://docs.taichi.graphics/lang/articles/contribution/dev_install) to build Taichi from source. Note that Taichi is LLVM-10.0.0 dependent and that we recommend installing [our pre-built LLVM libraries](https://docs.taichi.graphics/lang/articles/contribution/dev_install#installing-dependencies) for your platform. + +## Contribution opportunities + +Issues marked with ["welcome contribution"](https://github.com/taichi-dev/taichi/issues?q=is%3Aopen+is%3Aissue+label%3A%22welcome+contribution%22) are great places for starters. You can quickly get an idea of the entire workflow and how to join the community. + +**RFC**: We use the `RFC` (Request for Comments) mechanism to discuss and organize some of the more advanced and self-contained features. These are the projects that we would like to work on but still lack a concrete design or implementation roadmap for because of their complexity. We document these requests and the threaded proposals in the hope that we could provide the community with a good enough context and draw upon insights from the potentially passionate minds. You can find all the ongoing RFCs [here](https://github.com/taichi-dev/taichi/issues?q=is%3Aissue+is%3Aopen+label%3ARFC+), and you are also welcome to file new RFCs with us! diff --git a/MANIFEST.in b/MANIFEST.in index 0ca0ff3bd4272..cbc8a16b7dc89 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,14 +1,15 @@ include MANIFEST.in +include version.txt include python/*.txt include python/*.py include *.cfg include python/taichi/*.md include python/taichi/assets/* recursive-include python/taichi/examples *.py -include python/taichi/tests/* -include python/taichi/lib/*.so -include python/taichi/lib/*.pyd -include python/taichi/lib/*.bc +include python/taichi/_lib/core/*.so +include python/taichi/_lib/core/*.pyd +include python/taichi/_lib/runtime/*.bc +include python/taichi/_lib/runtime/*.dylib include python/taichi/shaders/*.spv include python/taichi/shaders/*.vert include python/taichi/shaders/*.frag diff --git a/README.md b/README.md index f6d03d805bd9d..d07cd9dc8345c 100644 --- a/README.md +++ b/README.md @@ -1,89 +1,178 @@ +
- -

Tutorial | Examples | Forum

-

Documentation | 简体中文文档 | Contributor Guidelines

+
-[![AppVeyor Status](https://img.shields.io/appveyor/build/yuanming-hu/taichi?logo=AppVeyor&label=AppVeyor)](https://ci.appveyor.com/project/yuanming-hu/taichi/branch/master) +--- + +[![Latest Release](https://img.shields.io/github/v/release/taichi-dev/taichi?color=blue&label=Latest%20Release)](https://github.com/taichi-dev/taichi/releases/latest) +[![downloads](https://pepy.tech/badge/taichi)](https://pepy.tech/project/taichi) +[![CI](https://github.com/taichi-dev/taichi/actions/workflows/testing.yml/badge.svg)](https://github.com/taichi-dev/taichi/actions/workflows/postsubmit.yml) [![Docker Cloud Build Status](https://img.shields.io/docker/cloud/build/taichidev/taichi?label=Docker%20Image&logo=docker)](https://hub.docker.com/r/taichidev/taichi) [![Python Codecov Status](https://img.shields.io/codecov/c/github/taichi-dev/taichi?label=Python%20Coverage&logo=codecov)](https://codecov.io/gh/taichi-dev/taichi/src/master) -[![Latest Release](https://img.shields.io/github/v/release/taichi-dev/taichi?color=blue&label=Latest%20Release)](https://github.com/taichi-dev/taichi/releases/latest) -## Overview +```py +import taichi as ti +``` -**Taichi** (太极) is a parallel programming language for high-performance numerical computations. It is embedded in **Python**, and its **just-in-time compiler** offloads compute-intensive tasks to multi-core CPUs and massively parallel GPUs. +- [Getting Started](#getting-started) + - [Installation](#installation) + - [Your first Taichi program](#your-first-taichi-program) +- [Documentation](#documentation) +- [Contacts](#contacts) +- [Contributing](#contributing) +- [Resources](#resources) + - [Demos](#demos) + - [Lectures & talks](#lectures--talks) - +**Taichi (太极)** is an open-source, imperative, parallel programming language for high-performance numerical computation. It is embedded in Python and uses just-in-time (JIT) compiler frameworks (e.g. LLVM) to offload compute-intensive Python code to the native GPU or CPU instructions. -Advanced features of Taichi include [spatially sparse computing](https://docs.taichi.graphics/lang/articles/advanced/sparse), [differentiable programming](https://docs.taichi.graphics/lang/articles/advanced/differentiable_programming) [[examples]](https://github.com/yuanming-hu/difftaichi), and [quantized computation](https://github.com/taichi-dev/quantaichi). +Advantages of Taichi: -**Please check out our SIGGRAPH 2020 course on Taichi basics:** [YouTube](https://youtu.be/Y0-76n3aZFA), [Bilibili](https://www.bilibili.com/video/BV1kA411n7jk/), [slides (pdf)](https://yuanming.taichi.graphics/publication/2020-taichi-tutorial/taichi-tutorial.pdf). +- Built around Python: Taichi shares almost the same syntax with Python, allowing you to write algorithms with minimal language barrier. It is also well integrated into the Python ecosystem, such as NumPy and PyTorch. +- Flexibility: Taichi provides a set of generic data containers known as *SNode* (/ˈsnoʊd/), an effective mechanism for composing hierarchical, multi-dimensional fields. This can cover many use patterns in numerical simulation (e.g. [spatially sparse computing](https://docs.taichi.graphics/lang/articles/advanced/sparse)). +- Performance: Through the `@ti.kernel` decorator, Taichi's JIT compiler automatically compiles your Python functions into efficient GPU or CPU machine code for parallel execution. +- Portability: Write your code once and run it everywhere. Currently, Taichi supports most mainstream GPU APIs, such as CUDA and Vulkan. +- ... and many more features! A cross-platform, Vulkan-based 3D visualizer, [differentiable programming](https://docs.taichi.graphics/lang/articles/advanced/differentiable_programming), [quantized computation](https://github.com/taichi-dev/quantaichi) (experimental), etc. -**中文视频教程:** [[哔哩哔哩]](https://www.bilibili.com/video/BV1gA411j7H5), [[幻灯片]](https://yuanming.taichi.graphics/publication/2020-taichi-tutorial/taichi-tutorial.pdf) +# Getting Started -## Examples ([More...](misc/examples.md)) +## Installation - - - - +You can easily install Taichi with Python's package installer `pip`: - - +```bash +pip install taichi +``` -## Installation [![Downloads](https://pepy.tech/badge/taichi)](https://pepy.tech/project/taichi) +If you want to try out the latest features, we also provide a nightly package: ```bash -python3 -m pip install taichi +pip install -i https://test.pypi.org/simple/ taichi-nightly ``` -**Supported OS**: Windows, Linux, Mac OS X; **Python**: 3.6-3.9 (64-bit only); **Backends**: x64 CPUs, CUDA, Apple Metal, Vulkan, OpenGL Compute Shaders. +*The nightly package can and will break from time to time!* -Please build from source for other configurations (e.g., your CPU is ARM, or you want to try out our experimental C backend). +**Supported environments** -**Note:** - - The PyPI package supports x64 CPU, CUDA 10/11, Metal, and OpenGL Compute Shader backends. - - On Ubuntu 19.04+, please `sudo apt install libtinfo5`. - - On Windows, please install [Microsoft Visual C++ Redistributable](https://aka.ms/vs/16/release/vc_redist.x64.exe) if you haven't. - - [[All releases]](https://github.com/taichi-dev/taichi/releases) + +- Operating systems + - Windows[1](#win-note) + - Linux + - macOS +- Python: 3.6 ~ 3.9 (64-bit only) +- Compute backends + - x64/ARM CPUs + - CUDA + - Vulkan + - OpenGL (4.3+) + - Apple Metal + - WebAssembly (experiemental) -|| **Linux (CUDA)** | **OS X (10.14+)** | **Windows** | **Documentation**| -|:------|:-----|:-----|:-----|:-----| -|**Build**|[![Build Status](http://f11.csail.mit.edu:8080/job/taichi/badge/icon)](http://f11.csail.mit.edu:8080/job/taichi/)| [![Build Status](https://travis-ci.com/taichi-dev/taichi.svg?branch=master)](https://travis-ci.com/taichi-dev/taichi) | [![Build status](https://ci.appveyor.com/api/projects/status/yxm0uniin8xty4j7/branch/master?svg=true)](https://ci.appveyor.com/project/yuanming-hu/taichi/branch/master)| [![Netlify Status](https://api.netlify.com/api/v1/badges/6825e411-c5f7-4148-ab43-023663f41b6a/deploy-status)](https://app.netlify.com/sites/docs-taichi-graphics/deploys)| -|**PyPI**|[![Build Status](https://travis-ci.com/yuanming-hu/taichi-wheels-test.svg?branch=master)](https://travis-ci.com/yuanming-hu/taichi-wheels-test)|[![Build Status](https://travis-ci.com/yuanming-hu/taichi-wheels-test.svg?branch=master)](https://travis-ci.com/yuanming-hu/taichi-wheels-test)|[![Build status](https://ci.appveyor.com/api/projects/status/39ar9wa8yd49je7o?svg=true)](https://ci.appveyor.com/project/yuanming-hu/taichi-wheels-test) | +1. On Windows, please install [Microsoft Visual C++ Redistributable](https://aka.ms/vs/16/release/vc_redist.x64.exe) first. -## Developer Installation +## Your first Taichi program -Please follow [this doc](https://docs.taichi.graphics/lang/articles/contribution/dev_install) to learn how to build Taichi from source. Note that Taichi requires LLVM-10.0.0, and it is recommneded to use [our prebuilt LLVM libraries](https://docs.taichi.graphics/lang/articles/contribution/dev_install#installing-dependencies) for each platform. +Here's how you can program a 2D fractal in Taichi: -## Contributors +```py +# python/taichi/examples/simulation/fractal.py - +import taichi as ti -*Note: contributor avatars above are randomly shuffled.* +ti.init(arch=ti.gpu) -------------------------------- +n = 320 +pixels = ti.field(dtype=float, shape=(n * 2, n)) -We welcome feedback and comments. If you would like to contribute to Taichi, please check out our [Contributor Guidelines](https://docs.taichi.graphics/lang/articles/contribution/contributor_guide). -If you use Taichi in your research, please cite related papers: +@ti.func +def complex_sqr(z): + return ti.Vector([z[0]**2 - z[1]**2, z[1] * z[0] * 2]) + + +@ti.kernel +def paint(t: float): + for i, j in pixels: # Parallelized over all pixels + c = ti.Vector([-0.8, ti.cos(t) * 0.2]) + z = ti.Vector([i / n - 1, j / n - 0.5]) * 2 + iterations = 0 + while z.norm() < 20 and iterations < 50: + z = complex_sqr(z) + c + iterations += 1 + pixels[i, j] = 1 - iterations * 0.02 + + +gui = ti.GUI("Julia Set", res=(n * 2, n)) + +for i in range(1000000): + paint(i * 0.03) + gui.set_image(pixels) + gui.show() +``` + +If Taichi is properly installed, you should get the animation below 🎉: + + -- [**(SIGGRAPH Asia 2019) Taichi: High-Performance Computation on Sparse Data Structures**](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang.pdf) [[Video]](https://youtu.be/wKw8LMF3Djo) [[BibTex]](https://raw.githubusercontent.com/taichi-dev/taichi/master/misc/taichi_bibtex.txt) [[Code]](https://github.com/taichi-dev/taichi) -- [**(ICLR 2020) DiffTaichi: Differentiable Programming for Physical Simulation**](https://arxiv.org/abs/1910.00935) [[Video]](https://www.youtube.com/watch?v=Z1xvAZve9aE) [[BibTex]](https://raw.githubusercontent.com/taichi-dev/taichi/master/misc/difftaichi_bibtex.txt) [[Code]](https://github.com/yuanming-hu/difftaichi) -- [**(SIGGRAPH 2021) QuanTaichi: A Compiler for Quantized Simulations**](https://yuanming.taichi.graphics/publication/2021-quantaichi/quantaichi.pdf) [[Video]](https://www.youtube.com/watch?v=0jdrAQOxJlY) [[BibTex]](https://raw.githubusercontent.com/taichi-dev/taichi/master/misc/quantaichi_bibtex.txt) [[Code]](https://github.com/taichi-dev/quantaichi) -## Links -- [TaichiCon](https://github.com/taichi-dev/taichicon): Taichi developer conferences. -- [GAMES 201 Lectures](https://github.com/taichi-dev/games201): (Chinese) A hands-on course on building advanced physics engines, based on Taichi. -- [TaichiZoo](https://zoo.taichi.graphics): Running Taichi code in your browser [1](#zoo-disclaimer). -- [加入太极图形](https://app.mokahr.com/apply/taichi/41024#/). -- [太极图形课](https://github.com/taichiCourse01). +# Documentation + +Taichi's documentation is available at https://docs.taichi.graphics. + +# Contacts + +We use these channels to report bugs, discuss design, show off demos and send announcements on a daily basis: + +- [GitHub Issues](https://github.com/taichi-dev/taichi/issues) +- [GitHub Discussions](https://github.com/taichi-dev/taichi/discussions) +- [Twitter](https://twitter.com/taichigraphics) +- [Taichi 中文论坛](https://forum.taichi.graphics/) +- Slack & Wechat groups: please send us a message at contact@taichi.graphics first, thanks! + +Should you spot any security issue, do not hesitate to report it by mailing to security@taichi.graphics. + +# Contributing + +If you would like to contribute to Taichi, please check out the [Contribution Guidelines](CONTRIBUTING.md). + +A huge thanks to all of our amazing contributors! + + + +*Contributor avatars are randomly shuffled.* + +# Resources + +## Demos + +- [Taichi examples](https://github.com/taichi-dev/taichi/tree/master/python/taichi/examples) +- [Advanced Taichi examples](https://github.com/taichi-dev/advanced_examples) +- [DiffTaichi](https://github.com/taichi-dev/difftaichi) +- [Taichi elements](https://github.com/taichi-dev/taichi_elements) +- [Taichi houdini](https://github.com/taichi-dev/taichi_houdini) - [More...](misc/links.md) -## Security + + + + + + + -Please disclose security issues responsibly to contact@taichi.graphics. +## Lectures & talks +- **SIGGRAPH 2020 course on Taichi basics**: [YouTube](https://youtu.be/Y0-76n3aZFA), [Bilibili](https://www.bilibili.com/video/BV1kA411n7jk/), [slides (pdf)](https://yuanming.taichi.graphics/publication/2020-taichi-tutorial/taichi-tutorial.pdf). +- Chinagraph 2020 用太极编写物理引擎: [哔哩哔哩](https://www.bilibili.com/video/BV1gA411j7H5) +- GAMES 201 高级物理引擎实战指南2020: [课件](https://github.com/taichi-dev/games201) +- 太极图形课第一季:[课件](https://github.com/taichiCourse01) +- [TaichiCon](https://github.com/taichi-dev/taichicon): Taichi developer conferences +- More to come... --- -1. TaichiZoo is still in its Beta version. If you've encountered any issue, please do not hesitate to [file a bug](https://github.com/taichi-dev/taichi-zoo-issue-tracker/issues/new/choose). +If you use Taichi in your research, please cite related papers: + +- [**(SIGGRAPH Asia 2019) Taichi: High-Performance Computation on Sparse Data Structures**](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang.pdf) [[Video]](https://youtu.be/wKw8LMF3Djo) [[BibTex]](https://raw.githubusercontent.com/taichi-dev/taichi/master/misc/taichi_bibtex.txt) [[Code]](https://github.com/taichi-dev/taichi) +- [**(ICLR 2020) DiffTaichi: Differentiable Programming for Physical Simulation**](https://arxiv.org/abs/1910.00935) [[Video]](https://www.youtube.com/watch?v=Z1xvAZve9aE) [[BibTex]](https://raw.githubusercontent.com/taichi-dev/taichi/master/misc/difftaichi_bibtex.txt) [[Code]](https://github.com/yuanming-hu/difftaichi) +- [**(SIGGRAPH 2021) QuanTaichi: A Compiler for Quantized Simulations**](https://yuanming.taichi.graphics/publication/2021-quantaichi/quantaichi.pdf) [[Video]](https://www.youtube.com/watch?v=0jdrAQOxJlY) [[BibTex]](https://raw.githubusercontent.com/taichi-dev/taichi/master/misc/quantaichi_bibtex.txt) [[Code]](https://github.com/taichi-dev/quantaichi) diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index cb426cd204ac6..0000000000000 --- a/appveyor.yml +++ /dev/null @@ -1,65 +0,0 @@ -#---------------------------------# -# general configuration # -#---------------------------------# - -# version format -version: 0.0.{build}-{branch} - -#---------------------------------# -# environment configuration # -#---------------------------------# - -image: Visual Studio 2019 -clone_folder: C:\taichi - -#---------------------------------# -# build configuration # -#---------------------------------# - -platform: x64 -configuration: Release - -environment: - matrix: - - PYTHON: C:\Python36-x64\python.exe - - PYTHON: C:\Python37-x64\python.exe - - PYTHON: C:\Python38-x64\python.exe - - PYTHON: C:\Python39-x64\python.exe - -skip_commits: - files: - - '*.md' - - '*.rst' - - docs - - benchmarks - - examples - - misc - - '.*' - -cache: - - build -> CMakeLists.txt, cmake/* - -build_script: - - set TAICHI_REPO_DIR=C:\taichi - - "%PYTHON% %TAICHI_REPO_DIR%/misc/appveyor_filter.py || appveyor exit 0" - - cd C:\ - - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-msvc2019.zip -LO - - 7z x taichi-llvm-10.0.0-msvc2019.zip -otaichi_llvm - - curl --retry 10 --retry-delay 5 https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip -LO - - "echo \"%APPVEYOR_REPO_COMMIT_MESSAGE%\" | grep '^\\[format\\]' && curl http://kun.csail.mit.edu:31415/%APPVEYOR_PULL_REQUEST_NUMBER% -LO || true" - - 7z x clang-10.0.0-win.zip -otaichi_clang - - set PATH=C:\taichi_llvm\bin;%PATH%; - - set PATH=C:\taichi_clang\bin;%PATH% - - clang --version - - cd C:\taichi - - set CI_SETUP_CMAKE_ARGS=-G "Visual Studio 16 2019" -A x64 -DLLVM_DIR=C:\taichi_llvm\lib\cmake\llvm - # This is to fix the bug caused by execnet 1.2, the library xdist uses to schedule tests. - # Reverting to v1.1 solves this bug(Python Fatal Error: IOError). - - "%PYTHON% -m pip install -U execnet==1.1" - - "%PYTHON% misc/ci_setup.py ci" - - '%PYTHON% -c "import taichi"' - - "%PYTHON% examples/algorithm/laplace.py" - - "%PYTHON% bin/taichi diagnose" - - "%PYTHON% bin/taichi test -Cvr2 -t2" - - "cd python && %PYTHON% build.py try_upload" - - "cd %TAICHI_REPO_DIR% && bash <(curl -s https://codecov.io/bash)" diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000000000..23af1a0f89dea --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,41 @@ +Install a few of extra requirements: +```bash +python3 -m pip install -r requirements.txt +``` + +## Run + +To run all benchmarks: +```bash +python3 run.py +``` + +## Result + +The benchmark results will be stored in the `results` folder in your current directory. +If you wish to save the results as a single json file (`./results/results.json`): +```bash +python3 deserialize.py +``` +Or you can specify the input and output path: +```bash +python3 deserialize.py --folder PATH_OF_RESULTS_FOLDER --output_path PATH_YOU_WIHS_TO_STORE +``` + +## Tools + +After getting benchmark results (`./results`), you can use a visualization tool to profile performance problems: +```bash +python3 visualization.py +``` + +You can specify the results file path: +```bash +python3 visualization.py --folder PATH_OF_RESULTS_FOLDER +``` + +The default host and port is `localhost:5006\visualization`. +If you want to enable remote access, take the following steps: +```bash +python3 visualization.py --host YOUR_IP_ADDRESS --port PORT_YOU_WISH_TO_USE +``` diff --git a/benchmarks/async_advection.py b/benchmarks/async_advection.py deleted file mode 100644 index 916fd8314357a..0000000000000 --- a/benchmarks/async_advection.py +++ /dev/null @@ -1,126 +0,0 @@ -import math - -from utils import benchmark_async - -import taichi as ti - -# TODO: staggerred grid - - -@benchmark_async -def simple_advection(scale): - n = 256 * 2**int((math.log(scale, 2)) // 2) - x = ti.Vector.field(3, dtype=ti.f32, shape=(n, n)) - new_x = ti.Vector.field(3, dtype=ti.f32, shape=(n, n)) - v = ti.Vector.field(2, dtype=ti.f32, shape=(n, n)) - dx = 1 / n - inv_dx = 1 / dx - dt = 0.01 - - stagger = ti.Vector([0.5, 0.5]) - - @ti.func - def Vector2(x, y): - return ti.Vector([x, y]) - - @ti.kernel - def init(): - for i, j in v: - v[i, j] = ti.Vector([j / n - 0.5, 0.5 - i / n]) - - for i, j in ti.ndrange(n * 4, n * 4): - ret = ti.taichi_logo(ti.Vector([i, j]) / (n * 4)) - x[i // 4, j // 4][0] += ret / 16 - x[i // 4, j // 4][1] += ret / 16 - x[i // 4, j // 4][2] += ret / 16 - - @ti.func - def vec(x, y): - return ti.Vector([x, y]) - - @ti.func - def clamp(p): - for d in ti.static(range(p.n)): - p[d] = min(1 - 1e-4 - dx + stagger[d] * dx, - max(p[d], stagger[d] * dx)) - return p - - @ti.func - def sample_bilinear(x, p): - p = clamp(p) - - p_grid = p * inv_dx - stagger - - I = ti.cast(ti.floor(p_grid), ti.i32) - f = p_grid - I - g = 1 - f - - return x[I] * (g[0] * g[1]) + x[I + vec(1, 0)] * (f[0] * g[1]) + x[ - I + vec(0, 1)] * (g[0] * f[1]) + x[I + vec(1, 1)] * (f[0] * f[1]) - - @ti.func - def velocity(p): - return sample_bilinear(v, p) - - @ti.func - def sample_min(x, p): - p = clamp(p) - p_grid = p * inv_dx - stagger - I = ti.cast(ti.floor(p_grid), ti.i32) - - return min(x[I], x[I + vec(1, 0)], x[I + vec(0, 1)], x[I + vec(1, 1)]) - - @ti.func - def sample_max(x, p): - p = clamp(p) - p_grid = p * inv_dx - stagger - I = ti.cast(ti.floor(p_grid), ti.i32) - - return max(x[I], x[I + vec(1, 0)], x[I + vec(0, 1)], x[I + vec(1, 1)]) - - @ti.func - def backtrace(I, dt): # RK3 - p = (I + stagger) * dx - v1 = velocity(p) - p1 = p - 0.5 * dt * v1 - v2 = velocity(p1) - p2 = p - 0.75 * dt * v2 - v3 = velocity(p2) - p -= dt * (2 / 9 * v1 + 1 / 3 * v2 + 4 / 9 * v3) - return p - - @ti.func - def semi_lagrangian(x, new_x, dt): - for I in ti.grouped(x): - new_x[I] = sample_bilinear(x, backtrace(I, dt)) - - @ti.kernel - def advect(): - semi_lagrangian(x(0), new_x(0), dt) - semi_lagrangian(x(1), new_x(1), dt) - semi_lagrangian(x(2), new_x(2), dt) - - for I in ti.grouped(x): - x[I] = new_x[I] - - init() - - def task(): - for i in range(10): - advect() - - ti.benchmark(task, repeat=100) - - visualize = False - - if visualize: - gui = ti.GUI('Advection schemes', (n, n)) - for i in range(10): - for _ in range(10): - advect() - gui.set_image(x.to_numpy()) - gui.show() - - -if __name__ == '__main__': - simple_advection() diff --git a/benchmarks/async_cases.py b/benchmarks/async_cases.py deleted file mode 100644 index af141f424ddb6..0000000000000 --- a/benchmarks/async_cases.py +++ /dev/null @@ -1,374 +0,0 @@ -import math -import os -import sys - -import taichi as ti - -sys.path.append(os.path.join(ti.core.get_repo_dir(), 'tests', 'python')) - -from fuse_test_template import (template_fuse_dense_x2y2z, - template_fuse_reduction) -from utils import * - - -@benchmark_async -def chain_copy(scale): - template_fuse_dense_x2y2z(size=scale * 1024**2, - repeat=1, - benchmark_repeat=100, - benchmark=True) - - -@benchmark_async -def increments(scale): - template_fuse_reduction(size=scale * 1024**2, - repeat=10, - benchmark_repeat=10, - benchmark=True) - - -@benchmark_async -def fill_array(scale): - a = ti.field(dtype=ti.f32, shape=scale * 1024**2) - - @ti.kernel - def fill(): - for i in a: - a[i] = 1.0 - - def repeated_fill(): - for _ in range(10): - fill() - - ti.benchmark(repeated_fill, repeat=10) - - -@benchmark_async -def fill_scalar(scale): - a = ti.field(dtype=ti.f32, shape=()) - - @ti.kernel - def fill(): - a[None] = 1.0 - - def repeated_fill(): - for _ in range(1000): - fill() - - ti.benchmark(repeated_fill, repeat=5) - - -@benchmark_async -def sparse_saxpy(scale): - a = ti.field(dtype=ti.f32) - b = ti.field(dtype=ti.f32) - - block_count = 2**int((math.log(scale, 2)) // 2) * 4 - block_size = 32 - # a, b always share the same sparsity - ti.root.pointer(ti.ij, block_count).dense(ti.ij, block_size).place(a, b) - - @ti.kernel - def initialize(): - for i, j in ti.ndrange(block_count * block_size, - block_count * block_size): - if (i // block_size + j // block_size) % 4 == 0: - a[i, j] = i + j - - @ti.kernel - def saxpy(x: ti.template(), y: ti.template(), alpha: ti.f32): - for i, j in x: - y[i, j] = alpha * x[i, j] + y[i, j] - - def task(): - initialize() - saxpy(a, b, 2) - saxpy(b, a, 1.1) - saxpy(b, a, 1.1) - saxpy(a, b, 1.1) - saxpy(a, b, 1.1) - saxpy(a, b, 1.1) - - ti.benchmark(task, repeat=10) - - -@benchmark_async -def autodiff(scale): - - n = 1024**2 * scale - - a = ti.field(dtype=ti.f32, shape=n, needs_grad=True) - b = ti.field(dtype=ti.f32, shape=n) - loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) - - @ti.kernel - def compute_loss(): - for i in a: - loss[None] += a[i] - - @ti.kernel - def accumulate_grad(): - for i in a: - b[i] += a.grad[i] - - def task(): - for i in range(10): - with ti.Tape(loss=loss): - # The forward kernel of compute_loss should be completely eliminated (except for the last one) - compute_loss() - - accumulate_grad() - - ti.benchmark(task, repeat=10) - - -@benchmark_async -def stencil_reduction(scale): - a = ti.field(dtype=ti.f32) - b = ti.field(dtype=ti.f32) - total = ti.field(dtype=ti.f32, shape=()) - - block_count = scale * 64 - block_size = 1024 - # a, b always share the same sparsity - ti.root.pointer(ti.i, block_count).dense(ti.i, block_size).place(a, b) - - @ti.kernel - def initialize(): - for i in range(block_size, block_size * (block_count - 1)): - a[i] = i - - @ti.kernel - def stencil(): - for i in a: - b[i] = a[i - 1] + a[i] + a[i + 1] - - @ti.kernel - def reduce(): - for i in a: - total[None] += b[i] - - @ti.kernel - def clear_b(): - for i in a: - b[i] = 0 - - def task(): - initialize() - for i in range(3): - stencil() - reduce() - clear_b() - - ti.benchmark(task, repeat=5) - - -@benchmark_async -def mpm_splitted(scale): - quality = int(3 * scale**(1 / 3)) - # Use a larger value for higher-res simulations - - n_particles, n_grid = 9000 * quality**2, 128 * quality - dx, inv_dx = 1 / n_grid, float(n_grid) - dt = 1e-4 / quality - p_vol, p_rho = (dx * 0.5)**2, 1 - p_mass = p_vol * p_rho - E, nu = 0.1e4, 0.2 # Young's modulus and Poisson's ratio - mu_0, lambda_0 = E / (2 * (1 + nu)), E * nu / ( - (1 + nu) * (1 - 2 * nu)) # Lame parameters - x = ti.Vector.field(2, dtype=float, shape=n_particles) # position - v = ti.Vector.field(2, dtype=float, shape=n_particles) # velocity - C = ti.Matrix.field(2, 2, dtype=float, - shape=n_particles) # affine velocity field - F = ti.Matrix.field(2, 2, dtype=float, - shape=n_particles) # deformation gradient - material = ti.field(dtype=int, shape=n_particles) # material id - Jp = ti.field(dtype=float, shape=n_particles) # plastic deformation - grid_v = ti.Vector.field(2, dtype=float, - shape=(n_grid, - n_grid)) # grid node momentum/velocity - grid_m = ti.field(dtype=float, shape=(n_grid, n_grid)) # grid node mass - - @ti.kernel - def substep(): - for i, j in grid_m: - grid_v[i, j] = [0, 0] - grid_m[i, j] = 0 - for p in x: - F[p] = (ti.Matrix.identity(float, 2) + - dt * C[p]) @ F[p] # deformation gradient update - for p in x: # Particle state update and scatter to grid (P2G) - base = (x[p] * inv_dx - 0.5).cast(int) - fx = x[p] * inv_dx - base.cast(float) - # Quadratic kernels [http://mpm.graphics Eqn. 123, with x=fx, fx-1,fx-2] - w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] - h = ti.exp( - 10 * (1.0 - Jp[p]) - ) # Hardening coefficient: snow gets harder when compressed - if material[p] == 1: # jelly, make it softer - h = 0.3 - mu, la = mu_0 * h, lambda_0 * h - if material[p] == 0: # liquid - mu = 0.0 - U, sig, V = ti.svd(F[p]) - J = 1.0 - for d in ti.static(range(2)): - new_sig = sig[d, d] - if material[p] == 2: # Snow - new_sig = min(max(sig[d, d], 1 - 2.5e-2), - 1 + 4.5e-3) # Plasticity - Jp[p] *= sig[d, d] / new_sig - sig[d, d] = new_sig - J *= new_sig - if material[ - p] == 0: # Reset deformation gradient to avoid numerical instability - F[p] = ti.Matrix.identity(float, 2) * ti.sqrt(J) - elif material[p] == 2: - F[p] = U @ sig @ V.transpose( - ) # Reconstruct elastic deformation gradient after plasticity - stress = 2 * mu * (F[p] - U @ V.transpose()) @ F[p].transpose( - ) + ti.Matrix.identity(float, 2) * la * J * (J - 1) - stress = (-dt * p_vol * 4 * inv_dx * inv_dx) * stress - affine = stress + p_mass * C[p] - for i, j in ti.static(ti.ndrange( - 3, 3)): # Loop over 3x3 grid node neighborhood - offset = ti.Vector([i, j]) - dpos = (offset.cast(float) - fx) * dx - weight = w[i][0] * w[j][1] - grid_v[base + - offset] += weight * (p_mass * v[p] + affine @ dpos) - grid_m[base + offset] += weight * p_mass - for i, j in grid_m: - if grid_m[i, j] > 0: # No need for epsilon here - grid_v[i, j] = ( - 1 / grid_m[i, j]) * grid_v[i, j] # Momentum to velocity - grid_v[i, j][1] -= dt * 50 # gravity - for i, j in grid_m: - if grid_m[i, j] > 0: # No need for epsilon here - if i < 3 and grid_v[i, j][0] < 0: - grid_v[i, j][0] = 0 # Boundary conditions - for i, j in grid_m: - if grid_m[i, j] > 0: # No need for epsilon here - if i > n_grid - 3 and grid_v[i, j][0] > 0: grid_v[i, j][0] = 0 - for i, j in grid_m: - if grid_m[i, j] > 0: # No need for epsilon here - if j < 3 and grid_v[i, j][1] < 0: grid_v[i, j][1] = 0 - for i, j in grid_m: - if grid_m[i, j] > 0: # No need for epsilon here - if j > n_grid - 3 and grid_v[i, j][1] > 0: grid_v[i, j][1] = 0 - for p in x: # grid to particle (G2P) - base = (x[p] * inv_dx - 0.5).cast(int) - fx = x[p] * inv_dx - base.cast(float) - w = [ - 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 - ] - new_v = ti.Vector.zero(float, 2) - new_C = ti.Matrix.zero(float, 2, 2) - for i, j in ti.static(ti.ndrange( - 3, 3)): # loop over 3x3 grid node neighborhood - dpos = ti.Vector([i, j]).cast(float) - fx - g_v = grid_v[base + ti.Vector([i, j])] - weight = w[i][0] * w[j][1] - new_v += weight * g_v - new_C += 4 * inv_dx * weight * g_v.outer_product(dpos) - v[p], C[p] = new_v, new_C - for p in x: - x[p] += dt * v[p] # advection - - group_size = n_particles // 3 - - @ti.kernel - def initialize(): - for i in range(n_particles): - x[i] = [ - ti.random() * 0.2 + 0.3 + 0.10 * (i // group_size), - ti.random() * 0.2 + 0.05 + 0.32 * (i // group_size) - ] - material[i] = i // group_size # 0: fluid 1: jelly 2: snow - v[i] = ti.Matrix([0, 0]) - F[i] = ti.Matrix([[1, 0], [0, 1]]) - Jp[i] = 1 - - initialize() - - def task(): - for s in range(int(2e-3 // dt)): - substep() - - ti.benchmark(task, repeat=5) - - -@benchmark_async -def multires(scale): - num_levels = 4 - - x = [] - for i in range(num_levels): - x.append(ti.field(dtype=ti.f32)) - - # TODO: Using 1024 instead of 512 hangs the CUDA case. Need to figure out why. - n = 512 * 1024 * scale - - block_size = 16 - assert n % block_size**2 == 0 - - for i in range(num_levels): - ti.root.pointer(ti.i, n // 2**i // block_size**2).pointer( - ti.i, block_size).dense(ti.i, block_size).place(x[i]) - - @ti.kernel - def initialize(): - for i in range(n): - x[0][i] = i - - @ti.kernel - def downsample(l: ti.template()): - for i in x[l]: - if i % 2 == 0: - x[l + 1][i // 2] = x[l][i] - - initialize() - - def task(): - for l in range(num_levels - 1): - downsample(l) - - ti.benchmark(task, repeat=5) - - -@benchmark_async -def deep_hierarchy(scale): - branching = 4 - num_levels = 8 + int(math.log(scale, branching)) - - x = ti.field(dtype=ti.f32) - - n = 256 * 1024 * scale - - assert n % (branching**num_levels) == 0 - - snode = ti.root - for i in range(num_levels): - snode = snode.pointer(ti.i, branching) - - snode.dense(ti.i, n // (branching**num_levels)).place(x) - - @ti.kernel - def initialize(): - for i in range(n): - x[i] = 0 - - initialize() - - # Not fusible, but no modification to the mask/list of x either - @ti.kernel - def jitter(): - for i in x: - if i % 2 == 0: - x[i] += x[i + 1] - - def task(): - for i in range(5): - jitter() - - ti.benchmark(task, repeat=5) diff --git a/benchmarks/benchmark_async.py b/benchmarks/benchmark_async.py deleted file mode 100644 index 9054631ff4fc2..0000000000000 --- a/benchmarks/benchmark_async.py +++ /dev/null @@ -1,36 +0,0 @@ -from async_advection import * -from async_cases import * - -import taichi as ti - -rerun = True - -cases = [ - chain_copy, increments, fill_array, sparse_saxpy, autodiff, - stencil_reduction, mpm_splitted, simple_advection, multires, deep_hierarchy -] - -if rerun: - for c in cases: - print('*' * 30) - print(f'* Running {c.__name__}') - print('*' * 30) - c() - -case_names = [c.__name__ for c in cases] - -ti.benchmark_plot(fn='benchmark.yml', - cases=case_names, - columns=[ - 'wall_clk_t', 'exec_t', 'launched_tasks', - 'compiled_inst', 'compiled_tasks' - ], - column_titles=[ - 'Wall-clock time', 'Backend time', 'Tasks launched', - 'Instructions emitted', 'Tasks compiled' - ], - archs=['cuda', 'x64'], - title='Whole-Program Optimization Microbenchmarks', - bars='sync_vs_async', - left_margin=0.2, - size=(11.5, 9)) diff --git a/benchmarks/deserialize.py b/benchmarks/deserialize.py new file mode 100644 index 0000000000000..7c09d9dd72ffb --- /dev/null +++ b/benchmarks/deserialize.py @@ -0,0 +1,99 @@ +import argparse +import json +import os +from copy import deepcopy + +from utils import dump2json + + +class ResultsBuilder(): + def __init__(self, results_file_path: str): + self._suites_result = {} + self._file_path = results_file_path + self.load_suites_result() + + def load_suites_result(self): + # benchmark info + info_path = os.path.join(self._file_path, '_info.json') + with open(info_path, 'r') as f: + info_dict = json.load(f)['suites'] + # suite info + for suite_name, attrs in info_dict.items(): + self._suites_result[suite_name] = {} + for arch in attrs['archs']: + self._suites_result[suite_name][arch] = {} + suite_info_path = os.path.join(self._file_path, suite_name, + arch, "_info.json") + with open(suite_info_path, 'r') as f: + suite_info_dict = json.load(f) + # case info + for case_name in suite_info_dict: + items = suite_info_dict[case_name] + items.pop('name') + items['metrics'] = items.pop('get_metric') + self._suites_result[suite_name][arch][case_name] = { + 'items': items + } + # cases result + for suite_name in self._suites_result: + for arch in self._suites_result[suite_name]: + for case_name in self._suites_result[suite_name][arch]: + case_info_path = os.path.join(self._file_path, suite_name, + arch, case_name + ".json") + with open(case_info_path, 'r') as f: + case_results = json.load(f) + remove_none_list = [] + for name, data in case_results.items(): + # remove case_name + data['tags'] = data['tags'][1:] + if data['result'] is None: + remove_none_list.append(name) + for name in remove_none_list: + case_results.pop(name) + self._suites_result[suite_name][arch][case_name][ + 'results'] = case_results + + def get_suites_result(self): + return self._suites_result + + def save_results_as_json(self, costomized_dir=None): + file_path = os.path.join(self._file_path, 'results.json') + if costomized_dir != None: + file_path = os.path.join(costomized_dir, 'results.json') + with open(file_path, 'w') as f: + print(dump2json(self._suites_result), file=f) + + def print_info(self): + # remove 'results' in self._suites_result, then print + info_dict = deepcopy(self._suites_result) + for suite_name in info_dict: + for arch in info_dict[suite_name]: + for case in info_dict[suite_name][arch]: + info_dict[suite_name][arch][case].pop('results') + print(dump2json(info_dict)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('-f', + '--folder', + default='./results', + dest='folder', + type=str, + help='Path of result folder. Defaults to ./results') + + parser.add_argument('-o', + '--output_path', + default='./results', + dest='output_path', + type=str, + help='Path of result folder. Defaults to ./results') + + args = parser.parse_args() + result_folder = args.folder + output_path = args.output_path + + results = ResultsBuilder(result_folder) + results.save_results_as_json(output_path) + results.print_info() diff --git a/benchmarks/fill_dense.py b/benchmarks/fill_dense.py deleted file mode 100644 index 452c5075fe5d8..0000000000000 --- a/benchmarks/fill_dense.py +++ /dev/null @@ -1,124 +0,0 @@ -import taichi as ti - -# originally by @KLozes - - -@ti.test() -def benchmark_flat_struct(): - N = 4096 - a = ti.field(dtype=ti.f32, shape=(N, N)) - - @ti.kernel - def fill(): - for i, j in a: - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=500) - - -@ti.test() -def benchmark_flat_range(): - N = 4096 - a = ti.field(dtype=ti.f32, shape=(N, N)) - - @ti.kernel - def fill(): - for i, j in ti.ndrange(N, N): - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=700) - - -@ti.test() -def benchmark_nested_struct(): - a = ti.field(dtype=ti.f32) - N = 512 - - ti.root.dense(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for i, j in a: - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=700) - - -@ti.test() -def benchmark_nested_struct_listgen_8x8(): - a = ti.field(dtype=ti.f32) - ti.cfg.demote_dense_struct_fors = False - N = 512 - - ti.root.dense(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for i, j in a: - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=1000) - - -@ti.test() -def benchmark_nested_struct_listgen_16x16(): - a = ti.field(dtype=ti.f32) - ti.cfg.demote_dense_struct_fors = False - N = 256 - - ti.root.dense(ti.ij, [N, N]).dense(ti.ij, [16, 16]).place(a) - - @ti.kernel - def fill(): - for i, j in a: - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=700) - - -@ti.test() -def benchmark_nested_range_blocked(): - a = ti.field(dtype=ti.f32) - N = 512 - - ti.root.dense(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for X in range(N * N): - for Y in range(64): - a[X // N * 8 + Y // 8, X % N * 8 + Y % 8] = 2.0 - - return ti.benchmark(fill, repeat=800) - - -@ti.test() -def benchmark_nested_range(): - a = ti.field(dtype=ti.f32) - N = 512 - - ti.root.dense(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for j in range(N * 8): - for i in range(N * 8): - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=1000) - - -@ti.test() -def benchmark_root_listgen(): - a = ti.field(dtype=ti.f32) - ti.cfg.demote_dense_struct_fors = False - N = 512 - - ti.root.dense(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for i, j in a.parent(): - a[i, j] = 2.0 - - return ti.benchmark(fill, repeat=800) diff --git a/benchmarks/fill_sparse.py b/benchmarks/fill_sparse.py deleted file mode 100644 index 64d15a65cb001..0000000000000 --- a/benchmarks/fill_sparse.py +++ /dev/null @@ -1,42 +0,0 @@ -import taichi as ti - - -@ti.archs_support_sparse -def benchmark_nested_struct(): - a = ti.field(dtype=ti.f32) - N = 512 - - ti.root.pointer(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for i, j in ti.ndrange(N * 8, N * 8): - a[i, j] = 2.0 - - fill() - - return ti.benchmark(fill) - - -@ti.archs_support_sparse -def benchmark_nested_struct_fill_and_clear(): - a = ti.field(dtype=ti.f32) - N = 512 - - ti.root.pointer(ti.ij, [N, N]).dense(ti.ij, [8, 8]).place(a) - - @ti.kernel - def fill(): - for i, j in ti.ndrange(N * 8, N * 8): - a[i, j] = 2.0 - - @ti.kernel - def clear(): - for i, j in a.parent(): - ti.deactivate(a.parent().parent(), [i, j]) - - def task(): - fill() - clear() - - return ti.benchmark(task, repeat=30) diff --git a/benchmarks/memory_bound.py b/benchmarks/memory_bound.py deleted file mode 100644 index 6bcd200099c12..0000000000000 --- a/benchmarks/memory_bound.py +++ /dev/null @@ -1,59 +0,0 @@ -import taichi as ti - -N = 1024**3 // 4 # 1 GB per buffer - - -# 4 B/it -@ti.test(exclude=ti.opengl) -def benchmark_memset(): - a = ti.field(dtype=ti.f32, shape=N) - - @ti.kernel - def memset(): - for i in a: - a[i] = 1.0 - - return ti.benchmark(memset, repeat=10) - - -# 8 B/it -@ti.test(exclude=ti.opengl) -def benchmark_sscal(): - a = ti.field(dtype=ti.f32, shape=N) - - @ti.kernel - def task(): - for i in a: - a[i] = 0.5 * a[i] - - return ti.benchmark(task, repeat=10) - - -# 8 B/it -@ti.test(exclude=ti.opengl) -def benchmark_memcpy(): - a = ti.field(dtype=ti.f32, shape=N) - b = ti.field(dtype=ti.f32, shape=N) - - @ti.kernel - def memcpy(): - for i in a: - a[i] = b[i] - - return ti.benchmark(memcpy, repeat=10) - - -# 12 B/it -@ti.test(exclude=ti.opengl) -def benchmark_saxpy(): - x = ti.field(dtype=ti.f32, shape=N) - y = ti.field(dtype=ti.f32, shape=N) - z = ti.field(dtype=ti.f32, shape=N) - - @ti.kernel - def task(): - for i in x: - a = 123 - z[i] = a * x[i] + y[i] - - return ti.benchmark(task, repeat=10) diff --git a/benchmarks/microbenchmarks/__init__.py b/benchmarks/microbenchmarks/__init__.py new file mode 100644 index 0000000000000..adbfa19ade092 --- /dev/null +++ b/benchmarks/microbenchmarks/__init__.py @@ -0,0 +1,12 @@ +from .atomic_ops import AtomicOpsPlan +from .fill import FillPlan +from .math_opts import MathOpsPlan +from .matrix_ops import MatrixOpsPlan +from .memcpy import MemcpyPlan +from .saxpy import SaxpyPlan +from .stencil2d import Stencil2DPlan + +benchmark_plan_list = [ + AtomicOpsPlan, FillPlan, MathOpsPlan, MatrixOpsPlan, MemcpyPlan, SaxpyPlan, + Stencil2DPlan +] diff --git a/benchmarks/microbenchmarks/_items.py b/benchmarks/microbenchmarks/_items.py new file mode 100644 index 0000000000000..af9b6747f46b6 --- /dev/null +++ b/benchmarks/microbenchmarks/_items.py @@ -0,0 +1,116 @@ +from microbenchmarks._utils import size2tag + +import taichi as ti + + +class BenchmarkItem: + name = 'item' + + def __init__(self): + self._items = {} # {'tag': impl, ...} + + def get(self): + return self._items + + def get_tags(self): + return list(self._items.keys()) + + def impl(self, tag: str): + return self._items[tag] + + def remove(self, tags: list): + for tag in tags: + self._items.pop(tag) + + def update(self, adict: dict): + self._items.update(adict) + + +class DataType(BenchmarkItem): + name = 'dtype' + integer_list = ['i32', 'i64'] + + def __init__(self): + self._items = { + str(ti.i32): ti.i32, + str(ti.i64): ti.i64, + str(ti.f32): ti.f32, + str(ti.f64): ti.f64 + } + + def remove_integer(self): + self.remove(self.integer_list) + + @staticmethod + def is_integer(dtype: str): + integer_list = ['i32', 'u32', 'i64', 'u64'] + return True if dtype in integer_list else False + + +class DataSize(BenchmarkItem): + name = 'dsize' + + def __init__(self): + self._items = {} + for i in range(2, 10, 2): # [16KB,256KB,4MB,64MB] + size_bytes = (4**i) * 1024 # kibibytes(KiB) = 1024 + self._items[size2tag(size_bytes)] = size_bytes + + +class Container(BenchmarkItem): + name = 'container' + + def __init__(self): + self._items = {'field': ti.field, 'ndarray': ti.ndarray} + + +class MathOps(BenchmarkItem): + name = 'math_op' + + #reference: https://docs.taichi.graphics/lang/articles/basic/operator + def __init__(self): + self._items = { + # Trigonometric + 'sin': ti.sin, + 'cos': ti.cos, + 'tan': ti.tan, + 'asin': ti.asin, + 'acos': ti.acos, + 'tanh': ti.tanh, + # Other arithmetic + 'sqrt': ti.sqrt, + 'rsqrt': ti.rsqrt, # A fast version for `1 / ti.sqrt(x)`. + 'exp': ti.exp, + 'log': ti.log, + 'round': ti.round, + 'floor': ti.floor, + 'ceil': ti.ceil, + 'abs': ti.abs, + } + + +class AtomicOps(BenchmarkItem): + name = 'atomic_op' + + def __init__(self): + self._items = { + 'atomic_add': ti.atomic_add, + 'atomic_sub': ti.atomic_sub, + 'atomic_and': ti.atomic_and, + 'atomic_or': ti.atomic_or, + 'atomic_xor': ti.atomic_xor, + 'atomic_max': ti.atomic_max, + 'atomic_min': ti.atomic_min + } + + @staticmethod + def is_logical_op(op: str): + logical_op_list = ['atomic_and', 'atomic_or', 'atomic_xor'] + return True if op in logical_op_list else False + + @staticmethod + def is_supported_type(op: str, dtype: str): + if AtomicOps.is_logical_op(op) and not DataType.is_integer(dtype): + return False + else: + return True diff --git a/benchmarks/microbenchmarks/_metric.py b/benchmarks/microbenchmarks/_metric.py new file mode 100644 index 0000000000000..46c008c1e7949 --- /dev/null +++ b/benchmarks/microbenchmarks/_metric.py @@ -0,0 +1,47 @@ +from microbenchmarks._items import BenchmarkItem +from microbenchmarks._utils import End2EndTimer, get_ti_arch + +import taichi as ti + + +def end2end_executor(repeat, func, *args): + # compile & warmup + for i in range(repeat): + func(*args) + + timer = End2EndTimer() + timer.tick() + for i in range(repeat): + func(*args) + time_in_s = timer.tock() + return time_in_s * 1000 / repeat #ms + + +def kernel_executor(repeat, func, *args): + # compile & warmup + for i in range(repeat): + func(*args) + ti.clear_kernel_profile_info() + for i in range(repeat): + func(*args) + return ti.kernel_profiler_total_time() * 1000 / repeat #ms + + +class MetricType(BenchmarkItem): + name = 'get_metric' + + def __init__(self): + self._items = { + 'kernel_elapsed_time_ms': kernel_executor, + 'end2end_time_ms': end2end_executor + } + + @staticmethod + def init_taichi(arch: str, tag_list: list): + if set(['kernel_elapsed_time_ms']).issubset(tag_list): + ti.init(kernel_profiler=True, arch=get_ti_arch(arch)) + elif set(['end2end_time_ms']).issubset(tag_list): + ti.init(kernel_profiler=False, arch=get_ti_arch(arch)) + else: + return False + return True diff --git a/benchmarks/microbenchmarks/_plan.py b/benchmarks/microbenchmarks/_plan.py new file mode 100644 index 0000000000000..52af1002c6530 --- /dev/null +++ b/benchmarks/microbenchmarks/_plan.py @@ -0,0 +1,89 @@ +import itertools + +from microbenchmarks._items import AtomicOps, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._utils import get_ti_arch, tags2name + +import taichi as ti + + +class Funcs(): + def __init__(self): + self._funcs = {} + + def add_func(self, tag_list: list, func): + self._funcs[tags2name(tag_list)] = {'tags': tag_list, 'func': func} + + def get_func(self, tags): + for name, item in self._funcs.items(): + if set(item['tags']).issubset(tags): + return item['func'] + return None + + +class BenchmarkPlan: + def __init__(self, name='plan', arch='x64', basic_repeat_times=1): + self.name = name + self.arch = arch + self.basic_repeat_times = basic_repeat_times + self.info = {'name': self.name} + self.plan = {} # {'tags': [...], 'result': None} + self.items = {} + self.funcs = Funcs() + + def create_plan(self, *items): + items_list = [[self.name]] + for item in items: + self.items[item.name] = item + items_list.append(item.get_tags()) + self.info[item.name] = item.get_tags() + case_list = list(itertools.product(*items_list)) #items generate cases + for tags in case_list: + self.plan[tags2name(tags)] = {'tags': tags, 'result': None} + self._remove_conflict_items() + + def add_func(self, tag_list, func): + self.funcs.add_func(tag_list, func) + + def run(self): + for case, plan in self.plan.items(): + tag_list = plan['tags'] + MetricType.init_taichi(self.arch, tag_list) + _ms = self.funcs.get_func(tag_list)(self.arch, + self.basic_repeat_times, + **self._get_kwargs(tag_list)) + plan['result'] = _ms + print(f'{tag_list}={_ms}') + ti.reset() + rdict = {'results': self.plan, 'info': self.info} + return rdict + + def _get_kwargs(self, tags, impl=True): + kwargs = {} + tags = tags[1:] # tags = [case_name, item1_tag, item2_tag, ...] + for item, tag in zip(self.items.values(), tags): + kwargs[item.name] = item.impl(tag) if impl == True else tag + return kwargs + + def _remove_conflict_items(self): + remove_list = [] + #logical_atomic with float_type + if set([AtomicOps.name, DataType.name]).issubset(self.items.keys()): + for name, case in self.plan.items(): + kwargs_tag = self._get_kwargs(case['tags'], impl=False) + atomic_tag = kwargs_tag[AtomicOps.name] + dtype_tag = kwargs_tag[DataType.name] + if not AtomicOps.is_supported_type(atomic_tag, dtype_tag): + remove_list.append(name) + #remove + for name in remove_list: + self.plan.pop(name) + + def remove_cases_with_tags(self, tags: list): + remove_list = [] + for case, plan in self.plan.items(): + if set(tags).issubset(plan['tags']): + remove_list.append(case) + #remove + for case in remove_list: + self.plan.pop(case) diff --git a/benchmarks/microbenchmarks/_utils.py b/benchmarks/microbenchmarks/_utils.py new file mode 100644 index 0000000000000..8bea2f2821ea2 --- /dev/null +++ b/benchmarks/microbenchmarks/_utils.py @@ -0,0 +1,87 @@ +from time import perf_counter + +from taichi._lib import core as ti_core + +import taichi as ti + + +class End2EndTimer: + def __init__(self): + self._ts1 = 0 + self._ts2 = 0 + + def tick(self): + ti.sync() + self._ts1 = perf_counter() + return self._ts1 + + def tock(self): + ti.sync() + self._ts2 = perf_counter() + return self._ts2 - self._ts1 + + +def size2tag(size_in_byte): + size_subsection = [(0.0, 'B'), (1024.0, 'KB'), (1048576.0, 'MB'), + (1073741824.0, 'GB'), + (float('inf'), 'INF')] #B KB MB GB + for dsize, unit in reversed(size_subsection): + if size_in_byte >= dsize: + return str(int(size_in_byte / dsize)) + unit + + +def tags2name(tag_list): + return '_'.join(tag_list) + + +def dtype_size(ti_dtype): + dtype_size_dict = {ti.i32: 4, ti.i64: 8, ti.f32: 4, ti.f64: 8} + if ti_dtype not in dtype_size_dict: + raise RuntimeError('Unsupported ti.dtype: ' + str(type(ti_dtype))) + else: + return dtype_size_dict[ti_dtype] + + +def get_ti_arch(arch: str): + arch_dict = { + 'cuda': ti.cuda, + 'vulkan': ti.vulkan, + 'opengl': ti.opengl, + 'metal': ti.metal, + 'x64': ti.x64, + 'cc': ti.cc + } + return arch_dict[arch] + + +def scaled_repeat_times(arch: str, datasize, repeat=1): + if (arch == 'cuda') | (arch == 'vulkan') | (arch == 'opengl'): + repeat *= 10 + if datasize <= 4 * 1024 * 1024: + repeat *= 10 + return repeat + + +def fill_random(dst, dtype, container): + @ti.kernel + def fill_template(dst: ti.template()): + for I in ti.grouped(dst): + dst[I] = ti.random(dtype) + + @ti.kernel + def fill_1d_array(dst: ti.any_arr()): + for i in dst: + dst[i] = ti.random(dtype) + + @ti.kernel + def fill_2d_array(dst: ti.any_arr()): + for i, j in dst: + dst[i, j] = ti.random(dtype) + + if container == ti.ndarray: + if len(dst.shape) == 1: + fill_1d_array(dst) + elif len(dst.shape) == 2: + fill_2d_array(dst) + else: + fill_template(dst) diff --git a/benchmarks/microbenchmarks/atomic_ops.py b/benchmarks/microbenchmarks/atomic_ops.py new file mode 100644 index 0000000000000..9b78728bb4445 --- /dev/null +++ b/benchmarks/microbenchmarks/atomic_ops.py @@ -0,0 +1,42 @@ +from microbenchmarks._items import AtomicOps, Container, DataSize, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan +from microbenchmarks._utils import dtype_size, fill_random, scaled_repeat_times + +import taichi as ti + + +def reduction_default(arch, repeat, atomic_op, container, dtype, dsize, + get_metric): + repeat = scaled_repeat_times(arch, dsize, repeat) + num_elements = dsize // dtype_size(dtype) + + x = container(dtype, shape=num_elements) + y = container(dtype, shape=()) + y[None] = 0 + + @ti.kernel + def reduction_field(y: ti.template(), x: ti.template()): + for i in x: + atomic_op(y[None], x[i]) + + @ti.kernel + def reduction_array(y: ti.any_arr(), x: ti.any_arr()): + for i in x: + atomic_op(y[None], x[i]) + + fill_random(x, dtype, container) + func = reduction_field if container == ti.field else reduction_array + return get_metric(repeat, func, y, x) + + +class AtomicOpsPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('atomic_ops', arch, basic_repeat_times=10) + atomic_ops = AtomicOps() + atomic_ops.remove( + ['atomic_sub', 'atomic_and', 'atomic_xor', 'atomic_max']) + self.create_plan(atomic_ops, Container(), DataType(), DataSize(), + MetricType()) + self.add_func(['field'], reduction_default) + self.add_func(['ndarray'], reduction_default) diff --git a/benchmarks/microbenchmarks/fill.py b/benchmarks/microbenchmarks/fill.py new file mode 100644 index 0000000000000..5d1dad91e8c82 --- /dev/null +++ b/benchmarks/microbenchmarks/fill.py @@ -0,0 +1,60 @@ +from microbenchmarks._items import BenchmarkItem, Container, DataSize, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan +from microbenchmarks._utils import dtype_size, scaled_repeat_times + +import taichi as ti + + +def fill_default(arch, repeat, container, dtype, dsize, get_metric): + @ti.kernel + def fill_field(dst: ti.template()): + for I in ti.grouped(dst): + dst[I] = ti.cast(0.7, dtype) + + @ti.kernel + def fill_array(dst: ti.any_arr()): + for i in dst: + dst[i] = ti.cast(0.7, dtype) + + repeat = scaled_repeat_times(arch, dsize, repeat) + num_elements = dsize // dtype_size(dtype) + x = container(dtype, num_elements) + func = fill_field if container == ti.field else fill_array + return get_metric(repeat, func, x) + + +def fill_sparse(arch, repeat, container, dtype, dsize, get_metric): + repeat = scaled_repeat_times(arch, dsize, repeat=1) + # basic_repeat_time = 1: sparse-specific parameter + num_elements = dsize // dtype_size(dtype) // 8 + + block = ti.root.pointer(ti.i, num_elements) + x = ti.field(dtype) + block.dense(ti.i, 8).place(x) + + @ti.kernel + def active_all(): + for i in ti.ndrange(num_elements): + ti.activate(block, [i]) + + active_all() + + @ti.kernel + def fill_const(dst: ti.template()): + for i in x: + dst[i] = ti.cast(0.7, dtype) + + return get_metric(repeat, fill_const, x) + + +class FillPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('fill', arch, basic_repeat_times=10) + fill_container = Container() + fill_container.update({'sparse': None}) # None: implement by feature + self.create_plan(fill_container, DataType(), DataSize(), MetricType()) + # use tag_list to label the customized implementation (funcs). + self.add_func(['field'], fill_default) + self.add_func(['ndarray'], fill_default) + self.add_func(['sparse'], fill_sparse) diff --git a/benchmarks/microbenchmarks/math_opts.py b/benchmarks/microbenchmarks/math_opts.py new file mode 100644 index 0000000000000..cf06d0bbcc500 --- /dev/null +++ b/benchmarks/microbenchmarks/math_opts.py @@ -0,0 +1,62 @@ +from microbenchmarks._items import BenchmarkItem, DataType, MathOps +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan +from microbenchmarks._utils import dtype_size, scaled_repeat_times + +import taichi as ti + + +def unary_ops_throughput_default(arch, repeat, math_op, dtype, element_num, + thread_for_loop, get_metric): + local_data_num = 16 #enough data to fill the instruction pipeline + global_vector = ti.Vector.field(local_data_num, dtype, element_num) + + @ti.kernel + def op_throughput(): + for e in global_vector: + local_vector = global_vector[e] + #loop + for j in range(thread_for_loop): + for k in ti.static(range(local_data_num)): + local_vector[k] = math_op(local_vector[k]) + #epilogue + global_vector[e] = local_vector + + @ti.kernel + def fill_random(): + for e in global_vector: + for i in ti.static(range(local_data_num)): + global_vector[e][i] = ti.random() + + fill_random() + return get_metric(repeat, op_throughput) + + +class ElementNum(BenchmarkItem): + name = 'element_num' + + def __init__(self): + self._items = { + 'element16384': 16384, + #enough threads for filling CUDA cores + } + + +class ForLoopCycle(BenchmarkItem): + name = 'thread_for_loop' + + def __init__(self): + self._items = {} + for i in range(1, 7): + cycles = 4 * pow(2, i) # [8 16 32 64 128 256] + self._items['threadloop' + str(cycles)] = cycles + + +class MathOpsPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('math_ops', arch, basic_repeat_times=10) + math_dtype = DataType() + math_dtype.remove_integer() + self.create_plan(MathOps(), math_dtype, ElementNum(), ForLoopCycle(), + MetricType()) + self.add_func(['element16384'], unary_ops_throughput_default) diff --git a/benchmarks/microbenchmarks/matrix_ops.py b/benchmarks/microbenchmarks/matrix_ops.py new file mode 100644 index 0000000000000..df163b959917f --- /dev/null +++ b/benchmarks/microbenchmarks/matrix_ops.py @@ -0,0 +1,109 @@ +from microbenchmarks._items import BenchmarkItem, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan + +import taichi as ti + + +def matrix_operations_default(arch, repeat, matrix_op, block_mn, element_num, + dtype, get_metric): + m, n = block_mn + global_matrixA = ti.Matrix.field(m, n, dtype, shape=element_num) + global_matrixB = ti.Matrix.field(m, n, dtype, shape=element_num) + global_matrixC = ti.Matrix.field(m, n, dtype, shape=element_num) + + @ti.kernel + def fill_matrixA(): + for e in global_matrixA: + for i, j in ti.static(range(m, n)): + global_matrixA[e][i, j] = ti.random(dtype) + + @ti.kernel + def fill_matrixB(): + for e in global_matrixB: + for i, j in ti.static(range(m, n)): + global_matrixB[e][i, j] = ti.random(dtype) + + @ti.kernel + def op_throughput(): + for e in range(element_num): + #prelogue + A = global_matrixA[e] + B = global_matrixB[e] + C = ti.Matrix.zero(dtype, m, n) + C = matrix_op(C, A, B) #C += A@B + #loop + for i in range(2048): + for j in ti.static(range(4)): #16*4*4=256 + A = matrix_op(A, C, B) #A += C@B + C = matrix_op(C, A, B) #C += A@B + B = matrix_op(B, A, C) #B += A@C + C = matrix_op(C, A, B) #C += A@B + #epilogue + global_matrixC[e] = C + + fill_matrixA() + fill_matrixB() + return get_metric(repeat, op_throughput) + + +@ti.func +def matrix_add(C, A, B): + C = A + B + return C + + +@ti.func +def matrix_mul(C, A, B): + C = A @ B + return C + + +@ti.func +def matrix_mma(C, A, B): + """matrix multiply and add""" + C = A @ B + C + return C + + +class MatrixOps(BenchmarkItem): + name = 'matrix_op' + + def __init__(self): + self._items = { + 'mat_add': matrix_add, + 'mat_mul': matrix_mul, + 'mat_mma': matrix_mma, + } + + +class BlockMN(BenchmarkItem): + name = 'block_mn' + + def __init__(self): + self._items = { + 'block_mn11': (1, 1), + 'block_mn22': (2, 2), + 'block_mn33': (3, 3), + 'block_mn44': (4, 4), + } + + +class ElementNum(BenchmarkItem): + name = 'element_num' + + def __init__(self): + self._items = { + 'element16384': 16384, + #enough threads for filling CUDA cores + } + + +class MatrixOpsPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('matrix_ops', arch, basic_repeat_times=10) + dtype = DataType() + dtype.remove(['i64', 'f64']) + self.create_plan(MatrixOps(), BlockMN(), ElementNum(), dtype, + MetricType()) + self.add_func(['element16384'], matrix_operations_default) diff --git a/benchmarks/microbenchmarks/memcpy.py b/benchmarks/microbenchmarks/memcpy.py new file mode 100644 index 0000000000000..ee7b654bf9c2a --- /dev/null +++ b/benchmarks/microbenchmarks/memcpy.py @@ -0,0 +1,36 @@ +from microbenchmarks._items import Container, DataSize, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan +from microbenchmarks._utils import dtype_size, fill_random, scaled_repeat_times + +import taichi as ti + + +def memcpy_default(arch, repeat, container, dtype, dsize, get_metric): + @ti.kernel + def memcpy_field(dst: ti.template(), src: ti.template()): + for I in ti.grouped(dst): + dst[I] = src[I] + + @ti.kernel + def memcpy_array(dst: ti.any_arr(), src: ti.any_arr()): + for I in ti.grouped(dst): + dst[I] = src[I] + + repeat = scaled_repeat_times(arch, dsize, repeat) + num_elements = dsize // dtype_size(dtype) // 2 # y=x + + x = container(dtype, num_elements) + y = container(dtype, num_elements) + + func = memcpy_field if container == ti.field else memcpy_array + fill_random(x, dtype, container) + return get_metric(repeat, func, y, x) + + +class MemcpyPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('memcpy', arch, basic_repeat_times=10) + self.create_plan(Container(), DataType(), DataSize(), MetricType()) + self.add_func(['field'], memcpy_default) + self.add_func(['ndarray'], memcpy_default) diff --git a/benchmarks/microbenchmarks/saxpy.py b/benchmarks/microbenchmarks/saxpy.py new file mode 100644 index 0000000000000..de8b32909af51 --- /dev/null +++ b/benchmarks/microbenchmarks/saxpy.py @@ -0,0 +1,39 @@ +from microbenchmarks._items import Container, DataSize, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan +from microbenchmarks._utils import dtype_size, fill_random, scaled_repeat_times + +import taichi as ti + + +def saxpy_default(arch, repeat, container, dtype, dsize, get_metric): + + repeat = scaled_repeat_times(arch, dsize, repeat) + num_elements = dsize // dtype_size(dtype) // 3 #z=x+y + + x = container(dtype, num_elements) + y = container(dtype, num_elements) + z = container(dtype, num_elements) + + @ti.kernel + def saxpy_field(z: ti.template(), x: ti.template(), y: ti.template()): + for i in z: + z[i] = 17 * x[i] + y[i] + + @ti.kernel + def saxpy_array(z: ti.any_arr(), x: ti.any_arr(), y: ti.any_arr()): + for i in z: + z[i] = 17 * x[i] + y[i] + + fill_random(x, dtype, container) + fill_random(y, dtype, container) + func = saxpy_field if container == ti.field else saxpy_array + return get_metric(repeat, func, z, x, y) + + +class SaxpyPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('saxpy', arch, basic_repeat_times=10) + self.create_plan(Container(), DataType(), DataSize(), MetricType()) + self.add_func(['field'], saxpy_default) + self.add_func(['ndarray'], saxpy_default) diff --git a/benchmarks/microbenchmarks/stencil2d.py b/benchmarks/microbenchmarks/stencil2d.py new file mode 100644 index 0000000000000..02cf21b2a3f66 --- /dev/null +++ b/benchmarks/microbenchmarks/stencil2d.py @@ -0,0 +1,135 @@ +from microbenchmarks._items import BenchmarkItem, Container, DataType +from microbenchmarks._metric import MetricType +from microbenchmarks._plan import BenchmarkPlan +from microbenchmarks._utils import (dtype_size, fill_random, + scaled_repeat_times, size2tag) + +import taichi as ti + +stencil_common = [(0, 0), (0, -1), (0, 1), (1, 0)] + + +def stencil_2d_default(arch, repeat, scatter, bls, container, dtype, dsize_2d, + get_metric): + + dsize = dsize_2d[0] * dsize_2d[1] + repeat = scaled_repeat_times(arch, dsize, repeat) + num_elements_2d = (dsize_2d[0] // dtype_size(dtype), dsize_2d[1] // 2) + + y = container(dtype, shape=num_elements_2d) + x = container(dtype, shape=num_elements_2d) + + @ti.kernel + def stencil_2d_field(y: ti.template(), x: ti.template()): + for I in ti.grouped(x): + if ti.static(scatter): + for offset in ti.static(stencil_common): + y[I + ti.Vector(offset)] += x[I] + else: # gather + s = ti.cast(0.0, dtype) + for offset in ti.static(stencil_common): + s = s + x[I + ti.Vector(offset)] + y[I] = s + + @ti.kernel + def stencil_2d_array(y: ti.any_arr(), x: ti.any_arr()): + for I in ti.grouped(x): + if ti.static(scatter): + for offset in ti.static(stencil_common): + y[I + ti.Vector(offset)] += x[I] + else: # gather + s = ti.cast(0.0, dtype) + for offset in ti.static(stencil_common): + s = s + x[I + ti.Vector(offset)] + y[I] = s + + fill_random(x, dtype, container) + func = stencil_2d_field if container == ti.field else stencil_2d_array + return get_metric(repeat, func, y, x) + + +def stencil_2d_sparse_bls(arch, repeat, scatter, bls, container, dtype, + dsize_2d, get_metric): + + dsize = dsize_2d[0] * dsize_2d[1] + if dsize <= 4096 or dsize > 67108864: # 16KB <= dsize <= 64 MB: Sparse-specific parameters + return None + repeat = scaled_repeat_times( + arch, dsize, 1) # basic_repeat_time = 1: Sparse-specific parameters + block_elements_2d = (dsize_2d[0] // dtype_size(dtype) // 8, + dsize_2d[1] // 2 // 8) + + block = ti.root.pointer(ti.ij, block_elements_2d) + y = ti.field(dtype) + x = ti.field(dtype) + block.dense(ti.ij, 8).place(y) + block.dense(ti.ij, 8).place(x) + + @ti.kernel + def active_all(): + for i, j in ti.ndrange(block_elements_2d[0], block_elements_2d[0]): + ti.activate(block, [i, j]) + + active_all() + + @ti.kernel + def stencil_2d(y: ti.template(), x: ti.template()): + #reference: tests/python/bls_test_template.py + if ti.static(bls and not scatter): + ti.block_local(x) + if ti.static(bls and scatter): + ti.block_local(y) + ti.block_dim(64) # 8*8=64 + + for I in ti.grouped(x): + if ti.static(scatter): + for offset in ti.static(stencil_common): + y[I + ti.Vector(offset)] += x[I] + else: # gather + s = ti.cast(0.0, dtype) + for offset in ti.static(stencil_common): + s = s + x[I + ti.Vector(offset)] + y[I] = s + + fill_random(x, dtype, container) + return get_metric(repeat, stencil_2d, y, x) + + +class Scatter(BenchmarkItem): + name = 'scatter' + + def __init__(self): + self._items = {'scatter': True, 'gether': False} + + +class BloclLocalStorage(BenchmarkItem): + name = 'bls' + + def __init__(self): + self._items = {'bls_on': True, 'bls_off': False} + + +class DataSize2D(BenchmarkItem): + name = 'dsize_2d' + + def __init__(self): + self._items = {} + for i in range(2, 10, 2): # [16KB,256KB,4MB,64MB] + size_bytes_2d = 32 * (2**i), 32 * (2**i) + size_bytes = size_bytes_2d[0] * size_bytes_2d[1] + self._items[size2tag(size_bytes)] = size_bytes_2d + + +class Stencil2DPlan(BenchmarkPlan): + def __init__(self, arch: str): + super().__init__('stencil_2d', arch, basic_repeat_times=10) + container = Container() + container.update({'sparse': None}) # None: implement by feature + self.create_plan(Scatter(), BloclLocalStorage(), container, DataType(), + DataSize2D(), MetricType()) + # no use for field & ndarray + self.remove_cases_with_tags(['field', 'bls1']) + self.remove_cases_with_tags(['ndarray', 'bls1']) + self.add_func(['field'], stencil_2d_default) + self.add_func(['ndarray'], stencil_2d_default) + self.add_func(['sparse'], stencil_2d_sparse_bls) diff --git a/benchmarks/minimal.py b/benchmarks/minimal.py deleted file mode 100644 index fa75cc6a9513c..0000000000000 --- a/benchmarks/minimal.py +++ /dev/null @@ -1,12 +0,0 @@ -import taichi as ti - - -@ti.test() -def benchmark_fill_scalar(): - a = ti.field(dtype=ti.f32, shape=()) - - @ti.kernel - def fill(): - a[None] = 1.0 - - return ti.benchmark(fill, repeat=1000) diff --git a/benchmarks/misc/membound.py b/benchmarks/misc/membound.py deleted file mode 100644 index 55bfd3c9942e3..0000000000000 --- a/benchmarks/misc/membound.py +++ /dev/null @@ -1,107 +0,0 @@ -import time - -from membound_cases import fill, reduction, saxpy -from utils import * - -import taichi as ti - -test_cases = [fill, saxpy, reduction] -test_archs = [ti.cuda] -test_dtype = [ti.i32, ti.i64, ti.f32, ti.f64] -test_dsize = [(4**i) * kibibyte for i in range(1, 11)] #[4KB,16KB...1GB] -test_repeat = 10 -results_evaluation = [geometric_mean] - - -class BenchmarkResult: - def __init__(self, name, arch, dtype, dsize, results_evaluation): - self.test_name = name - self.test_arch = arch - self.data_type = dtype - self.data_size = dsize - self.min_time_in_us = [] - self.results_evaluation = results_evaluation - - def time2mdtableline(self): - string = '|' + self.test_name + '.' + dtype2str[self.data_type] + '|' - string += ''.join( - str(round(time, 4)) + '|' for time in self.min_time_in_us) - string += ''.join( - str(round(item(self.min_time_in_us), 4)) + '|' - for item in self.results_evaluation) - return string - - -class BenchmarkImpl: - def __init__(self, func, archs, data_type, data_size): - self.func = func - self.name = func.__name__ - self.env = None - self.device = None - self.archs = archs - self.data_type = data_type - self.data_size = data_size - self.benchmark_results = [] - - def run(self): - for arch in self.archs: - for dtype in self.data_type: - ti.init(kernel_profiler=True, arch=arch) - print("TestCase[%s.%s.%s]" % - (self.func.__name__, ti.core.arch_name(arch), - dtype2str[dtype])) - result = BenchmarkResult(self.name, arch, dtype, - self.data_size, results_evaluation) - for size in self.data_size: - print("data_size = %s" % (size2str(size))) - result.min_time_in_us.append( - self.func(arch, dtype, size, test_repeat)) - time.sleep(0.2) - self.benchmark_results.append(result) - - def print(self): - i = 0 - for arch in self.archs: - for dtype in self.data_type: - for idx in range(len(self.data_size)): - print( - " test_case:[%s] arch:[%s] dtype:[%s] dsize:[%7s] >>> time:[%4.4f]" - % - (self.name, ti.core.arch_name(arch), dtype2str[dtype], - size2str(self.benchmark_results[i].data_size[idx]), - self.benchmark_results[i].min_time_in_us[idx])) - i = i + 1 - - def save2markdown(self, arch): - header = '|kernel elapsed time(ms)' + ''.join( - '|' for i in range(len(self.data_size) + len(results_evaluation))) - lines = [header] - for result in self.benchmark_results: - if (result.test_arch == arch): - lines.append(result.time2mdtableline()) - return lines - - -class Membound: - benchmark_imps = [] - - def __init__(self): - for case in test_cases: - self.benchmark_imps.append( - BenchmarkImpl(case, test_archs, test_dtype, test_dsize)) - - def run(self): - for case in self.benchmark_imps: - case.run() - - def mdlines(self, arch): - lines = [] - lines += md_table_header(self.__class__.__name__, arch, test_dsize, - test_repeat, results_evaluation) - for case in self.benchmark_imps: - if arch in case.archs: - lines += case.save2markdown(arch) - else: - continue - lines.append('') - return lines diff --git a/benchmarks/misc/membound_cases.py b/benchmarks/misc/membound_cases.py deleted file mode 100644 index b02b76382f008..0000000000000 --- a/benchmarks/misc/membound_cases.py +++ /dev/null @@ -1,76 +0,0 @@ -from utils import dtype_size, scale_repeat, size2str - -import taichi as ti - - -def init_const(x, dtype, num_elements): - @ti.kernel - def init_const(x: ti.template(), n: ti.i32): - for i in range(n): - x[i] = ti.cast(0.7, dtype) - - init_const(x, num_elements) - - -def membound_benchmark(func, num_elements, repeat): - # compile the kernel first - func(num_elements) - ti.clear_kernel_profile_info() - for i in range(repeat): - func(num_elements) - kernelname = func.__name__ - quering_result = ti.query_kernel_profile_info(kernelname) - return quering_result.min - - -def fill(arch, dtype, dsize, repeat=10): - - repeat = scale_repeat(arch, dsize, repeat) - num_elements = dsize // dtype_size[dtype] - - x = ti.field(dtype, shape=num_elements) - - @ti.kernel - def fill_const(n: ti.i32): - for i in range(n): - x[i] = ti.cast(0.7, dtype) - - return membound_benchmark(fill_const, num_elements, repeat) - - -def saxpy(arch, dtype, dsize, repeat=10): - - repeat = scale_repeat(arch, dsize, repeat) - num_elements = dsize // dtype_size[dtype] // 3 #z=x+y - - x = ti.field(dtype, shape=num_elements) - y = ti.field(dtype, shape=num_elements) - z = ti.field(dtype, shape=num_elements) - - @ti.kernel - def saxpy(n: ti.i32): - for i in range(n): - z[i] = 17 * x[i] + y[i] - - init_const(x, dtype, num_elements) - init_const(y, dtype, num_elements) - - return membound_benchmark(saxpy, num_elements, repeat) - - -def reduction(arch, dtype, dsize, repeat=10): - - repeat = scale_repeat(arch, dsize, repeat) - num_elements = dsize // dtype_size[dtype] - - x = ti.field(dtype, shape=num_elements) - y = ti.field(dtype, shape=()) - y[None] = 0 - - @ti.kernel - def reduction(n: ti.i32): - for i in range(n): - y[None] += x[i] - - init_const(x, dtype, num_elements) - return membound_benchmark(reduction, num_elements, repeat) diff --git a/benchmarks/misc/run.py b/benchmarks/misc/run.py deleted file mode 100644 index 69fff53d82a5c..0000000000000 --- a/benchmarks/misc/run.py +++ /dev/null @@ -1,33 +0,0 @@ -from membound import Membound - -import taichi as ti - -test_suites = [Membound] -test_archs = [ti.cuda] - - -class PerformanceMonitoring: - suites = [] - - def __init__(self): - for s in test_suites: - self.suites.append(s()) - - def run(self): - print("Running...") - for s in self.suites: - s.run() - - def write_md(self): - filename = f'performance_result.md' - with open(filename, 'w') as f: - for arch in test_archs: - for s in self.suites: - lines = s.mdlines(arch) - for line in lines: - print(line, file=f) - - -p = PerformanceMonitoring() -p.run() -p.write_md() diff --git a/benchmarks/misc/utils.py b/benchmarks/misc/utils.py deleted file mode 100644 index 5e6206116b34d..0000000000000 --- a/benchmarks/misc/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import taichi as ti - -kibibyte = 1024 - -bls2str = {False: "BLS_off", True: "BLS_on"} -dense2str = {False: "Struct_for", True: "Range_for"} - -dtype2str = {ti.i32: "i32", ti.i64: "i64", ti.f32: "f32", ti.f64: "f64"} -dtype_size = {ti.i32: 4, ti.i64: 8, ti.f32: 4, ti.f64: 8} - -# for output string -size_subsection = [(0.0, 'B'), (1024.0, 'KB'), (1048576.0, 'MB'), - (1073741824.0, 'GB'), (float('inf'), 'INF')] #B KB MB GB - - -def size2str(size_in_byte): - for dsize, units in reversed(size_subsection): - if size_in_byte >= dsize: - return str(round(size_in_byte / dsize, 4)) + units - - -def scale_repeat(arch, datasize, repeat=10): - scaled = repeat - if (arch == ti.gpu) | (arch == ti.opengl) | (arch == ti.cuda): - scaled *= 10 - if datasize <= 4 * 1024 * 1024: - scaled *= 10 - return scaled - - -def geometric_mean(data_array): - product = 1 - for data in data_array: - product *= data - return pow(product, 1.0 / len(data_array)) - - -def md_table_header(suite_name, arch, test_dsize, test_repeat, - results_evaluation): - header = '|' + suite_name + '.' + ti.core.arch_name(arch) + '|' - header += ''.join('|' for i in range(len(test_dsize))) - header += ''.join(item.__name__ + '|' for item in results_evaluation) - - layout = '|:--:|' - layout += ''.join( - ':--:|' for i in range(len(test_dsize) + len(results_evaluation))) - - size = '|**data size**|' - size += ''.join(size2str(size) + '|' for size in test_dsize) - size += ''.join('|' for i in range(len(results_evaluation))) - - repeat = '|**repeat**|' - repeat += ''.join( - str(scale_repeat(arch, size, test_repeat)) + '|' - for size in test_dsize) - repeat += ''.join('|' for i in range(len(results_evaluation))) - - lines = [header, layout, size, repeat] - return lines diff --git a/benchmarks/mpm2d.py b/benchmarks/mpm2d.py deleted file mode 100644 index 4daa34d63e960..0000000000000 --- a/benchmarks/mpm2d.py +++ /dev/null @@ -1,241 +0,0 @@ -import time - -import numpy as np - -import taichi as ti - - -@ti.test() -def benchmark_range(): - quality = 1 # Use a larger value for higher-res simulations - n_particles, n_grid = 9000 * quality**2, 128 * quality - dx, inv_dx = 1 / n_grid, float(n_grid) - dt = 1e-4 / quality - p_vol, p_rho = (dx * 0.5)**2, 1 - p_mass = p_vol * p_rho - E, nu = 0.1e4, 0.2 # Young's modulus and Poisson's ratio - mu_0, lambda_0 = E / (2 * (1 + nu)), E * nu / ( - (1 + nu) * (1 - 2 * nu)) # Lame parameters - - x = ti.Vector.field(2, dtype=ti.f32, shape=n_particles) # position - v = ti.Vector.field(2, dtype=ti.f32, shape=n_particles) # velocity - C = ti.Matrix.field(2, 2, dtype=ti.f32, - shape=n_particles) # affine velocity field - F = ti.Matrix.field(2, 2, dtype=ti.f32, - shape=n_particles) # deformation gradient - material = ti.field(dtype=int, shape=n_particles) # material id - Jp = ti.field(dtype=ti.f32, shape=n_particles) # plastic deformation - grid_v = ti.Vector.field(2, dtype=ti.f32, - shape=(n_grid, - n_grid)) # grid node momemtum/velocity - grid_m = ti.field(dtype=ti.f32, shape=(n_grid, n_grid)) # grid node mass - - @ti.kernel - def substep(): - for i, j in ti.ndrange(n_grid, n_grid): - grid_v[i, j] = [0, 0] - grid_m[i, j] = 0 - for p in range(n_particles - ): # Particle state update and scatter to grid (P2G) - base = (x[p] * inv_dx - 0.5).cast(int) - fx = x[p] * inv_dx - base.cast(float) - # Quadratic kernels [http://mpm.graphics Eqn. 123, with x=fx, fx-1,fx-2] - w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] - F[p] = (ti.Matrix.identity(ti.f32, 2) + - dt * C[p]) @ F[p] # deformation gradient update - h = ti.exp( - 10 * (1.0 - Jp[p]) - ) # Hardening coefficient: snow gets harder when compressed - if material[p] == 1: # jelly, make it softer - h = 0.3 - mu, la = mu_0 * h, lambda_0 * h - if material[p] == 0: # liquid - mu = 0.0 - U, sig, V = ti.svd(F[p]) - J = 1.0 - for d in ti.static(range(2)): - new_sig = sig[d, d] - if material[p] == 2: # Snow - new_sig = min(max(sig[d, d], 1 - 2.5e-2), - 1 + 4.5e-3) # Plasticity - Jp[p] *= sig[d, d] / new_sig - sig[d, d] = new_sig - J *= new_sig - if material[ - p] == 0: # Reset deformation gradient to avoid numerical instability - F[p] = ti.Matrix.identity(ti.f32, 2) * ti.sqrt(J) - elif material[p] == 2: - F[p] = U @ sig @ V.T( - ) # Reconstruct elastic deformation gradient after plasticity - stress = 2 * mu * (F[p] - U @ V.T()) @ F[p].T( - ) + ti.Matrix.identity(ti.f32, 2) * la * J * (J - 1) - stress = (-dt * p_vol * 4 * inv_dx * inv_dx) * stress - affine = stress + p_mass * C[p] - for i, j in ti.static(ti.ndrange( - 3, 3)): # Loop over 3x3 grid node neighborhood - offset = ti.Vector([i, j]) - dpos = (offset.cast(float) - fx) * dx - weight = w[i][0] * w[j][1] - grid_v[base + - offset] += weight * (p_mass * v[p] + affine @ dpos) - grid_m[base + offset] += weight * p_mass - for i, j in ti.ndrange(n_grid, n_grid): - if grid_m[i, j] > 0: # No need for epsilon here - grid_v[i, j] = ( - 1 / grid_m[i, j]) * grid_v[i, j] # Momentum to velocity - grid_v[i, j][1] -= dt * 50 # gravity - if i < 3 and grid_v[i, j][0] < 0: - grid_v[i, j][0] = 0 # Boundary conditions - if i > n_grid - 3 and grid_v[i, j][0] > 0: grid_v[i, j][0] = 0 - if j < 3 and grid_v[i, j][1] < 0: grid_v[i, j][1] = 0 - if j > n_grid - 3 and grid_v[i, j][1] > 0: grid_v[i, j][1] = 0 - for p in range(n_particles): # grid to particle (G2P) - base = (x[p] * inv_dx - 0.5).cast(int) - fx = x[p] * inv_dx - base.cast(float) - w = [ - 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 - ] - new_v = ti.Vector.zero(ti.f32, 2) - new_C = ti.Matrix.zero(ti.f32, 2, 2) - for i, j in ti.static(ti.ndrange( - 3, 3)): # loop over 3x3 grid node neighborhood - dpos = ti.Vector([i, j]).cast(float) - fx - g_v = grid_v[base + ti.Vector([i, j])] - weight = w[i][0] * w[j][1] - new_v += weight * g_v - new_C += 4 * inv_dx * weight * g_v.outer_product(dpos) - v[p], C[p] = new_v, new_C - x[p] += dt * v[p] # advection - - import random - group_size = n_particles // 3 - for i in range(n_particles): - x[i] = [ - random.random() * 0.2 + 0.3 + 0.10 * (i // group_size), - random.random() * 0.2 + 0.05 + 0.32 * (i // group_size) - ] - material[i] = i // group_size # 0: fluid 1: jelly 2: snow - v[i] = [0, 0] - F[i] = [[1, 0], [0, 1]] - Jp[i] = 1 - - ti.benchmark(substep, repeat=4000) - - -@ti.test(exclude=ti.opengl) -def benchmark_struct(): - quality = 1 # Use a larger value for higher-res simulations - n_particles, n_grid = 9000 * quality**2, 128 * quality - dx, inv_dx = 1 / n_grid, float(n_grid) - dt = 1e-4 / quality - p_vol, p_rho = (dx * 0.5)**2, 1 - p_mass = p_vol * p_rho - E, nu = 0.1e4, 0.2 # Young's modulus and Poisson's ratio - mu_0, lambda_0 = E / (2 * (1 + nu)), E * nu / ( - (1 + nu) * (1 - 2 * nu)) # Lame parameters - - x = ti.Vector.field(2, dtype=ti.f32, shape=n_particles) # position - v = ti.Vector.field(2, dtype=ti.f32, shape=n_particles) # velocity - C = ti.Matrix.field(2, 2, dtype=ti.f32, - shape=n_particles) # affine velocity field - F = ti.Matrix.field(2, 2, dtype=ti.f32, - shape=n_particles) # deformation gradient - material = ti.field(dtype=int, shape=n_particles) # material id - Jp = ti.field(dtype=ti.f32, shape=n_particles) # plastic deformation - grid_v = ti.Vector.field(2, dtype=ti.f32, - shape=(n_grid, - n_grid)) # grid node momemtum/velocity - grid_m = ti.field(dtype=ti.f32, shape=(n_grid, n_grid)) # grid node mass - - @ti.kernel - def substep(): - for i, j in grid_m: - grid_v[i, j] = [0, 0] - grid_m[i, j] = 0 - - for p in x: # Particle state update and scatter to grid (P2G) - base = (x[p] * inv_dx - 0.5).cast(int) - fx = x[p] * inv_dx - base.cast(float) - # Quadratic kernels [http://mpm.graphics Eqn. 123, with x=fx, fx-1,fx-2] - w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] - F[p] = (ti.Matrix.identity(ti.f32, 2) + - dt * C[p]) @ F[p] # deformation gradient update - h = ti.exp( - 10 * (1.0 - Jp[p]) - ) # Hardening coefficient: snow gets harder when compressed - if material[p] == 1: # jelly, make it softer - h = 0.3 - mu, la = mu_0 * h, lambda_0 * h - if material[p] == 0: # liquid - mu = 0.0 - U, sig, V = ti.svd(F[p]) - J = 1.0 - for d in ti.static(range(2)): - new_sig = sig[d, d] - if material[p] == 2: # Snow - new_sig = min(max(sig[d, d], 1 - 2.5e-2), - 1 + 4.5e-3) # Plasticity - Jp[p] *= sig[d, d] / new_sig - sig[d, d] = new_sig - J *= new_sig - if material[ - p] == 0: # Reset deformation gradient to avoid numerical instability - F[p] = ti.Matrix.identity(ti.f32, 2) * ti.sqrt(J) - elif material[p] == 2: - F[p] = U @ sig @ V.T( - ) # Reconstruct elastic deformation gradient after plasticity - stress = 2 * mu * (F[p] - U @ V.T()) @ F[p].T( - ) + ti.Matrix.identity(ti.f32, 2) * la * J * (J - 1) - stress = (-dt * p_vol * 4 * inv_dx * inv_dx) * stress - affine = stress + p_mass * C[p] - for i, j in ti.static(ti.ndrange( - 3, 3)): # Loop over 3x3 grid node neighborhood - offset = ti.Vector([i, j]) - dpos = (offset.cast(float) - fx) * dx - weight = w[i][0] * w[j][1] - grid_v[base + - offset] += weight * (p_mass * v[p] + affine @ dpos) - grid_m[base + offset] += weight * p_mass - - for i, j in grid_m: - if grid_m[i, j] > 0: # No need for epsilon here - grid_v[i, j] = ( - 1 / grid_m[i, j]) * grid_v[i, j] # Momentum to velocity - grid_v[i, j][1] -= dt * 50 # gravity - if i < 3 and grid_v[i, j][0] < 0: - grid_v[i, j][0] = 0 # Boundary conditions - if i > n_grid - 3 and grid_v[i, j][0] > 0: grid_v[i, j][0] = 0 - if j < 3 and grid_v[i, j][1] < 0: grid_v[i, j][1] = 0 - if j > n_grid - 3 and grid_v[i, j][1] > 0: grid_v[i, j][1] = 0 - - for p in x: # grid to particle (G2P) - base = (x[p] * inv_dx - 0.5).cast(int) - fx = x[p] * inv_dx - base.cast(float) - w = [ - 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 - ] - new_v = ti.Vector.zero(ti.f32, 2) - new_C = ti.Matrix.zero(ti.f32, 2, 2) - for i, j in ti.static(ti.ndrange( - 3, 3)): # loop over 3x3 grid node neighborhood - dpos = ti.Vector([i, j]).cast(float) - fx - g_v = grid_v[base + ti.Vector([i, j])] - weight = w[i][0] * w[j][1] - new_v += weight * g_v - new_C += 4 * inv_dx * weight * g_v.outer_product(dpos) - v[p], C[p] = new_v, new_C - x[p] += dt * v[p] # advection - - import random - group_size = n_particles // 3 - for i in range(n_particles): - x[i] = [ - random.random() * 0.2 + 0.3 + 0.10 * (i // group_size), - random.random() * 0.2 + 0.05 + 0.32 * (i // group_size) - ] - material[i] = i // group_size # 0: fluid 1: jelly 2: snow - v[i] = [0, 0] - F[i] = [[1, 0], [0, 1]] - Jp[i] = 1 - - ti.benchmark(substep, repeat=4000) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 0000000000000..625298e5a2b64 --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,2 @@ +jsbeautifier +bokeh diff --git a/benchmarks/run.py b/benchmarks/run.py index f7d7215d6cfd4..5c6574dd22cac 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -1,68 +1,62 @@ import os +import warnings -import taichi as ti +from suite_microbenchmarks import MicroBenchmark +from taichi._lib import core as ti_core +from utils import datatime_with_format, dump2json +benchmark_suites = [MicroBenchmark] -def get_benchmark_dir(): - return os.path.join(ti.core.get_repo_dir(), 'benchmarks') - -class Case: - def __init__(self, name, func): - self.name = name - self.func = func - self.records = {} - - def __lt__(self, other): - return self.name < other.name - - def __eq__(self, other): - return self.name == other.name - - def run(self): - print(f'==> {self.name}:') - os.environ['TI_CURRENT_BENCHMARK'] = self.name - self.func() - - -class Suite: - def __init__(self, filename): - self.cases = [] - print(filename) - self.name = filename[:-3] - loc = {} - exec(f'import {self.name} as suite', {}, loc) - suite = loc['suite'] - case_keys = list( - sorted(filter(lambda x: x.startswith('benchmark_'), dir(suite)))) - self.cases = [Case(k, getattr(suite, k)) for k in case_keys] - - def run(self): - print(f'{self.name}:') - for case in sorted(self.cases): - case.run() +class BenchmarkInfo: + def __init__(self): + """init with commit info""" + self.commit_hash = ti_core.get_commit_hash() + self.datetime = datatime_with_format() + self.suites = {} + print(f'commit_hash = {self.commit_hash}') -class TaichiBenchmark: +class BenchmarkSuites: def __init__(self): - self.suites = [] - benchmark_dir = get_benchmark_dir() - for f in map(os.path.basename, sorted(os.listdir(benchmark_dir))): - if f != 'run.py' and f.endswith('.py') and f[0] != '_': - self.suites.append(Suite(f)) + self._suites = [] + for suite in benchmark_suites: + self._suites.append(suite()) def run(self): - output_dir = os.environ.get('TI_BENCHMARK_OUTPUT_DIR', '.') - filename = f'{output_dir}/benchmark.yml' - try: - with open(filename, 'r+') as f: - f.truncate() # clear the previous result - except FileNotFoundError: - pass - print("Running...") - for s in self.suites: - s.run() - - -b = TaichiBenchmark() -b.run() + for suite in self._suites: + suite.run() + + def save(self, benchmark_dir='./'): + for suite in self._suites: + suite_dir = os.path.join(benchmark_dir, suite.suite_name) + os.makedirs(suite_dir, exist_ok=True) + suite.save_as_json(suite_dir) + + def get_suites_info(self): + info_dict = {} + for suite in self._suites: + info_dict[suite.suite_name] = suite.get_benchmark_info() + return info_dict + + +def main(): + + benchmark_dir = os.path.join(os.getcwd(), 'results') + os.makedirs(benchmark_dir, exist_ok=True) + + #init & run + info = BenchmarkInfo() + suites = BenchmarkSuites() + suites.run() + #save benchmark results & info + suites.save(benchmark_dir) + info.suites = suites.get_suites_info() + info_path = os.path.join(benchmark_dir, '_info.json') + info_str = dump2json(info) + with open(info_path, 'w') as f: + print(info_str, file=f) + + +if __name__ == '__main__': + main() diff --git a/benchmarks/suite_microbenchmarks.py b/benchmarks/suite_microbenchmarks.py new file mode 100644 index 0000000000000..dc8e8305404c1 --- /dev/null +++ b/benchmarks/suite_microbenchmarks.py @@ -0,0 +1,66 @@ +import os +import time + +from microbenchmarks import benchmark_plan_list +from utils import dump2json + + +class MicroBenchmark: + suite_name = 'microbenchmarks' + config = { + 'cuda': { + 'enable': True + }, + 'vulkan': { + 'enable': False + }, + 'opengl': { + 'enable': False + } + } + + def __init__(self): + self._results = {} + self._info = {} + + def get_benchmark_info(self): + info_dict = {} + arch_list = [] + for arch, item in self.config.items(): + if item['enable'] == True: + arch_list.append(arch) + info_dict['archs'] = arch_list + return info_dict + + def run(self): + for arch, item in self.config.items(): + if item['enable'] == True: + arch_results = {} + self._info[arch] = {} + for plan in benchmark_plan_list: + plan_impl = plan(arch) + results = plan_impl.run() + self._info[arch][plan_impl.name] = results['info'] + arch_results[plan_impl.name] = results['results'] + + self._results[arch] = arch_results + + def save_as_json(self, suite_dir='./'): + for arch in self._results: + arch_dir = os.path.join(suite_dir, arch) + os.makedirs(arch_dir, exist_ok=True) + self._save_info_as_json(arch, arch_dir) + self._save_cases_as_json(arch, arch_dir) + + def _save_info_as_json(self, arch, arch_dir='./'): + info_path = os.path.join(arch_dir, '_info.json') + with open(info_path, 'w') as f: + print(dump2json(self._info[arch]), file=f) + + def _save_cases_as_json(self, arch, arch_dir='./'): + for case in self._info[arch]: + case_path = os.path.join(arch_dir, (case + '.json')) + case_results = self._results[arch][case] + with open(case_path, 'w') as f: + case_str = dump2json(case_results) + print(case_str, file=f) diff --git a/benchmarks/summarize_async.py b/benchmarks/summarize_async.py deleted file mode 100644 index e3ddf4ada499f..0000000000000 --- a/benchmarks/summarize_async.py +++ /dev/null @@ -1,31 +0,0 @@ -import math - -import yaml - -with open('benchmark.yml') as f: - data = yaml.load(f) - -records = {} - -for case in data: - for metrics in ['exec_t', 'launched_tasks']: - for arch in ['x64', 'cuda']: - key = metrics, arch - - d = [] - for scheme in ['sync', 'async']: - d.append(data[case][metrics][arch][scheme]) - - if key not in records: - records[key] = [] - - records[key].append(d[0] / d[1]) - -for k in records: - rec = records[k] - # Compute geometric mean - p = 1 - for r in rec: - p *= r - p = p**(1 / len(rec)) - print(f'{k}, {p:.3f}x') diff --git a/benchmarks/utils.py b/benchmarks/utils.py index a08a32eecc324..cc9d62371b1fe 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,25 +1,21 @@ +import datetime import functools +import json import os -import taichi as ti +import jsbeautifier -def benchmark_async(func): - @functools.wraps(func) - def body(): - for arch in [ti.cpu, ti.cuda]: - for async_mode in [True, False]: - os.environ['TI_CURRENT_BENCHMARK'] = func.__name__ - ti.init(arch=arch, - async_mode=async_mode, - kernel_profiler=True, - verbose=False) - if arch == ti.cpu: - scale = 2 - else: - # Use more data to hide compilation overhead - # (since CUDA runs much faster than CPUs) - scale = 64 - func(scale) +def get_benchmark_dir(): + return os.path.dirname(os.path.realpath(__file__)) - return body + +def dump2json(obj): + obj2dict = obj if type(obj) is dict else obj.__dict__ + options = jsbeautifier.default_options() + options.indent_size = 4 + return jsbeautifier.beautify(json.dumps(obj2dict), options) + + +def datatime_with_format(): + return datetime.datetime.now().isoformat() diff --git a/build.ps1 b/build.ps1 new file mode 100644 index 0000000000000..9857d6d0928db --- /dev/null +++ b/build.ps1 @@ -0,0 +1,3 @@ +$stopwatch = [system.diagnostics.stopwatch]::startNew() +python setup.py develop +$stopwatch.Elapsed diff --git a/ci/Dockerfile b/ci/Dockerfile deleted file mode 100644 index bd8eb13fbfdd2..0000000000000 --- a/ci/Dockerfile +++ /dev/null @@ -1,82 +0,0 @@ -# Taichi Dockerfile for development -ARG UBUNTU -FROM nvidia/cuda:${UBUNTU} -ARG PYTHON -ARG TEST_OPTION -ARG PYPI_PWD -ARG COMMIT_SHA -ENV PYPI_PWD=$PYPI_PWD -ENV DEBIAN_FRONTEND=noninteractive -LABEL maintainer="https://github.com/taichi-dev" - -RUN apt-get update && \ - apt-get install -y software-properties-common \ - $PYTHON \ - python3-pip \ - ${PYTHON}-dev\ - libtinfo-dev \ - clang-10 \ - wget \ - git \ - libx11-dev \ - libxrandr-dev \ - libxinerama-dev \ - libxcursor-dev \ - libxi-dev \ - libglu1-mesa-dev \ - freeglut3-dev \ - mesa-common-dev \ - build-essential \ - libssl-dev \ - libidn11-dev \ - libz-dev \ - unzip - - -# Install the latest version of CMAKE v3.20.2 from source -WORKDIR / -RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.5/cmake-3.20.5-linux-x86_64.tar.gz -RUN tar xf cmake-3.20.5-linux-x86_64.tar.gz -ENV PATH="/cmake-3.20.5-linux-x86_64/bin:$PATH" - -# Intall LLVM 10 -WORKDIR / -# Make sure this URL gets updated each time there is a new prebuilt bin release -RUN wget https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_linux_patch2/taichi-llvm-10.0.0-linux.zip -RUN unzip taichi-llvm-10.0.0-linux.zip -ENV PATH="/taichi-llvm-10.0.0-linux/bin:$PATH" - -# Install Taichi from source -ENV CC="clang-10" -ENV CXX="clang++-10" -WORKDIR /taichi-dev - -# Prevent docker caching when head changes -ADD https://api.github.com/repos/taichi-dev/taichi/git/refs/heads/master version.json -RUN git clone https://github.com/taichi-dev/taichi --branch=master -WORKDIR ./taichi -RUN git cat-file -t $COMMIT_SHA || exit 1 -RUN git checkout $COMMIT_SHA -# Install Taichi's Python dependencies -RUN $PYTHON -m pip install --user -r requirements_dev.txt -# Build Taichi wheel from source -RUN git submodule update --init --recursive --depth=1 -WORKDIR /taichi-dev/taichi -WORKDIR python/ -ENV TAICHI_CMAKE_ARGS="-DTI_WITH_VULKAN:BOOL=OFF" -RUN $PYTHON build.py build $TEST_OPTION -WORKDIR ../ -RUN $PYTHON -m pip install dist/*.whl - -# Link Taichi source repo to Python Path -ENV LANG="C.UTF-8" - -# Show ELF info -RUN ldd build/libtaichi_core.so -RUN strings build/libtaichi_core.so | grep GLIBC - -# Install twine and upload project to pypi. -RUN $PYTHON -m pip install --user twine -RUN ti test -vr2 -WORKDIR python/ -RUN $PYTHON build.py upload --skip_build $TEST_OPTION diff --git a/ci/Dockerfile.manylinux2014.cpu b/ci/Dockerfile.manylinux2014.cpu new file mode 100644 index 0000000000000..ceda86ee0901c --- /dev/null +++ b/ci/Dockerfile.manylinux2014.cpu @@ -0,0 +1,55 @@ +# This file is generated by python Dockerfile_generator.py -o manylinux2014 -t cpu +# Taichi Dockerfile (CPU only) for Manylinux2014 compliant +FROM quay.io/pypa/manylinux2014_x86_64 + +LABEL maintainer="https://github.com/taichi-dev" + +RUN yum check-update && \ + yum install -y git \ + cmake \ + wget \ + libXrandr + +# Build LLVM/Clang 10 from source +WORKDIR / +RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/llvm-10.0.0.src.tar.xz +RUN tar -xf llvm-10.0.0.src.tar.xz && rm llvm-10.0.0.src.tar.xz +RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang-10.0.0.src.tar.xz +RUN tar -xf clang-10.0.0.src.tar.xz && rm clang-10.0.0.src.tar.xz +RUN cp -r clang-10.0.0.src llvm-10.0.0.src/tools/clang + +WORKDIR /llvm-10.0.0.src/build +RUN cmake .. -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_TERMINFO=OFF +RUN make -j 8 && make install +ENV CC="/usr/local/bin/clang" +ENV CXX="/usr/local/bin/clang++" + +# Link gcc 10 to build Taichi +WORKDIR /usr/lib/gcc/x86_64-redhat-linux/ +RUN ln -s /opt/rh/devtoolset-10/root/usr/lib/gcc/x86_64-redhat-linux/10 10 +# Check gcc-10 is used +RUN clang++ -v + +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev + +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" + +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 -y +RUN conda create -n py37 python=3.7 -y +RUN conda create -n py38 python=3.8 -y +RUN conda create -n py39 python=3.9 -y + +# Load scripts for build and test +WORKDIR /home/dev/scripts +COPY ci/scripts/manylinux_build_wheel.sh manylinux_build_wheel.sh + +WORKDIR /home/dev +ENV LANG="C.UTF-8" diff --git a/ci/Dockerfile.manylinux2014.cuda b/ci/Dockerfile.manylinux2014.cuda new file mode 100644 index 0000000000000..347c0f9968a81 --- /dev/null +++ b/ci/Dockerfile.manylinux2014.cuda @@ -0,0 +1,56 @@ +FROM nvidia/cudagl:11.2.2-devel-centos7 + +LABEL maintainer="https://github.com/taichi-dev" + +RUN yum install -y git wget + +# Install cmake 3.x +RUN yum install -y epel-release +RUN yum install -y cmake3 +RUN ln -s /usr/bin/cmake3 /usr/bin/cmake + +# Install gcc 10 (https://git.centos.org/rpms/devtoolset-10-gcc) +RUN yum install -y centos-release-scl +RUN yum install -y devtoolset-10-gcc* +ENV PATH="/opt/rh/devtoolset-10/root/usr/bin:$PATH" + +# Build LLVM/Clang 10 from source +WORKDIR / +RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/llvm-10.0.0.src.tar.xz +RUN tar -xf llvm-10.0.0.src.tar.xz && rm llvm-10.0.0.src.tar.xz +RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang-10.0.0.src.tar.xz +RUN tar -xf clang-10.0.0.src.tar.xz && rm clang-10.0.0.src.tar.xz +RUN cp -r clang-10.0.0.src llvm-10.0.0.src/tools/clang + +WORKDIR /llvm-10.0.0.src/build +RUN cmake .. -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_TERMINFO=OFF +RUN make -j 8 && make install +ENV CC="/usr/local/bin/clang" +ENV CXX="/usr/local/bin/clang++" + +# Link gcc 10 to build Taichi +WORKDIR /usr/lib/gcc/x86_64-redhat-linux/ +RUN ln -s /opt/rh/devtoolset-10/root/usr/lib/gcc/x86_64-redhat-linux/10 10 +# Check gcc-10 is used +RUN clang++ -v + +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev + +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" + +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 -y +RUN conda create -n py37 python=3.7 -y +RUN conda create -n py38 python=3.8 -y +RUN conda create -n py39 python=3.9 -y +RUN conda create -n py310 python=3.10 -y + +WORKDIR /home/dev +ENV LANG="C.UTF-8" diff --git a/ci/Dockerfile.ubuntu.18.04 b/ci/Dockerfile.ubuntu.18.04 new file mode 100644 index 0000000000000..6addb38ebc664 --- /dev/null +++ b/ci/Dockerfile.ubuntu.18.04 @@ -0,0 +1,101 @@ +# This file is generated by python Dockerfile_generator.py -o ubuntu -t gpu +# Taichi Dockerfile for development +FROM nvidia/cudagl:11.2.2-devel-ubuntu18.04 +# Use 11.2 instead of 11.4 to avoid forward compatibility issue on Nvidia driver 460 + +ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility + +ENV DEBIAN_FRONTEND=noninteractive + +LABEL maintainer="https://github.com/taichi-dev" + +RUN apt-get update && \ + apt-get install -y software-properties-common \ + python3-pip \ + libtinfo-dev \ + clang-10 \ + wget \ + git \ + unzip \ + libxrandr-dev \ + libxinerama-dev \ + libxcursor-dev \ + libxi-dev \ + libglu1-mesa-dev \ + freeglut3-dev \ + mesa-common-dev \ + libssl-dev \ + libglm-dev \ + libxcb-keysyms1-dev \ + libxcb-dri3-dev \ + libxcb-randr0-dev \ + libxcb-ewmh-dev \ + libpng-dev \ + g++-multilib \ + libmirclient-dev \ + libwayland-dev \ + bison \ + libx11-xcb-dev \ + liblz4-dev \ + libzstd-dev \ + qt5-default \ + libglfw3 \ + libglfw3-dev \ + libjpeg-dev \ + libvulkan-dev + +# Install the latest version of CMAKE v3.20.5 from source +WORKDIR / +RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.5/cmake-3.20.5-linux-x86_64.tar.gz +RUN tar xf cmake-3.20.5-linux-x86_64.tar.gz && \ + rm cmake-3.20.5-linux-x86_64.tar.gz +ENV PATH="/cmake-3.20.5-linux-x86_64/bin:$PATH" + +# Intall LLVM 10 +WORKDIR / +# Make sure this URL gets updated each time there is a new prebuilt bin release +RUN wget https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_linux_patch2/taichi-llvm-10.0.0-linux.zip +RUN unzip taichi-llvm-10.0.0-linux.zip && \ + rm taichi-llvm-10.0.0-linux.zip +ENV PATH="/taichi-llvm-10.0.0-linux/bin:$PATH" +# Use Clang as the default compiler +ENV CC="clang-10" +ENV CXX="clang++-10" + +# Setting up Vulkan SDK +# References +# [1] https://github.com/edowson/docker-nvidia-vulkan +# [2] https://gitlab.com/nvidia/container-images/vulkan/-/tree/master/docker +WORKDIR /vulkan +RUN wget https://sdk.lunarg.com/sdk/download/1.2.189.0/linux/vulkansdk-linux-x86_64-1.2.189.0.tar.gz +RUN tar xf vulkansdk-linux-x86_64-1.2.189.0.tar.gz && \ + rm vulkansdk-linux-x86_64-1.2.189.0.tar.gz +# Locate Vulkan components +ENV VULKAN_SDK="/vulkan/1.2.189.0/x86_64" +ENV PATH="$VULKAN_SDK/bin:$PATH" +ENV LD_LIBRARY_PATH="$VULKAN_SDK/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" +ENV VK_LAYER_PATH="$VULKAN_SDK/etc/vulkan/explicit_layer.d" +WORKDIR /usr/share/vulkan/icd.d +COPY vulkan/icd.d/nvidia_icd.json nvidia_icd.json + +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev + +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" + +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 pytorch cudatoolkit=10.2 -c pytorch -y +RUN conda create -n py37 python=3.7 pytorch cudatoolkit=10.2 -c pytorch -y +RUN conda create -n py38 python=3.8 pytorch cudatoolkit=10.2 -c pytorch -y +RUN conda create -n py39 python=3.9 pytorch cudatoolkit=10.2 -c pytorch -y +# TODO add torch to 3.10 when supported +RUN conda create -n py310 python=3.10 -y + +WORKDIR /home/dev +ENV LANG="C.UTF-8" diff --git a/ci/Dockerfile.ubuntu.18.04.cpu b/ci/Dockerfile.ubuntu.18.04.cpu new file mode 100644 index 0000000000000..c7c60244b6837 --- /dev/null +++ b/ci/Dockerfile.ubuntu.18.04.cpu @@ -0,0 +1,57 @@ +# This file is generated by python Dockerfile_generator.py -o ubuntu -t cpu +# Taichi Dockerfile for development +FROM ubuntu:18.04 + +ENV DEBIAN_FRONTEND=noninteractive + +LABEL maintainer="https://github.com/taichi-dev" + +RUN apt-get update && \ + apt-get install -y software-properties-common \ + python3-pip \ + libtinfo-dev \ + clang-10 \ + wget \ + git \ + unzip \ + libx11-xcb-dev \ + zlib1g-dev +# Install the latest version of CMAKE v3.20.5 from source +WORKDIR / +RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.5/cmake-3.20.5-linux-x86_64.tar.gz +RUN tar xf cmake-3.20.5-linux-x86_64.tar.gz && \ + rm cmake-3.20.5-linux-x86_64.tar.gz +ENV PATH="/cmake-3.20.5-linux-x86_64/bin:$PATH" + +# Intall LLVM 10 +WORKDIR / +# Make sure this URL gets updated each time there is a new prebuilt bin release +RUN wget https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_linux_patch2/taichi-llvm-10.0.0-linux.zip +RUN unzip taichi-llvm-10.0.0-linux.zip && \ + rm taichi-llvm-10.0.0-linux.zip +ENV PATH="/taichi-llvm-10.0.0-linux/bin:$PATH" +# Use Clang as the default compiler +ENV CC="clang-10" +ENV CXX="clang++-10" + +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev + +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" + +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 pytorch -y +RUN conda create -n py37 python=3.7 pytorch -y +RUN conda create -n py38 python=3.8 pytorch -y +RUN conda create -n py39 python=3.9 pytorch -y +# TODO add torch to 3.10 when supported +RUN conda create -n py310 python=3.10 -y + +WORKDIR /home/dev +ENV LANG="C.UTF-8" diff --git a/Dockerfile b/ci/Dockerfile.ubuntu.20.04 similarity index 73% rename from Dockerfile rename to ci/Dockerfile.ubuntu.20.04 index 189f487dc870d..b38442b4bc486 100644 --- a/Dockerfile +++ b/ci/Dockerfile.ubuntu.20.04 @@ -1,16 +1,17 @@ +# This file is generated by python Dockerfile_generator.py -o ubuntu -t gpu # Taichi Dockerfile for development FROM nvidia/cudagl:11.2.2-devel-ubuntu20.04 # Use 11.2 instead of 11.4 to avoid forward compatibility issue on Nvidia driver 460 ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility + ENV DEBIAN_FRONTEND=noninteractive + LABEL maintainer="https://github.com/taichi-dev" -# Ubuntu 20.04 installs Python 3.8 by default RUN apt-get update && \ apt-get install -y software-properties-common \ python3-pip \ - python-is-python3 \ libtinfo-dev \ clang-10 \ wget \ @@ -40,10 +41,10 @@ RUN apt-get update && \ qt5-default \ libglfw3 \ libglfw3-dev \ - vulkan-tools \ + libjpeg-dev \ libvulkan-dev \ + vulkan-tools \ vulkan-validationlayers-dev - # Install the latest version of CMAKE v3.20.5 from source WORKDIR / RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.5/cmake-3.20.5-linux-x86_64.tar.gz @@ -51,7 +52,6 @@ RUN tar xf cmake-3.20.5-linux-x86_64.tar.gz && \ rm cmake-3.20.5-linux-x86_64.tar.gz ENV PATH="/cmake-3.20.5-linux-x86_64/bin:$PATH" - # Intall LLVM 10 WORKDIR / # Make sure this URL gets updated each time there is a new prebuilt bin release @@ -59,7 +59,9 @@ RUN wget https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_li RUN unzip taichi-llvm-10.0.0-linux.zip && \ rm taichi-llvm-10.0.0-linux.zip ENV PATH="/taichi-llvm-10.0.0-linux/bin:$PATH" - +# Use Clang as the default compiler +ENV CC="clang-10" +ENV CXX="clang++-10" # Setting up Vulkan SDK # References @@ -77,30 +79,26 @@ ENV VK_LAYER_PATH="$VULKAN_SDK/etc/vulkan/explicit_layer.d" WORKDIR /usr/share/vulkan/icd.d COPY ci/vulkan/icd.d/nvidia_icd.json nvidia_icd.json +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev -# Install Taichi from source -ENV CC="clang-10" -ENV CXX="clang++-10" -WORKDIR /taichi-dev -# Prevent docker caching when head changes -ADD https://api.github.com/repos/taichi-dev/taichi/git/refs/heads/master version.json -RUN git clone --recursive https://github.com/taichi-dev/taichi --branch=master -WORKDIR /taichi-dev/taichi -RUN python3 -m pip install --user -r requirements_dev.txt -# Update Torch version, otherwise cuda tests fail. See #2969. -RUN python3 -m pip install torch==1.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html -RUN TAICHI_CMAKE_ARGS="-DTI_WITH_VULKAN:BOOL=ON -DTI_WITH_CUDA:BOOL=ON -DTI_WITH_OPENGL:BOOL=ON" python3 setup.py develop --user -# Show ELF info -RUN ldd build/libtaichi_core.so -RUN strings build/libtaichi_core.so | grep GLIBC +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" -# Link Taichi source repo to Python Path -ENV PATH="/taichi-dev/taichi/bin:$PATH" -ENV TAICHI_REPO_DIR="/taichi-dev/taichi/" -ENV PYTHONPATH="$TAICHI_REPO_DIR/python:$PYTHONPATH" -ENV LANG="C.UTF-8" +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 -y +RUN conda create -n py37 python=3.7 -y +RUN conda create -n py38 python=3.8 -y +RUN conda create -n py39 python=3.9 -y + +# Load scripts for build and test +WORKDIR /home/dev/scripts +COPY ci/scripts/ubuntu_build_test.sh ubuntu_build_test.sh -# Add Docker specific ENV -ENV TI_IN_DOCKER=true -WORKDIR /taichi-dev/taichi -CMD /bin/bash +WORKDIR /home/dev +ENV LANG="C.UTF-8" diff --git a/ci/Dockerfile.ubuntu.20.04.cpu b/ci/Dockerfile.ubuntu.20.04.cpu new file mode 100644 index 0000000000000..23d228faa9500 --- /dev/null +++ b/ci/Dockerfile.ubuntu.20.04.cpu @@ -0,0 +1,59 @@ +# This file is generated by python Dockerfile_generator.py -o ubuntu -t cpu +# Taichi Dockerfile for development +FROM ubuntu:20.04 + +ENV DEBIAN_FRONTEND=noninteractive + +LABEL maintainer="https://github.com/taichi-dev" + +RUN apt-get update && \ + apt-get install -y software-properties-common \ + python3-pip \ + libtinfo-dev \ + clang-10 \ + wget \ + git \ + unzip \ + libx11-xcb-dev + +# Install the latest version of CMAKE v3.20.5 from source +WORKDIR / +RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.5/cmake-3.20.5-linux-x86_64.tar.gz +RUN tar xf cmake-3.20.5-linux-x86_64.tar.gz && \ + rm cmake-3.20.5-linux-x86_64.tar.gz +ENV PATH="/cmake-3.20.5-linux-x86_64/bin:$PATH" + +# Intall LLVM 10 +WORKDIR / +# Make sure this URL gets updated each time there is a new prebuilt bin release +RUN wget https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_linux_patch2/taichi-llvm-10.0.0-linux.zip +RUN unzip taichi-llvm-10.0.0-linux.zip && \ + rm taichi-llvm-10.0.0-linux.zip +ENV PATH="/taichi-llvm-10.0.0-linux/bin:$PATH" +# Use Clang as the default compiler +ENV CC="clang-10" +ENV CXX="clang++-10" + +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev + +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" + +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 -y +RUN conda create -n py37 python=3.7 -y +RUN conda create -n py38 python=3.8 -y +RUN conda create -n py39 python=3.9 -y + +# Load scripts for build and test +WORKDIR /home/dev/scripts +COPY ci/scripts/ubuntu_build_test_cpu.sh ubuntu_build_test_cpu.sh + +WORKDIR /home/dev +ENV LANG="C.UTF-8" diff --git a/ci/Dockerfile_generator.py b/ci/Dockerfile_generator.py new file mode 100644 index 0000000000000..d75b2f45ab643 --- /dev/null +++ b/ci/Dockerfile_generator.py @@ -0,0 +1,351 @@ +import argparse +import functools +import sys +from enum import Enum +from functools import reduce +from pathlib import Path + +OS = { + "windows": (), + "macos": (), + "manylinux2014": ("", ), + "ubuntu": ( + "18.04", + "20.04", + ) +} +HARDWARE = ("cpu", "gpu") + +HEAD_BLOCK = """# This file is generated by python Dockerfile_generator.py -o {os} -t {target} +""" + +CPU_BASE_BLOCK = """# Taichi Dockerfile for development +FROM {os}:{version} +""" + +CPU_MANYLINUX_BASE_BLOCK = """# Taichi Dockerfile (CPU only) for Manylinux2014 compliant +FROM quay.io/pypa/manylinux2014_x86_64 +""" + +GPU_BASE_BLOCK = """# Taichi Dockerfile for development +FROM nvidia/cudagl:11.2.2-devel-ubuntu{version} +# Use 11.2 instead of 11.4 to avoid forward compatibility issue on Nvidia driver 460 +""" + +CPU_YUM_INSTALL_BLOCK = """ +RUN yum check-update && \\ + yum install -y git \\ + cmake \\ + wget \\ + libXrandr +""" + +CPU_APT_INSTALL_BLOCK = """ +RUN apt-get update && \\ + apt-get install -y software-properties-common \\ + python3-pip \\ + libtinfo-dev \\ + clang-10 \\ + wget \\ + git \\ + unzip \\ + libx11-xcb-dev +""" + +GPU_APT_INSTALL_BLOCK = """ +RUN apt-get update && \\ + apt-get install -y software-properties-common \\ + python3-pip \\ + libtinfo-dev \\ + clang-10 \\ + wget \\ + git \\ + unzip \\ + libxrandr-dev \\ + libxinerama-dev \\ + libxcursor-dev \\ + libxi-dev \\ + libglu1-mesa-dev \\ + freeglut3-dev \\ + mesa-common-dev \\ + libssl-dev \\ + libglm-dev \\ + libxcb-keysyms1-dev \\ + libxcb-dri3-dev \\ + libxcb-randr0-dev \\ + libxcb-ewmh-dev \\ + libpng-dev \\ + g++-multilib \\ + libmirclient-dev \\ + libwayland-dev \\ + bison \\ + libx11-xcb-dev \\ + liblz4-dev \\ + libzstd-dev \\ + qt5-default \\ + libglfw3 \\ + libglfw3-dev \\ + libjpeg-dev \\ + libvulkan-dev +""" + +NVIDIA_DRIVER_CAPABILITIES_BLOCK = """ +ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility +""" + +DEBIAN_NONINTERACTIVE_BLOCK = """ +ENV DEBIAN_FRONTEND=noninteractive +""" + +MAINTAINER_BLOCK = """ +LABEL maintainer="https://github.com/taichi-dev" +""" + +CMAKE_BLOCK = """ +# Install the latest version of CMAKE v3.20.5 from source +WORKDIR / +RUN wget https://github.com/Kitware/CMake/releases/download/v3.20.5/cmake-3.20.5-linux-x86_64.tar.gz +RUN tar xf cmake-3.20.5-linux-x86_64.tar.gz && \\ + rm cmake-3.20.5-linux-x86_64.tar.gz +ENV PATH="/cmake-3.20.5-linux-x86_64/bin:$PATH" +""" + +LLVM_BLOCK = """ +# Intall LLVM 10 +WORKDIR / +# Make sure this URL gets updated each time there is a new prebuilt bin release +RUN wget https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_linux_patch2/taichi-llvm-10.0.0-linux.zip +RUN unzip taichi-llvm-10.0.0-linux.zip && \\ + rm taichi-llvm-10.0.0-linux.zip +ENV PATH="/taichi-llvm-10.0.0-linux/bin:$PATH" +# Use Clang as the default compiler +ENV CC="clang-10" +ENV CXX="clang++-10" +""" + +LLVM_CLANG_FROM_SOURCE_BLOCK = """ +# Build LLVM/Clang 10 from source +WORKDIR / +RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/llvm-10.0.0.src.tar.xz +RUN tar -xf llvm-10.0.0.src.tar.xz && \ + rm llvm-10.0.0.src.tar.xz +RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang-10.0.0.src.tar.xz +RUN tar -xf clang-10.0.0.src.tar.xz && \ + rm clang-10.0.0.src.tar.xz +RUN cp -r clang-10.0.0.src llvm-10.0.0.src/tools/clang + +WORKDIR /llvm-10.0.0.src/build +RUN cmake .. -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_TERMINFO=OFF +RUN make -j 8 && \ + make install +ENV CC="/usr/local/bin/clang" +ENV CXX="/usr/local/bin/clang++" +""" + +GCC_LINK_BLOCK = """ +# Link gcc 10 to build Taichi +WORKDIR /usr/lib/gcc/x86_64-redhat-linux/ +RUN ln -s /opt/rh/devtoolset-10/root/usr/lib/gcc/x86_64-redhat-linux/10 10 +# Check gcc-10 is used +RUN clang++ -v +""" + +USER_BLOCK = """ +# Create non-root user for running the container +RUN useradd -ms /bin/bash dev +WORKDIR /home/dev +USER dev +""" + +VULKAN_BLOCK = """ +# Setting up Vulkan SDK +# References +# [1] https://github.com/edowson/docker-nvidia-vulkan +# [2] https://gitlab.com/nvidia/container-images/vulkan/-/tree/master/docker +WORKDIR /vulkan +RUN wget https://sdk.lunarg.com/sdk/download/1.2.189.0/linux/vulkansdk-linux-x86_64-1.2.189.0.tar.gz +RUN tar xf vulkansdk-linux-x86_64-1.2.189.0.tar.gz && \\ + rm vulkansdk-linux-x86_64-1.2.189.0.tar.gz +# Locate Vulkan components +ENV VULKAN_SDK="/vulkan/1.2.189.0/x86_64" +ENV PATH="$VULKAN_SDK/bin:$PATH" +ENV LD_LIBRARY_PATH="$VULKAN_SDK/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" +ENV VK_LAYER_PATH="$VULKAN_SDK/etc/vulkan/explicit_layer.d" +WORKDIR /usr/share/vulkan/icd.d +COPY ci/vulkan/icd.d/nvidia_icd.json nvidia_icd.json +""" + +CONDA_BLOCK = """ +# Install miniconda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \\ + bash Miniconda3-latest-Linux-x86_64.sh -p /home/dev/miniconda -b +ENV PATH="/home/dev/miniconda/bin:$PATH" + +# Set up multi-python environment +RUN conda init bash +RUN conda create -n py36 python=3.6 -y +RUN conda create -n py37 python=3.7 -y +RUN conda create -n py38 python=3.8 -y +RUN conda create -n py39 python=3.9 -y +""" + +SCRIPTS_BLOCK = """ +# Load scripts for build and test +WORKDIR /home/dev/scripts +COPY ci/scripts/{script} {script} + +WORKDIR /home/dev +ENV LANG="C.UTF-8" +""" + + +class Parser(argparse.ArgumentParser): + def error(self, message): + """Make it print help message by default.""" + sys.stderr.write(f"error: {message}\n") + self.print_help() + sys.exit(2) + + +class AvailableColors(Enum): + GRAY = 90 + RED = 91 + GREEN = 92 + YELLOW = 93 + BLUE = 94 + PURPLE = 95 + WHITE = 97 + BLACK = 30 + DEFAULT = 39 + + +def _apply_color(color: str, message: str) -> str: + """Dye message with color, fall back to default if it fails.""" + color_code = AvailableColors["DEFAULT"].value + try: + color_code = AvailableColors[color.upper()].value + except KeyError: + pass + return f"\033[1;{color_code}m{message}\033[0m" + + +def info(message: str, plain=False): + """Log the info to stdout""" + print(_apply_color("default", message) if not plain else message) + + +def success(message: str): + """Log the success to stdout""" + print(_apply_color("green", f"[✔] {message}")) + + +def error(message: str): + """Log the error to stderr""" + print(_apply_color("red", f"[✗] {message}"), file=sys.stderr) + + +def warn(message: str): + """Log the warning to stdout""" + print(_apply_color("yellow", f"[!] {message}")) + + +def main(arguments=None): + parser = Parser(description="""A CLI to generate Taichi CI Dockerfiles. + Example usage: + python3 Dockerfile_generator.py -o ubuntu -t cpu + """) + parser.add_argument( + "-o", + "--os", + help="The target os of the Dockerfile.", + required=True, + type=str, + choices=OS, + metavar="\b", + ) + parser.add_argument( + "-t", + "--target", + help="The target hardware of the Dockerfile. [cpu/gpu]", + required=True, + type=str, + choices=HARDWARE, + metavar="\b", + ) + args = parser.parse_args() + pwd = Path(__file__).resolve().parent + + head_block = HEAD_BLOCK.format(os=args.os, target=args.target) + + if args.target == "cpu": + info("Generating Dockerfile(s) for CPU.") + + def f(os: str, version: str) -> str: + info(f"OS: {os}, version: {version}") + + if os == "manylinux2014": + base_block = CPU_MANYLINUX_BASE_BLOCK + install_block = CPU_YUM_INSTALL_BLOCK + scripts_block = SCRIPTS_BLOCK.format( + script=f"manylinux_build_wheel.sh") + + dockerfile = reduce( + lambda x, y: x + y, + (head_block, base_block, MAINTAINER_BLOCK, install_block, + LLVM_CLANG_FROM_SOURCE_BLOCK, GCC_LINK_BLOCK, USER_BLOCK, + CONDA_BLOCK, scripts_block)) + + filename = pwd / f"Dockerfile.{os}.cpu" + else: + base_block = CPU_BASE_BLOCK.format(os=os, version=version) + install_block = CPU_APT_INSTALL_BLOCK + scripts_block = SCRIPTS_BLOCK.format( + script=f"{os}_build_test_cpu.sh") + # ubuntu 18.04 needs special treatments + if os == "ubuntu" and version == "18.04": + install_block = install_block.rstrip() + """ \\ + zlib1g-dev""" + + dockerfile = reduce( + lambda x, y: x + y, + (head_block, base_block, DEBIAN_NONINTERACTIVE_BLOCK, + MAINTAINER_BLOCK, install_block, CMAKE_BLOCK, LLVM_BLOCK, + USER_BLOCK, CONDA_BLOCK, scripts_block)) + + filename = pwd / f"Dockerfile.{os}.{version}.cpu" + + info(f"Storing at: {filename}") + with filename.open("w") as fp: + fp.write(dockerfile) + else: + info("Generating Dockerfile(s) for GPU.") + + def f(os: str, version: str) -> str: + info(f"OS: {os}, version: {version}") + base_block = GPU_BASE_BLOCK.format(version=version) + scripts_block = SCRIPTS_BLOCK.format(script=f"{os}_build_test.sh") + install_block = GPU_APT_INSTALL_BLOCK + + # ubuntu 20.04 needs special treatments + if os == "ubuntu" and version == "20.04": + install_block = install_block.rstrip() + """ \\ + vulkan-tools \\ + vulkan-validationlayers-dev""" + + dockerfile = reduce( + lambda x, y: x + y, + (head_block, base_block, NVIDIA_DRIVER_CAPABILITIES_BLOCK, + DEBIAN_NONINTERACTIVE_BLOCK, MAINTAINER_BLOCK, install_block, + CMAKE_BLOCK, LLVM_BLOCK, VULKAN_BLOCK, USER_BLOCK, + CONDA_BLOCK, scripts_block)) + filename = pwd / f"Dockerfile.{os}.{version}" + info(f"Storing at: {filename}") + with (filename).open("w") as fp: + fp.write(dockerfile) + + list(map(functools.partial(f, args.os), OS[args.os])) + success("Dockerfile generation is complete.") + + +if __name__ == "__main__": + main() diff --git a/ci/scripts/manylinux_build_wheel.sh b/ci/scripts/manylinux_build_wheel.sh new file mode 100755 index 0000000000000..e33157490eb8d --- /dev/null +++ b/ci/scripts/manylinux_build_wheel.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -ex + +# Parse ARGs +for ARGUMENT in "$@" +do + KEY=$(echo $ARGUMENT | cut -f1 -d=) + VALUE=$(echo $ARGUMENT | cut -f2 -d=) + case "$KEY" in + SHA) SHA=${VALUE} ;; + PY) PY=${VALUE} ;; + *) + esac +done + +source /home/dev/miniconda/etc/profile.d/conda.sh +conda activate $PY + +# Build Taichi from source +git clone --recursive https://github.com/taichi-dev/taichi --branch=master +cd taichi +git checkout $SHA +python3 -m pip install -r requirements_dev.txt -i http://repo.taichigraphics.com/repository/pypi/simple --trusted-host repo.taichigraphics.com +# Add Docker specific ENV +export TI_IN_DOCKER=true + +# TODO, unify this step with wheel build, check #3537 +TAICHI_CMAKE_ARGS="-DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_OPENGL:BOOL=OFF -DTI_WITH_CC:BOOL=OFF" python3 setup.py install +# build.py is to be removed +#cd python && python build.py build diff --git a/ci/scripts/release_test.sh b/ci/scripts/release_test.sh new file mode 100644 index 0000000000000..d64140b4f53bb --- /dev/null +++ b/ci/scripts/release_test.sh @@ -0,0 +1,272 @@ +#!/usr/bin/env bash + +# Taichi release test suite + +# Usage: `bash release_test.sh` + +# This script is created mainly for eyeball-testing +# that if all of the official examples are still working +# with the latest version of Taichi. + +# Some of the test cases are fetched from external repositories +# please reach out to us if you are the owner of those +# repos and don't like us to do it. + +# You can add more tests into this script and plug-n-play +# existing tests in the `taichi::test::main` function as +# you need. + +function taichi::utils::set_debug { + if [ ${DEBUG} == "true" ]; then + set -x + fi + set -euo pipefail +} + +function taichi::utils::logger { + # default: gray + if [ "$1" == "info" ]; then + printf '\e[1;90m%-6s\e[m\n' "$(date +"[%m-%d %H:%M:%S]") $2" + # error: red + elif [ "$1" == "error" ]; then + printf '\e[1;91m%-6s\e[m\n' "$(date +"[%m-%d %H:%M:%S]") $2" + # success: green + elif [ "$1" == "success" ]; then + printf '\e[1;92m%-6s\e[m\n' "$(date +"[%m-%d %H:%M:%S]") $2" + # warning: yellow + elif [ "$1" == "warning" ]; then + printf '\e[1;93m%-6s\e[m\n' "$(date +"[%m-%d %H:%M:%S]") $2" + # debug: gray + elif [ "$1" == "debug" ]; then + if [ "${DEBUG}" == "true" ]; then + printf '\e[1;90m%-6s\e[m\n' "$(date +"[%m-%d %H:%M:%S]") $2" + fi + else + printf "$1" + fi +} + +function taichi::utils::logger::info { + taichi::utils::logger "info" "$1" +} + +function taichi::utils::logger::error { + taichi::utils::logger "error" "$1" +} + +function taichi::utils::logger::success { + taichi::utils::logger "success" "$1" +} + +function taichi::utils::logger::warning { + taichi::utils::logger "warning" "$1" +} + +function taichi::utils::logger::debug { + taichi::utils::logger "debug" "$1" +} + +function taichi::utils::line { + printf '%.0s-' {1..20}; echo +} + +function taichi::utils::git_clone { + local GIT_ORG=$1 + local GIT_REPO=$2 + git clone "git@github.com:${GIT_ORG}/${GIT_REPO}.git" +} + +function taichi::utils::pause { + read -p "Press enter to continue" +} + +function taichi::test::ggui { + local WORKDIR=${1} + local PATTERN="*_ggui.py" + local ORG="taichi-dev" + local REPO="taichi" + + # divider + taichi::utils::line + taichi::utils::logger::info "Running GGUI examples" + + # clone the repo + taichi::utils::git_clone "${ORG}" "${REPO}" + cd "${REPO}/python/taichi/examples/ggui_examples" + + # run tests + for match in $(find ./ -name "${PATTERN}"); do + python "${match}" + taichi::utils::line + taichi::utils::pause + done + + # go back to workdir + cd "${WORKDIR}" +} + +function taichi::test::difftaichi { + local WORKDIR=${1} + local PATTERN="*.py" + local ORG="taichi-dev" + local REPO="difftaichi" + + # divider + taichi::utils::line + taichi::utils::logger::info "Running DiffTaichi examples" + + # clone the repo + taichi::utils::git_clone "${ORG}" "${REPO}" + cd "${REPO}/examples" + + # run tests + for match in $(find ./ -name "${PATTERN}"); do + python "${match}" + taichi::utils::line + taichi::utils::pause + done + + # go back to workdir + cd "${WORKDIR}" +} + +function taichi::test::taichi_elements { + local WORKDIR=${1} + local PATTERN="demo_*.py" + local ORG="taichi-dev" + local REPO="taichi_elements" + + # divider + taichi::utils::line + taichi::utils::logger::info "Running Taichi Elements examples" + + # clone the repo + taichi::utils::git_clone "${ORG}" "${REPO}" + cd "${REPO}" + + # install dependencies + python "download_ply.py" + + # run tests + cd "${REPO}/demo" + for match in $(find ./ -name "${PATTERN}"); do + python "${match}" + taichi::utils::line + taichi::utils::pause + done + + # run special tests + # FIXME: this does not work properly yet + # taichi::utils::logger::success $(ls) + # read -p "Please input the directory containing the generated particles, e.g. sim_2022-01-01_20-55-48" particles_dir + # python render_particles.py -i ./"${particles_dir}" \ + # -b 0 -e 400 -s 1 \ + # -o ./output \ + # --gpu-memory 20 \ + # -M 460 \ + # --shutter-time 0.0 \ + # -r 128 + + # go back to workdir + cd "${WORKDIR}" +} + +function taichi::test::stannum { + local WORKDIR=${1} + local ORG="ifsheldon" + local REPO="stannum" + + # divider + taichi::utils::line + taichi::utils::logger::info "Running Stannum examples" + + # clone the repo + taichi::utils::git_clone "${ORG}" "${REPO}" + cd "${REPO}" + + # run tests + pytest -v -s ./ + + # go back to workdir + cd "${WORKDIR}" +} + +function taichi::test::sandyfluid { + local WORKDIR=${1} + local ORG="ethz-pbs21" + local REPO="SandyFluid" + + # divider + taichi::utils::line + taichi::utils::logger::info "Running SandyFluid examples" + + # clone the repo + taichi::utils::git_clone "${ORG}" "${REPO}" + cd "${REPO}" + + # install dependencies + # remove the line contains pinned Taichi version for testing purposes + grep -v "taichi" requirements.txt > tmpfile && mv tmpfile requirements.txt + pip install -r requirements.txt + + # run tests + python src/main.py + + # go back to workdir + cd "${WORKDIR}" +} + +function taichi::test::voxel_editor { + local WORKDIR=${1} + local ORG="taichi-dev" + local REPO="voxel_editor" + + # divider + taichi::utils::line + taichi::utils::logger::info "Running Voxel Editor examples" + + # clone the repo + taichi::utils::git_clone "${ORG}" "${REPO}" + cd "${REPO}" + + # run tests + python voxel_editor.py + + # go back to workdir + cd "${WORKDIR}" +} + +function taichi::test::main { + # set debugging flag + DEBUG="false" + + # create a temporary directory for testing + WORKDIR="$(mktemp -d)" + taichi::utils::logger::info "Running all tests within ${WORKDIR}" + + # make sure to clean up the temp dir on exit + trap '{ rm -rf -- "$WORKDIR"; }' EXIT + + # walk into the working dir + cd "${WORKDIR}" + + # ggui examples + taichi::test::ggui "${WORKDIR}" + + # difftaichi examples + taichi::test::difftaichi "${WORKDIR}" + + # taichi_elements examples + taichi::test::taichi_elements "${WORKDIR}" + + # stannum tests + taichi::test::stannum "${WORKDIR}" + + # sandyfluid tests + taichi::test::sandyfluid "${WORKDIR}" + + # voxel editor tests + taichi::test::voxel_editor "${WORKDIR}" +} + +taichi::test::main diff --git a/ci/scripts/ubuntu_build_test.sh b/ci/scripts/ubuntu_build_test.sh new file mode 100755 index 0000000000000..ed5acc79c99a8 --- /dev/null +++ b/ci/scripts/ubuntu_build_test.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -ex + +# Parse ARGs +for ARGUMENT in "$@" +do + KEY=$(echo $ARGUMENT | cut -f1 -d=) + VALUE=$(echo $ARGUMENT | cut -f2 -d=) + case "$KEY" in + SHA) SHA=${VALUE} ;; + PY) PY=${VALUE} ;; + *) + esac +done + +source /home/dev/miniconda/etc/profile.d/conda.sh +conda activate $PY + +# Build Taichi from source +git clone --recursive https://github.com/taichi-dev/taichi --branch=master +cd taichi +git checkout $SHA +python3 -m pip install -r requirements_dev.txt -i http://repo.taichigraphics.com/repository/pypi/simple --trusted-host repo.taichigraphics.com +# Update Torch version, otherwise cuda tests fail. See #2969. +python3 -m pip install torch==1.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html -i http://repo.taichigraphics.com/repository/pypi/simple --trusted-host repo.taichigraphics.com +TAICHI_CMAKE_ARGS="-DTI_WITH_VULKAN:BOOL=ON -DTI_WITH_CUDA:BOOL=ON -DTI_WITH_OPENGL:BOOL=ON" python3 setup.py install + +# Add Docker specific ENV +export TI_IN_DOCKER=true + +# Run tests +ti diagnose +python tests/run_tests.py -vr2 -t2 -k "not ndarray and not torch" +python tests/run_tests.py -vr2 -t1 -k "ndarray or torch" diff --git a/ci/scripts/ubuntu_build_test_cpu.sh b/ci/scripts/ubuntu_build_test_cpu.sh new file mode 100755 index 0000000000000..feba31b80e874 --- /dev/null +++ b/ci/scripts/ubuntu_build_test_cpu.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -ex + +# Parse ARGs +for ARGUMENT in "$@" +do + KEY=$(echo $ARGUMENT | cut -f1 -d=) + VALUE=$(echo $ARGUMENT | cut -f2 -d=) + case "$KEY" in + SHA) SHA=${VALUE} ;; + PY) PY=${VALUE} ;; + *) + esac +done + +source /home/dev/miniconda/etc/profile.d/conda.sh +conda activate $PY + +# Build Taichi from source +git clone --recursive https://github.com/taichi-dev/taichi --branch=master +cd taichi +git checkout $SHA +python3 -m pip install -r requirements_dev.txt -i http://repo.taichigraphics.com/repository/pypi/simple --trusted-host repo.taichigraphics.com +TAICHI_CMAKE_ARGS="-DTI_WITH_VULKAN:BOOL=OFF -DTI_WITH_CUDA:BOOL=OFF -DTI_WITH_OPENGL:BOOL=OFF" python3 setup.py install + +# Add Docker specific ENV +export TI_IN_DOCKER=true + +# Run tests +ti diagnose +python tests/run_tests.py -vr2 -t2 -k "not ndarray and not torch" +python tests/run_tests.py -vr2 -t1 -k "ndarray or torch" diff --git a/cmake/PythonNumpyPybind11.cmake b/cmake/PythonNumpyPybind11.cmake index 5957afc7999c3..311630dba74a8 100644 --- a/cmake/PythonNumpyPybind11.cmake +++ b/cmake/PythonNumpyPybind11.cmake @@ -51,39 +51,26 @@ execute_process(COMMAND ${PYTHON_EXECUTABLE} -c sys.stdout.write(str(sys.version_info[1]))" OUTPUT_VARIABLE PYTHON_MINOR_VERSION) + if (WIN32) - link_directories(${PYTHON_LIBRARY_DIR}/../../libs) - set(PYTHON_LIBRARIES ${PYTHON_LIBRARY_DIR}/../../libs/python3.lib) - set(PYTHON_LIBRARIES ${PYTHON_LIBRARY_DIR}/../../libs/python3${PYTHON_MINOR_VERSION}.lib) + execute_process(COMMAND ${PYTHON_EXECUTABLE} -c + "import sys;sys.stdout.write(sys.base_prefix.replace('\\\\', '/'))" + OUTPUT_VARIABLE PYTHON_BASE_PREFIX) + link_directories(${PYTHON_BASE_PREFIX}/libs) + set(PYTHON_LIBRARIES ${PYTHON_BASE_PREFIX}/libs/python3.lib) + set(PYTHON_LIBRARIES ${PYTHON_BASE_PREFIX}/libs/python3${PYTHON_MINOR_VERSION}.lib) else() find_library(PYTHON_LIBRARY NAMES python${PYTHON_VERSION} python${PYTHON_VERSION}m PATHS ${PYTHON_LIBRARY_DIR} NO_DEFAULT_PATH NO_SYSTEM_ENVIRONMENT_PATH PATH_SUFFIXES x86_64-linux-gnu) set(PYTHON_LIBRARIES ${PYTHON_LIBRARY}) endif() -# Creating python enters -file(MAKE_DIRECTORY bin) -file(WRITE ${CMAKE_SOURCE_DIR}/bin/ti "#!${PYTHON_EXECUTABLE_PATH}\nimport taichi\nexit(taichi.main())") -execute_process(COMMAND chmod +x ${CMAKE_SOURCE_DIR}/bin/ti) -execute_process(COMMAND cp ${CMAKE_SOURCE_DIR}/bin/ti ${CMAKE_SOURCE_DIR}/bin/taichi) - include_directories(${PYTHON_INCLUDE_DIRS}) message(" version: ${PYTHON_VERSION}") message(" include: ${PYTHON_INCLUDE_DIRS}") message(" library: ${PYTHON_LIBRARIES}") -execute_process(COMMAND ${PYTHON_EXECUTABLE} -c - "import git; from git import Repo; import sys;\ - sys.stdout.write(git.__version__)" - OUTPUT_VARIABLE GITPYTHON_VERSION - RESULT_VARIABLE GITPYTHON_IMPORT_RET) -if (NOT GITPYTHON_IMPORT_RET) - message(" gitpython version: ${GITPYTHON_VERSION}") -else () - message(FATAL_ERROR "Cannot import git. Please install. ([sudo] pip3 install --user gitpython)") -endif () - execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import numpy.distutils, sys;\ sys.stdout.write(':'.join(numpy.distutils.misc_util.get_numpy_include_dirs()))" diff --git a/cmake/TaichiCXXFlags.cmake b/cmake/TaichiCXXFlags.cmake index 030d58c6a7c29..79fe36770b646 100644 --- a/cmake/TaichiCXXFlags.cmake +++ b/cmake/TaichiCXXFlags.cmake @@ -17,10 +17,23 @@ if (MINGW) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -static-libstdc++") endif () -if (MSVC) +# Do not enable lto for APPLE since it made linking extremely slow. +if (WIN32) + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -flto=thin") + endif() +endif() + +if (WIN32) link_directories(${CMAKE_CURRENT_SOURCE_DIR}/external/lib) - set(CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} /Zc:__cplusplus /std:c++17 /MP /Z7 /D \"_CRT_SECURE_NO_WARNINGS\" /D \"_ENABLE_EXTENDED_ALIGNED_STORAGE\"") + if (MSVC) + set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} /Zc:__cplusplus /std:c++17 /bigobj /wd4244 /wd4267 /nologo /Zi /D \"_CRT_SECURE_NO_WARNINGS\" /D \"_ENABLE_EXTENDED_ALIGNED_STORAGE\"") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fsized-deallocation -target x86_64-pc-windows-msvc") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -gcodeview") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -gcodeview") + endif() else() if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") message("Clang compiler detected. Using std=c++17.") @@ -37,7 +50,7 @@ else() endif () message("Building for processor ${CMAKE_SYSTEM_PROCESSOR}") -if ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64" OR "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "AMD64") +if ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64" OR "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "AMD64" OR "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "amd64") if (MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /D \"TI_ARCH_x64\"") else() @@ -48,6 +61,9 @@ if ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64" OR "${CMAKE_SYSTEM_PROCESSOR}" elseif ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "aarch64" OR "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_ARCH_ARM") set(ARCH "arm64") +elseif ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_ARCH_x86") + set(ARCH "x86") else() message(FATAL_ERROR "Unknown processor type ${CMAKE_SYSTEM_PROCESSOR}") endif() diff --git a/cmake/TaichiCore.cmake b/cmake/TaichiCore.cmake index c124205f1bf72..3545dc68b0fb2 100644 --- a/cmake/TaichiCore.cmake +++ b/cmake/TaichiCore.cmake @@ -1,9 +1,26 @@ option(USE_STDCPP "Use -stdlib=libc++" OFF) +option(TI_WITH_LLVM "Build with LLVM backends" ON) +option(TI_WITH_METAL "Build with the Metal backend" ON) option(TI_WITH_CUDA "Build with the CUDA backend" ON) option(TI_WITH_CUDA_TOOLKIT "Build with the CUDA toolkit" OFF) option(TI_WITH_OPENGL "Build with the OpenGL backend" ON) option(TI_WITH_CC "Build with the C backend" ON) option(TI_WITH_VULKAN "Build with the Vulkan backend" OFF) +option(TI_WITH_DX11 "Build with the DX11 backend" OFF) +option(TI_EMSCRIPTENED "Build using emscripten" OFF) +set(_TI_SYMBOL_VISIBILITY default) + +if(TI_EMSCRIPTENED) + set(TI_WITH_LLVM OFF) + set(TI_WITH_METAL OFF) + set(TI_WITH_CUDA OFF) + set(TI_WITH_OPENGL OFF) + set(TI_WITH_CC OFF) + set(TI_WITH_DX11 OFF) + + set(TI_WITH_VULKAN ON) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_EMSCRIPTENED") +endif() if(UNIX AND NOT APPLE) # Handy helper for Linux @@ -34,16 +51,21 @@ if (WIN32) endif() set(TI_WITH_GGUI OFF) -if(TI_WITH_CUDA AND TI_WITH_VULKAN) +if(TI_WITH_VULKAN AND NOT TI_EMSCRIPTENED) set(TI_WITH_GGUI ON) endif() -if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/external/glad/src/glad.c") +if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/external/glad/src/gl.c") set(TI_WITH_OPENGL OFF) message(WARNING "external/glad submodule not detected. Settings TI_WITH_OPENGL to OFF.") endif() +if(NOT TI_WITH_LLVM) + set(TI_WITH_CUDA OFF) + set(TI_WITH_CUDA_TOOLKIT OFF) +endif() + file(GLOB TAICHI_CORE_SOURCE @@ -57,6 +79,7 @@ file(GLOB TAICHI_WASM_SOURCE "taichi/backends/wasm/*.cpp" "taichi/backends/wasm/ file(GLOB TAICHI_CUDA_SOURCE "taichi/backends/cuda/*.cpp" "taichi/backends/cuda/*.h") file(GLOB TAICHI_METAL_SOURCE "taichi/backends/metal/*.h" "taichi/backends/metal/*.cpp" "taichi/backends/metal/shaders/*") file(GLOB TAICHI_OPENGL_SOURCE "taichi/backends/opengl/*.h" "taichi/backends/opengl/*.cpp" "taichi/backends/opengl/shaders/*") +file(GLOB TAICHI_DX11_SOURCE "taichi/backends/dx/*.h" "taichi/backends/dx/*.cpp") file(GLOB TAICHI_CC_SOURCE "taichi/backends/cc/*.h" "taichi/backends/cc/*.cpp") file(GLOB TAICHI_VULKAN_SOURCE "taichi/backends/vulkan/*.h" "taichi/backends/vulkan/*.cpp" "external/SPIRV-Reflect/spirv_reflect.c") file(GLOB TAICHI_INTEROP_SOURCE "taichi/backends/interop/*.cpp" "taichi/backends/interop/*.h") @@ -66,16 +89,22 @@ file(GLOB TAICHI_GGUI_SOURCE "taichi/ui/*.cpp" "taichi/ui/*/*.cpp" "taichi/ui/*/*/*.cpp" "taichi/ui/*/*/*/*.cpp" "taichi/ui/*/*/*/*/*.cpp" "taichi/ui/*.h" "taichi/ui/*/*.h" "taichi/ui/*/*/*.h" "taichi/ui/*/*/*/*.h" "taichi/ui/*/*/*/*/*.h" ) +file(GLOB TAICHI_GGUI_GLFW_SOURCE + "taichi/ui/common/window_base.cpp" + "taichi/ui/backends/vulkan/window.cpp" +) list(REMOVE_ITEM TAICHI_CORE_SOURCE ${TAICHI_GGUI_SOURCE}) if(TI_WITH_GGUI) add_definitions(-DTI_WITH_GGUI) - list(APPEND TAICHI_CORE_SOURCE ${TAICHI_GGUI_SOURCE}) - - include_directories(SYSTEM external/glm) + # Remove GLFW dependencies from the build for Android + if(ANDROID) + list(REMOVE_ITEM TAICHI_GGUI_SOURCE ${TAICHI_GGUI_GLFW_SOURCE}) + endif() + list(APPEND TAICHI_CORE_SOURCE ${TAICHI_GGUI_SOURCE}) endif() # These files are compiled into .bc and loaded as LLVM module dynamically. They should not be compiled into libtaichi. So they're removed here @@ -95,14 +124,19 @@ file(GLOB TAICHI_OPENGL_REQUIRED_SOURCE file(GLOB TAICHI_VULKAN_REQUIRED_SOURCE "taichi/backends/vulkan/runtime.h" "taichi/backends/vulkan/runtime.cpp" - "taichi/backends/vulkan/snode_struct_compiler.cpp" - "taichi/backends/vulkan/snode_struct_compiler.h" ) list(REMOVE_ITEM TAICHI_CORE_SOURCE ${TAICHI_BACKEND_SOURCE}) -list(APPEND TAICHI_CORE_SOURCE ${TAICHI_CPU_SOURCE}) -list(APPEND TAICHI_CORE_SOURCE ${TAICHI_WASM_SOURCE}) +if(TI_WITH_LLVM) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_WITH_LLVM") + list(APPEND TAICHI_CORE_SOURCE ${TAICHI_CPU_SOURCE}) + list(APPEND TAICHI_CORE_SOURCE ${TAICHI_WASM_SOURCE}) +else() + file(GLOB TAICHI_LLVM_SOURCE "taichi/llvm/*.cpp" "taichi/llvm/*.h") + list(REMOVE_ITEM TAICHI_CORE_SOURCE ${TAICHI_LLVM_SOURCE}) +endif() + list(APPEND TAICHI_CORE_SOURCE ${TAICHI_INTEROP_SOURCE}) @@ -115,14 +149,21 @@ if(NOT CUDA_VERSION) set(CUDA_VERSION 10.0) endif() -# TODO(#529) include Metal source only on Apple MacOS, and OpenGL only when TI_WITH_OPENGL is ON -list(APPEND TAICHI_CORE_SOURCE ${TAICHI_METAL_SOURCE}) + +# By default, TI_WITH_METAL is ON for all platforms. +# As of right now, on non-macOS platforms, the metal backend won't work at all. +# We have future plans to allow metal AOT to run on non-macOS devices. +if (TI_WITH_METAL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_WITH_METAL") + list(APPEND TAICHI_CORE_SOURCE ${TAICHI_METAL_SOURCE}) +endif() + if (TI_WITH_OPENGL) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_WITH_OPENGL") # Q: Why not external/glad/src/*.c? # A: To ensure glad submodule exists when TI_WITH_OPENGL is ON. - file(GLOB TAICHI_GLAD_SOURCE "external/glad/src/glad.c") + file(GLOB TAICHI_GLAD_SOURCE "external/glad/src/gl.c" "external/glad/src/egl.c") list(APPEND TAICHI_CORE_SOURCE ${TAICHI_GLAD_SOURCE}) list(APPEND TAICHI_CORE_SOURCE ${TAICHI_OPENGL_SOURCE}) endif() @@ -140,6 +181,11 @@ if (TI_WITH_VULKAN) endif() list(APPEND TAICHI_CORE_SOURCE ${TAICHI_VULKAN_REQUIRED_SOURCE}) +if (TI_WITH_DX11) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTI_WITH_DX11") + list(APPEND TAICHI_CORE_SOURCE ${TAICHI_DX11_SOURCE}) +endif() + # This compiles all the libraries with -fPIC, which is critical to link a static # library into a shared lib. set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -158,6 +204,14 @@ file(GLOB TAICHI_PYBIND_SOURCE ) list(REMOVE_ITEM TAICHI_CORE_SOURCE ${TAICHI_PYBIND_SOURCE}) +file(GLOB TAICHI_EMBIND_SOURCE + "taichi/javascript/*.cpp" + "taichi/javascript/*.h" +) +if (TAICHI_EMBIND_SOURCE) + list(REMOVE_ITEM TAICHI_CORE_SOURCE ${TAICHI_EMBIND_SOURCE}) +endif() + # TODO(#2196): Rename these CMAKE variables: # CORE_LIBRARY_NAME --> TAICHI_ISOLATED_CORE_LIB_NAME # CORE_WITH_PYBIND_LIBRARY_NAME --> TAICHI_CORE_LIB_NAME @@ -171,6 +225,7 @@ list(REMOVE_ITEM TAICHI_CORE_SOURCE ${TAICHI_PYBIND_SOURCE}) # everywhere in python. set(CORE_LIBRARY_NAME taichi_isolated_core) add_library(${CORE_LIBRARY_NAME} OBJECT ${TAICHI_CORE_SOURCE}) +set_target_properties(${CORE_LIBRARY_NAME} PROPERTIES CXX_VISIBILITY_PRESET ${_TI_SYMBOL_VISIBILITY}) if (APPLE) # Ask OS X to minic Linux dynamic linking behavior @@ -181,19 +236,26 @@ include_directories(${CMAKE_SOURCE_DIR}) include_directories(external/include) include_directories(external/spdlog/include) if (TI_WITH_OPENGL) - include_directories(external/glad/include) + target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/glad/include) endif() + target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/FP16/include) set(LIBRARY_NAME ${CORE_LIBRARY_NAME}) -if (TI_WITH_OPENGL) +# GLFW not available on Android +if (TI_WITH_OPENGL OR TI_WITH_VULKAN AND NOT ANDROID AND NOT TI_EMSCRIPTENED) set(GLFW_BUILD_DOCS OFF CACHE BOOL "" FORCE) set(GLFW_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) + if (APPLE) + set(GLFW_VULKAN_STATIC ON CACHE BOOL "" FORCE) + endif() + message("Building with GLFW") add_subdirectory(external/glfw) target_link_libraries(${LIBRARY_NAME} glfw) + target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/glfw/include) endif() if(DEFINED ENV{LLVM_DIR}) @@ -201,48 +263,50 @@ if(DEFINED ENV{LLVM_DIR}) message("Getting LLVM_DIR=${LLVM_DIR} from the environment variable") endif() -# http://llvm.org/docs/CMake.html#embedding-llvm-in-your-project -find_package(LLVM REQUIRED CONFIG) -message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") -if(${LLVM_PACKAGE_VERSION} VERSION_LESS "10.0") - message(FATAL_ERROR "LLVM version < 10 is not supported") -endif() -message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -include_directories(${LLVM_INCLUDE_DIRS}) -message("LLVM include dirs ${LLVM_INCLUDE_DIRS}") -message("LLVM library dirs ${LLVM_LIBRARY_DIRS}") -add_definitions(${LLVM_DEFINITIONS}) - -llvm_map_components_to_libnames(llvm_libs - Core - ExecutionEngine - InstCombine - OrcJIT - RuntimeDyld - TransformUtils - BitReader - BitWriter - Object - ScalarOpts - Support - native - Linker - Target - MC - Passes - ipo - Analysis - ) -target_link_libraries(${LIBRARY_NAME} ${llvm_libs}) - -if (APPLE AND "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64") - llvm_map_components_to_libnames(llvm_aarch64_libs AArch64) - target_link_libraries(${LIBRARY_NAME} ${llvm_aarch64_libs}) -endif() +if(TI_WITH_LLVM) + # http://llvm.org/docs/CMake.html#embedding-llvm-in-your-project + find_package(LLVM REQUIRED CONFIG) + message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") + if(${LLVM_PACKAGE_VERSION} VERSION_LESS "10.0") + message(FATAL_ERROR "LLVM version < 10 is not supported") + endif() + message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + include_directories(${LLVM_INCLUDE_DIRS}) + message("LLVM include dirs ${LLVM_INCLUDE_DIRS}") + message("LLVM library dirs ${LLVM_LIBRARY_DIRS}") + add_definitions(${LLVM_DEFINITIONS}) + + llvm_map_components_to_libnames(llvm_libs + Core + ExecutionEngine + InstCombine + OrcJIT + RuntimeDyld + TransformUtils + BitReader + BitWriter + Object + ScalarOpts + Support + native + Linker + Target + MC + Passes + ipo + Analysis + ) + target_link_libraries(${LIBRARY_NAME} ${llvm_libs}) + + if (APPLE AND "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64") + llvm_map_components_to_libnames(llvm_aarch64_libs AArch64) + target_link_libraries(${LIBRARY_NAME} ${llvm_aarch64_libs}) + endif() -if (TI_WITH_CUDA) - llvm_map_components_to_libnames(llvm_ptx_libs NVPTX) - target_link_libraries(${LIBRARY_NAME} ${llvm_ptx_libs}) + if (TI_WITH_CUDA) + llvm_map_components_to_libnames(llvm_ptx_libs NVPTX) + target_link_libraries(${LIBRARY_NAME} ${llvm_ptx_libs}) + endif() endif() if (TI_WITH_CUDA_TOOLKIT) @@ -261,38 +325,32 @@ else() message(STATUS "TI_WITH_CUDA_TOOLKIT = OFF") endif() -add_subdirectory(external/SPIRV-Cross) -target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Cross) -target_link_libraries(${CORE_LIBRARY_NAME} spirv-cross-glsl spirv-cross-core) - -if (TI_WITH_VULKAN) - # Vulkan libs - # https://cmake.org/cmake/help/latest/module/FindVulkan.html - # https://github.com/PacktPublishing/Learning-Vulkan/blob/master/Chapter%2003/HandShake/CMakeLists.txt - find_package(Vulkan REQUIRED) - - if(NOT Vulkan_FOUND) - message(FATAL_ERROR "TI_WITH_VULKAN is ON but Vulkan could not be found") - endif() +if (TI_WITH_OPENGL) + set(SPIRV_CROSS_CLI false) + add_subdirectory(external/SPIRV-Cross) + target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Cross) + target_link_libraries(${CORE_LIBRARY_NAME} spirv-cross-glsl spirv-cross-core) +endif() - message(STATUS "Vulkan_INCLUDE_DIR=${Vulkan_INCLUDE_DIR}") - message(STATUS "Vulkan_LIBRARY=${Vulkan_LIBRARY}") +if (TI_WITH_DX11) + set(SPIRV_CROSS_CLI false) + #target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Cross) + target_link_libraries(${CORE_LIBRARY_NAME} spirv-cross-hlsl spirv-cross-core) +endif() - include_directories(external/SPIRV-Headers/include) +# SPIR-V codegen is always there, regardless of Vulkan +set(SPIRV_SKIP_EXECUTABLES true) +set(SPIRV-Headers_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/SPIRV-Headers) +add_subdirectory(external/SPIRV-Tools) +# NOTE: SPIRV-Tools-opt must come before SPIRV-Tools +# https://github.com/KhronosGroup/SPIRV-Tools/issues/1569#issuecomment-390250792 +target_link_libraries(${CORE_LIBRARY_NAME} SPIRV-Tools-opt ${SPIRV_TOOLS}) - set(SPIRV_SKIP_EXECUTABLES true) - set(SPIRV-Headers_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/SPIRV-Headers) - add_subdirectory(external/SPIRV-Tools) - # NOTE: SPIRV-Tools-opt must come before SPIRV-Tools - # https://github.com/KhronosGroup/SPIRV-Tools/issues/1569#issuecomment-390250792 - target_link_libraries(${CORE_LIBRARY_NAME} SPIRV-Tools-opt ${SPIRV_TOOLS}) +if (TI_WITH_VULKAN) + include_directories(SYSTEM external/Vulkan-Headers/include) - # No longer link against vulkan, using volk instead - #target_link_libraries(${CORE_LIBRARY_NAME} ${Vulkan_LIBRARY}) - include_directories(${Vulkan_INCLUDE_DIR}) - include_directories(external/volk) + include_directories(SYSTEM external/volk) - # Is this the best way to include the SPIRV-Headers? target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Headers/include) target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/SPIRV-Reflect) target_include_directories(${CORE_LIBRARY_NAME} PRIVATE external/VulkanMemoryAllocator/include) @@ -303,6 +361,12 @@ if (TI_WITH_VULKAN) find_package(Threads REQUIRED) target_link_libraries(${CORE_LIBRARY_NAME} Threads::Threads) endif() + + if (APPLE) + find_library(MOLTEN_VK libMoltenVK.dylib PATHS $HOMEBREW_CELLAR/molten-vk $VULKAN_SDK REQUIRED) + configure_file(${MOLTEN_VK} ${CMAKE_BINARY_DIR}/libMoltenVK.dylib COPYONLY) + message(STATUS "MoltenVK library ${MOLTEN_VK}") + endif() endif () # Optional dependencies @@ -312,17 +376,31 @@ if (APPLE) endif () if (NOT WIN32) - target_link_libraries(${CORE_LIBRARY_NAME} pthread stdc++) - if (APPLE) - # OS X + # Android has a custom toolchain so pthread is not available and should + # link against other libraries as well for logcat and internal features. + if (ANDROID) + target_link_libraries(${CORE_LIBRARY_NAME} android log) + else() + target_link_libraries(${CORE_LIBRARY_NAME} pthread stdc++) + endif() + + if (UNIX AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Linux") + # OS X or BSD else() # Linux target_link_libraries(${CORE_LIBRARY_NAME} stdc++fs X11) target_link_libraries(${CORE_LIBRARY_NAME} -static-libgcc -static-libstdc++) - if (NOT TI_EXPORT_CORE) # expose api for CHI IR Builder + if ((NOT TI_EXPORT_CORE) AND (NOT ${_TI_SYMBOL_VISIBILITY} STREQUAL hidden)) # expose api for CHI IR Builder + message(WARNING "Using linker.map to hide symbols!") target_link_libraries(${CORE_LIBRARY_NAME} -Wl,--version-script,${CMAKE_CURRENT_SOURCE_DIR}/misc/linker.map) endif () - target_link_libraries(${CORE_LIBRARY_NAME} -Wl,--wrap=log2f) # Avoid glibc dependencies + # Avoid glibc dependencies + if (TI_WITH_VULKAN) + target_link_libraries(${CORE_LIBRARY_NAME} -Wl,--wrap=log2f) + else() + # Enforce compatibility with manylinux2014 + target_link_libraries(${CORE_LIBRARY_NAME} -Wl,--wrap=log2f -Wl,--wrap=exp2 -Wl,--wrap=log2 -Wl,--wrap=logf -Wl,--wrap=powf -Wl,--wrap=exp -Wl,--wrap=log -Wl,--wrap=pow) + endif() endif() else() # windows @@ -338,31 +416,59 @@ endforeach () message("PYTHON_LIBRARIES: " ${PYTHON_LIBRARIES}) -set(CORE_WITH_PYBIND_LIBRARY_NAME taichi_core) -add_library(${CORE_WITH_PYBIND_LIBRARY_NAME} SHARED ${TAICHI_PYBIND_SOURCE}) -# It is actually possible to link with an OBJECT library -# https://cmake.org/cmake/help/v3.13/command/target_link_libraries.html?highlight=target_link_libraries#linking-object-libraries -target_link_libraries(${CORE_WITH_PYBIND_LIBRARY_NAME} PUBLIC ${CORE_LIBRARY_NAME}) - -# These commands should apply to the DLL that is loaded from python, not the OBJECT library. -if (MSVC) - set_property(TARGET ${CORE_WITH_PYBIND_LIBRARY_NAME} APPEND PROPERTY LINK_FLAGS /DEBUG) -endif () +if(NOT TI_EMSCRIPTENED) + set(CORE_WITH_PYBIND_LIBRARY_NAME taichi_core) + # Cannot compile Python source code with Android, but TI_EXPORT_CORE should be set and + # Android should only use the isolated library ignoring those source code. + if (NOT ANDROID) + add_library(${CORE_WITH_PYBIND_LIBRARY_NAME} SHARED ${TAICHI_PYBIND_SOURCE}) + else() + add_library(${CORE_WITH_PYBIND_LIBRARY_NAME} SHARED) + endif () -if (WIN32) - set_target_properties(${CORE_WITH_PYBIND_LIBRARY_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY - "${CMAKE_CURRENT_SOURCE_DIR}/runtimes") -endif () + set_target_properties(${CORE_WITH_PYBIND_LIBRARY_NAME} PROPERTIES CXX_VISIBILITY_PRESET ${_TI_SYMBOL_VISIBILITY}) + # Remove symbols from static libs: https://stackoverflow.com/a/14863432/12003165 + if (LINUX) + target_link_options(${CORE_WITH_PYBIND_LIBRARY_NAME} PUBLIC -Wl,--exclude-libs=ALL) + endif() + # It is actually possible to link with an OBJECT library + # https://cmake.org/cmake/help/v3.13/command/target_link_libraries.html?highlight=target_link_libraries#linking-object-libraries + target_link_libraries(${CORE_WITH_PYBIND_LIBRARY_NAME} PUBLIC ${CORE_LIBRARY_NAME}) + + # These commands should apply to the DLL that is loaded from python, not the OBJECT library. + if (MSVC) + set_property(TARGET ${CORE_WITH_PYBIND_LIBRARY_NAME} APPEND PROPERTY LINK_FLAGS /DEBUG) + endif () + + if (WIN32) + set_target_properties(${CORE_WITH_PYBIND_LIBRARY_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY + "${CMAKE_CURRENT_SOURCE_DIR}/runtimes") + endif () +endif() +if(TI_EMSCRIPTENED) + set(CORE_WITH_EMBIND_LIBRARY_NAME taichi) + add_executable(${CORE_WITH_EMBIND_LIBRARY_NAME} ${TAICHI_EMBIND_SOURCE}) + target_link_libraries(${CORE_WITH_EMBIND_LIBRARY_NAME} PUBLIC ${CORE_LIBRARY_NAME}) + target_compile_options(${CORE_WITH_EMBIND_LIBRARY_NAME} PRIVATE "-Oz") + # target_compile_options(${CORE_LIBRARY_NAME} PRIVATE "-Oz") + set_target_properties(${CORE_LIBRARY_NAME} PROPERTIES LINK_FLAGS "-s ERROR_ON_UNDEFINED_SYMBOLS=0 -s ASSERTIONS=1") + set_target_properties(${CORE_WITH_EMBIND_LIBRARY_NAME} PROPERTIES LINK_FLAGS "--bind -s MODULARIZE=1 -s EXPORT_NAME=createTaichiModule -s WASM=0 --memory-init-file 0 -Oz --closure 1 -s ERROR_ON_UNDEFINED_SYMBOLS=0 -s ASSERTIONS=1 -s NO_DISABLE_EXCEPTION_CATCHING") +endif() if(TI_WITH_GGUI) + include_directories(SYSTEM PRIVATE external/glm) # Dear ImGui add_definitions(-DIMGUI_IMPL_VULKAN_NO_PROTOTYPES) set(IMGUI_DIR external/imgui) - include_directories(external/glfw/include) include_directories(SYSTEM ${IMGUI_DIR} ${IMGUI_DIR}/backends ..) +if(ANDROID) + add_library(imgui ${IMGUI_DIR}/backends/imgui_impl_android.cpp ${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp ${IMGUI_DIR}/imgui.cpp ${IMGUI_DIR}/imgui_draw.cpp ${IMGUI_DIR}/imgui_tables.cpp ${IMGUI_DIR}/imgui_widgets.cpp) +else() + include_directories(external/glfw/include) add_library(imgui ${IMGUI_DIR}/backends/imgui_impl_glfw.cpp ${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp ${IMGUI_DIR}/imgui.cpp ${IMGUI_DIR}/imgui_draw.cpp ${IMGUI_DIR}/imgui_tables.cpp ${IMGUI_DIR}/imgui_widgets.cpp) +endif() target_link_libraries(${CORE_LIBRARY_NAME} imgui) endif() diff --git a/cmake/TaichiExamples.cmake b/cmake/TaichiExamples.cmake new file mode 100644 index 0000000000000..742119013d623 --- /dev/null +++ b/cmake/TaichiExamples.cmake @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.0) + +if(NOT TI_EMSCRIPTENED) + +set(EXAMPLES_NAME taichi_cpp_examples) + +file(GLOB_RECURSE TAICHI_EXAMPLES_SOURCE +"cpp_examples/main.cpp" +"cpp_examples/run_snode.cpp" +"cpp_examples/autograd.cpp" +"cpp_examples/aot_save.cpp" +) + +include_directories( + ${PROJECT_SOURCE_DIR}, +) + +add_executable(${EXAMPLES_NAME} ${TAICHI_EXAMPLES_SOURCE}) +if (WIN32) + # Output the executable to bin/ instead of build/Debug/... + set(EXAMPLES_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/bin") + set_target_properties(${EXAMPLES_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${EXAMPLES_OUTPUT_DIR}) + set_target_properties(${EXAMPLES_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${EXAMPLES_OUTPUT_DIR}) + set_target_properties(${EXAMPLES_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE ${EXAMPLES_OUTPUT_DIR}) + set_target_properties(${EXAMPLES_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_MINSIZEREL ${EXAMPLES_OUTPUT_DIR}) + set_target_properties(${EXAMPLES_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELWITHDEBINFO ${EXAMPLES_OUTPUT_DIR}) +endif() +target_link_libraries(${EXAMPLES_NAME} taichi_isolated_core) + +endif() diff --git a/cmake/TaichiMain.cmake b/cmake/TaichiMain.cmake deleted file mode 100644 index 3504c0049ed38..0000000000000 --- a/cmake/TaichiMain.cmake +++ /dev/null @@ -1,10 +0,0 @@ -if (WIN32) - message("Added executable 'ti' for Windows") - # On Windows, generate a ti.exe as entry - add_executable(ti python/ti.cpp) - set_target_properties(ti PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/bin") - set_target_properties(ti PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${CMAKE_CURRENT_SOURCE_DIR}/bin") - set_target_properties(ti PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${CMAKE_CURRENT_SOURCE_DIR}/bin") - set_target_properties(ti PROPERTIES RUNTIME_OUTPUT_DIRECTORY_MINSIZEREL "${CMAKE_CURRENT_SOURCE_DIR}/bin") - set_target_properties(ti PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELWITHDEBINFO "${CMAKE_CURRENT_SOURCE_DIR}/bin") -endif () diff --git a/cmake/TaichiTests.cmake b/cmake/TaichiTests.cmake index 2946b758e56cc..2b6da207c82ff 100644 --- a/cmake/TaichiTests.cmake +++ b/cmake/TaichiTests.cmake @@ -12,6 +12,8 @@ endif() # 2. Re-implement the legacy CPP tests using googletest file(GLOB_RECURSE TAICHI_TESTS_SOURCE "tests/cpp/analysis/*.cpp" + "tests/cpp/aot/*.cpp" + "tests/cpp/backends/*.cpp" "tests/cpp/codegen/*.cpp" "tests/cpp/common/*.cpp" "tests/cpp/ir/*.cpp" diff --git a/cpp_examples/aot_save.cpp b/cpp_examples/aot_save.cpp new file mode 100644 index 0000000000000..cc88ebbb53403 --- /dev/null +++ b/cpp_examples/aot_save.cpp @@ -0,0 +1,78 @@ +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/statements.h" +#include "taichi/program/program.h" + +void aot_save() { + using namespace taichi; + using namespace lang; + auto program = Program(Arch::vulkan); + + program.config.advanced_optimization = false; + + int n = 10; + + // program.materialize_runtime(); + auto *root = new SNode(0, SNodeType::root); + auto *pointer = &root->dense(Axis(0), n, false); + auto *place = &pointer->insert_children(SNodeType::place); + place->dt = PrimitiveType::i32; + program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/true); + + auto aot_builder = program.make_aot_module_builder(Arch::vulkan); + + std::unique_ptr kernel_init, kernel_ret; + + { + /* + @ti.kernel + def init(): + for index in range(n): + place[index] = index + */ + IRBuilder builder; + auto *zero = builder.get_int32(0); + auto *n_stmt = builder.get_int32(n); + auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *ptr = builder.create_global_ptr(place, {index}); + builder.create_global_store(ptr, index); + } + + kernel_init = + std::make_unique(program, builder.extract_ir(), "init"); + } + + { + /* + @ti.kernel + def ret(): + sum = 0 + for index in place: + sum = sum + place[index]; + return sum + */ + IRBuilder builder; + auto *sum = builder.create_local_var(PrimitiveType::i32); + auto *loop = builder.create_struct_for(pointer, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *sum_old = builder.create_local_load(sum); + auto *place_index = + builder.create_global_load(builder.create_global_ptr(place, {index})); + builder.create_local_store(sum, builder.create_add(sum_old, place_index)); + } + builder.create_return(builder.create_local_load(sum)); + + kernel_ret = std::make_unique(program, builder.extract_ir(), "ret"); + kernel_ret->insert_ret(PrimitiveType::i32); + } + + aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1); + aot_builder->add("init", kernel_init.get()); + aot_builder->add("ret", kernel_ret.get()); + aot_builder->dump(".", "aot.tcb"); + std::cout << "done" << std::endl; +} diff --git a/examples/chi_examples/main.cpp b/cpp_examples/autograd.cpp similarity index 57% rename from examples/chi_examples/main.cpp rename to cpp_examples/autograd.cpp index 7424b5389676b..0b3bc8f43ed95 100644 --- a/examples/chi_examples/main.cpp +++ b/cpp_examples/autograd.cpp @@ -1,147 +1,7 @@ -#include - #include "taichi/ir/ir_builder.h" #include "taichi/ir/statements.h" #include "taichi/program/program.h" -void run_snode() { - /* - import taichi as ti, numpy as np - ti.init() - #ti.init(print_ir = True) - - n = 10 - place = ti.field(dtype = ti.i32) - ti.root.pointer(ti.i, n).place(place) - - @ti.kernel - def init(): - for index in range(n): - place[index] = index - - @ti.kernel - def ret() -> ti.i32: - sum = 0 - for index in place: - sum = sum + place[index] - return sum - - @ti.kernel - def ext(ext_arr: ti.ext_arr()): - for index in place: - ext_arr[index] = place[index] - - init() - print(ret()) - ext_arr = np.zeros(n, np.int32) - ext(ext_arr) - #ext_arr = place.to_numpy() - print(ext_arr) - */ - - using namespace taichi; - using namespace lang; - auto program = Program(Arch::x64); - /*CompileConfig config_print_ir; - config_print_ir.print_ir = true; - prog_.config = config_print_ir;*/ // print_ir = True - - int n = 10; - program.materialize_runtime(); - auto *root = new SNode(0, SNodeType::root); - auto *pointer = &root->pointer(Index(0), n); - auto *place = &pointer->insert_children(SNodeType::place); - place->dt = PrimitiveType::i32; - program.add_snode_tree(std::unique_ptr(root)); - - std::unique_ptr kernel_init, kernel_ret, kernel_ext; - - { - /* - @ti.kernel - def init(): - for index in range(n): - place[index] = index - */ - IRBuilder builder; - auto *zero = builder.get_int32(0); - auto *n_stmt = builder.get_int32(n); - auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4); - { - auto _ = builder.get_loop_guard(loop); - auto *index = builder.get_loop_index(loop); - auto *ptr = builder.create_global_ptr(place, {index}); - builder.create_global_store(ptr, index); - } - - kernel_init = - std::make_unique(program, builder.extract_ir(), "init"); - } - - { - /* - @ti.kernel - def ret(): - sum = 0 - for index in place: - sum = sum + place[index]; - return sum - */ - IRBuilder builder; - auto *sum = builder.create_local_var(PrimitiveType::i32); - auto *loop = builder.create_struct_for(pointer, 1, 0, 4); - { - auto _ = builder.get_loop_guard(loop); - auto *index = builder.get_loop_index(loop); - auto *sum_old = builder.create_local_load(sum); - auto *place_index = - builder.create_global_load(builder.create_global_ptr(place, {index})); - builder.create_local_store(sum, builder.create_add(sum_old, place_index)); - } - builder.create_return(builder.create_local_load(sum)); - - kernel_ret = std::make_unique(program, builder.extract_ir(), "ret"); - } - - { - /* - @ti.kernel - def ext(ext: ti.ext_arr()): - for index in place: - ext[index] = place[index]; - # ext = place.to_numpy() - */ - IRBuilder builder; - auto *loop = builder.create_struct_for(pointer, 1, 0, 4); - { - auto _ = builder.get_loop_guard(loop); - auto *index = builder.get_loop_index(loop); - auto *ext = builder.create_external_ptr( - builder.create_arg_load(0, PrimitiveType::i32, true), {index}); - auto *place_index = - builder.create_global_load(builder.create_global_ptr(place, {index})); - builder.create_global_store(ext, place_index); - } - - kernel_ext = std::make_unique(program, builder.extract_ir(), "ext"); - kernel_ext->insert_arg(get_data_type(), true); - } - - auto ctx_init = kernel_init->make_launch_context(); - auto ctx_ret = kernel_ret->make_launch_context(); - auto ctx_ext = kernel_ext->make_launch_context(); - std::vector ext_arr(n); - ctx_ext.set_arg_external_array(0, taichi::uint64(ext_arr.data()), n); - - (*kernel_init)(ctx_init); - (*kernel_ret)(ctx_ret); - std::cout << program.fetch_result(0) << std::endl; - (*kernel_ext)(ctx_ext); - for (int i = 0; i < n; i++) - std::cout << ext_arr[i] << " "; - std::cout << std::endl; -} - void autograd() { /* import taichi as ti, numpy as np @@ -211,16 +71,17 @@ void autograd() { } }; - auto *snode = &root->dense(0, n).insert_children(SNodeType::place); + auto *snode = + &root->dense(Axis(0), n, false).insert_children(SNodeType::place); snode->dt = PrimitiveType::f32; snode->grad_info = std::make_unique( - &root->dense(0, n).insert_children(SNodeType::place)); + &root->dense(Axis(0), n, false).insert_children(SNodeType::place)); snode->get_grad()->dt = PrimitiveType::f32; snode->get_grad()->grad_info = std::make_unique(); return snode; }; auto *a = get_snode_grad(), *b = get_snode_grad(), *c = get_snode_grad(); - program.add_snode_tree(std::unique_ptr(root)); + program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/false); std::unique_ptr kernel_init, kernel_forward, kernel_backward, kernel_ext; @@ -230,7 +91,7 @@ void autograd() { auto *zero = builder.get_int32(0); auto *one = builder.get_int32(1); auto *n_stmt = builder.get_int32(n); - auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4); + auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *i = builder.get_loop_index(loop); @@ -253,7 +114,7 @@ void autograd() { auto get_kernel_cal = [&](bool grad) -> Kernel * { IRBuilder builder; - auto *loop = builder.create_struct_for(a, 1, 0, 4); + auto *loop = builder.create_struct_for(a, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *i = builder.get_loop_index(loop); @@ -272,7 +133,7 @@ void autograd() { { IRBuilder builder; - auto *loop = builder.create_struct_for(a, 1, 0, 4); + auto *loop = builder.create_struct_for(a, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *i = builder.get_loop_index(loop); @@ -306,9 +167,12 @@ void autograd() { auto ctx_backward = kernel_backward->make_launch_context(); auto ctx_ext = kernel_ext->make_launch_context(); std::vector ext_a(n), ext_b(n), ext_c(n); - ctx_ext.set_arg_external_array(0, taichi::uint64(ext_a.data()), n); - ctx_ext.set_arg_external_array(1, taichi::uint64(ext_b.data()), n); - ctx_ext.set_arg_external_array(2, taichi::uint64(ext_c.data()), n); + ctx_ext.set_arg_external_array(0, taichi::uint64(ext_a.data()), n, + /*is_device_allocation=*/false); + ctx_ext.set_arg_external_array(1, taichi::uint64(ext_b.data()), n, + /*is_device_allocation=*/false); + ctx_ext.set_arg_external_array(2, taichi::uint64(ext_c.data()), n, + /*is_device_allocation=*/false); (*kernel_init)(ctx_init); (*kernel_forward)(ctx_forward); @@ -324,9 +188,3 @@ void autograd() { std::cout << ext_c[i] << " "; std::cout << std::endl; } - -int main() { - run_snode(); - autograd(); - return 0; -} diff --git a/cpp_examples/main.cpp b/cpp_examples/main.cpp new file mode 100644 index 0000000000000..af8656a7d5465 --- /dev/null +++ b/cpp_examples/main.cpp @@ -0,0 +1,14 @@ +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/statements.h" +#include "taichi/program/program.h" + +void run_snode(); +void autograd(); +void aot_save(); + +int main() { + run_snode(); + autograd(); + aot_save(); + return 0; +} diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp new file mode 100644 index 0000000000000..992f6ae1d79f2 --- /dev/null +++ b/cpp_examples/run_snode.cpp @@ -0,0 +1,141 @@ +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/statements.h" +#include "taichi/program/program.h" + +void run_snode() { + /* + import taichi as ti, numpy as np + ti.init() + #ti.init(print_ir = True) + + n = 10 + place = ti.field(dtype = ti.i32) + ti.root.pointer(ti.i, n).place(place) + + @ti.kernel + def init(): + for index in range(n): + place[index] = index + + @ti.kernel + def ret() -> ti.i32: + sum = 0 + for index in place: + sum = sum + place[index] + return sum + + @ti.kernel + def ext(ext_arr: ti.ext_arr()): + for index in place: + ext_arr[index] = place[index] + + init() + print(ret()) + ext_arr = np.zeros(n, np.int32) + ext(ext_arr) + #ext_arr = place.to_numpy() + print(ext_arr) + */ + + using namespace taichi; + using namespace lang; + auto program = Program(Arch::x64); + /*CompileConfig config_print_ir; + config_print_ir.print_ir = true; + prog_.config = config_print_ir;*/ // print_ir = True + + int n = 10; + program.materialize_runtime(); + auto *root = new SNode(0, SNodeType::root); + auto *pointer = &root->pointer(Axis(0), n, false); + auto *place = &pointer->insert_children(SNodeType::place); + place->dt = PrimitiveType::i32; + program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/false); + + std::unique_ptr kernel_init, kernel_ret, kernel_ext; + + { + /* + @ti.kernel + def init(): + for index in range(n): + place[index] = index + */ + IRBuilder builder; + auto *zero = builder.get_int32(0); + auto *n_stmt = builder.get_int32(n); + auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *ptr = builder.create_global_ptr(place, {index}); + builder.create_global_store(ptr, index); + } + + kernel_init = + std::make_unique(program, builder.extract_ir(), "init"); + } + + { + /* + @ti.kernel + def ret(): + sum = 0 + for index in place: + sum = sum + place[index]; + return sum + */ + IRBuilder builder; + auto *sum = builder.create_local_var(PrimitiveType::i32); + auto *loop = builder.create_struct_for(pointer, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *sum_old = builder.create_local_load(sum); + auto *place_index = + builder.create_global_load(builder.create_global_ptr(place, {index})); + builder.create_local_store(sum, builder.create_add(sum_old, place_index)); + } + builder.create_return(builder.create_local_load(sum)); + + kernel_ret = std::make_unique(program, builder.extract_ir(), "ret"); + } + + { + /* + @ti.kernel + def ext(ext: ti.ext_arr()): + for index in place: + ext[index] = place[index]; + # ext = place.to_numpy() + */ + IRBuilder builder; + auto *loop = builder.create_struct_for(pointer, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *ext = builder.create_external_ptr( + builder.create_arg_load(0, PrimitiveType::i32, true), {index}); + auto *place_index = + builder.create_global_load(builder.create_global_ptr(place, {index})); + builder.create_global_store(ext, place_index); + } + + kernel_ext = std::make_unique(program, builder.extract_ir(), "ext"); + kernel_ext->insert_arg(get_data_type(), true); + } + + auto ctx_init = kernel_init->make_launch_context(); + auto ctx_ret = kernel_ret->make_launch_context(); + auto ctx_ext = kernel_ext->make_launch_context(); + std::vector ext_arr(n); + ctx_ext.set_arg_external_array(0, taichi::uint64(ext_arr.data()), n, false); + + (*kernel_init)(ctx_init); + (*kernel_ret)(ctx_ret); + std::cout << program.fetch_result(0) << std::endl; + (*kernel_ext)(ctx_ext); + for (int i = 0; i < n; i++) + std::cout << ext_arr[i] << " "; + std::cout << std::endl; +} diff --git a/docs/fragments/.gitkeep b/docs/fragments/.gitkeep new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/docs/lang/api/index.md b/docs/lang/api/index.md deleted file mode 100644 index 3973c95ba54bc..0000000000000 --- a/docs/lang/api/index.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -sidebar_position: 1 ---- - -# API Docs - -:::info Taichi language API reference -Please refer to https://api-docs.taichi.graphics for the most up-to-date Taichi language's API reference while we are working on improving this page, thanks for your understanding! -::: diff --git a/docs/lang/articles/advanced/_category_.json b/docs/lang/articles/advanced/_category_.json index 84fc16f9348e5..321360c5fbf3d 100644 --- a/docs/lang/articles/advanced/_category_.json +++ b/docs/lang/articles/advanced/_category_.json @@ -1,4 +1,4 @@ { - "label": "Advanced Programming", + "label": "Advanced Topics", "position": 3 } diff --git a/docs/lang/articles/advanced/differentiable_programming.md b/docs/lang/articles/advanced/differentiable_programming.md index 7bdd46768cc2b..40377b5974f7d 100644 --- a/docs/lang/articles/advanced/differentiable_programming.md +++ b/docs/lang/articles/advanced/differentiable_programming.md @@ -79,14 +79,14 @@ print('dy/dx =', x.grad[None], ' at x =', x[None]) A common problem in physical simulation is that it's usually easy to compute energy but hard to compute force on every particle, e.g [Bond bending (and torsion) in molecular dynamics](https://github.com/victoriacity/taichimd/blob/5a44841cc8dfe5eb97de51f1d46f1bede1cc9936/taichimd/interaction.py#L190-L220) -and [FEM with hyperelastic energy functions](https://github.com/taichi-dev/taichi/blob/master/examples/simulation/fem128.py). +and [FEM with hyperelastic energy functions](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/fem128.py). Recall that we can differentiate (negative) potential energy to get forces: `F_i = -dU / dx_i`. So once you have coded a kernel that computes the potential energy, you may use Taichi's autodiff system to obtain the derivatives and then `F_i` on each particle. Take -[examples/simulation/ad_gravity.py](https://github.com/taichi-dev/taichi/blob/master/examples/simulation/ad_gravity.py) +[examples/simulation/ad_gravity.py](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/ad_gravity.py) as an example: ```python @@ -158,9 +158,9 @@ start up. :::tip See -[examples/simulation/mpm_lagrangian_forces.py](https://github.com/taichi-dev/taichi/blob/master/examples/simulation/mpm_lagrangian_forces.py) +[examples/simulation/mpm_lagrangian_forces.py](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/mpm_lagrangian_forces.py) and -[examples/simulation/fem99.py](https://github.com/taichi-dev/taichi/blob/master/examples/simulation/fem99.py) +[examples/simulation/fem99.py](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/fem99.py) for examples on using autodiff-based force evaluation MPM and FEM. ::: @@ -169,6 +169,10 @@ for examples on using autodiff-based force evaluation MPM and FEM. As mentioned above, `ti.Tape()` can only track a 0D field as the output variable. If there're multiple output variables that you want to back-propagate gradients to inputs, `kernel.grad()` should be used instead of `ti.Tape()`. +Different from using `ti.Tape()`, you need to set the `grad` of the output variables themselves to `1` manually +before calling the `kernel.grad()`. The reason is that the `grad` of the output variables themselves +will always be multiplied to the `grad` with respect to the inputs at the end of the back-propagation. +If using `ti.Tape()`, the program will help you do this under the hood. ```python {13-14} import taichi as ti @@ -188,6 +192,8 @@ def func(): for i in range(N): x[i] = i + +# Set the `grad` of the output variables to `1` before calling `func.grad()`. loss.grad[None] = 1 loss2.grad[None] = 1 @@ -296,41 +302,22 @@ func_break_rule_2.grad() assert x.grad[1] == 4.0 assert x.grad[2] == 3.0 ``` -### Kernel Simplicity Rule -:::note Kernel Simplicity Rule -Kernel body must consist of multiple simply nested for-loops. For example, each for-loop can either contain exactly one (nested) for-loop (and no other statements), or a group of statements without loops. +### Avoid mixed usage of parallel for-loop and non-for statements + +Mixed usage of parallel for-loops and non-for statements are not supported in the autodiff system. +Please split the two kinds of statements into different kernels. + +:::note +Kernel body must only consist of either multiple for-loops or non-for statements. ::: Example: ```python @ti.kernel -def differentiable_task1(): - # Good: simple for loop - for i in x: - x[i] = y[i] - -@ti.kernel -def differentiable_task2(): - # Good: one nested for loop - for i in range(10): - for j in range(20): - for k in range(300): - ... do whatever you want, as long as there are no loops - -@ti.kernel -def differentiable_task3(): - # Bad: the outer for loop contains two for loops. - for i in range(10): - for j in range(20): - ... - for j in range(20): - ... - -@ti.kernel -def differentiable_task4(): - # Bad: mixed usage of for-loop and a statement without looping. Please split them into two kernels. +def differentiable_task(): + # Bad: mixed usage of a parallel for-loop and a statement without looping. Please split them into two kernels. loss[None] += x[0] for i in range(10): ... @@ -345,27 +332,50 @@ to open a [github issue](https://github.com/taichi-dev/taichi/issues/new?assigne if you see any silent wrong results. ::: -### Workaround kernel simplicity rule +### Write differentiable code inside Taichi kernel -:::tip -**static for-loops** (e.g. `for i in ti.static(range(4))`) will get -unrolled by the Python frontend preprocessor and therefore does not -count as a level of loop. -::: +Taichi compiler only captures the code in the Taichi scope when performing the source code transformation for autodiff. Therefore, only the code written in Taichi scope is auto-differentiated. Although you can modify the `grad` of a field in python scope manually, the code is not auto-differentiated. + +Example: + +```python +import taichi as ti + +ti.init() +x = ti.field(dtype=float, shape=(), needs_grad=True) +loss = ti.field(dtype=float, shape=(), needs_grad=True) -For instance, we can rewrite `differentiable_task3` listed above using `ti.static`: -``` python @ti.kernel -def differentiable_task3(): - # Good: ti.static unrolls the inner loops so that it now only has one simple for loop. - for i in range(10): - for j in ti.static(range(20)): - ... - for j in ti.static(range(20)): - ... +def differentiable_task(): + for l in range(3): + loss[None] += ti.sin(x[None]) + 1.0 + +@ti.kernel +def manipulation_in_kernel(): + loss[None] += ti.sin(x[None]) + 1.0 + + +x[None] = 0.0 +with ti.Tape(loss=loss): + # The line below in python scope only contribute to the forward pass + # but not the backward pass i.e., not auto-differentiated. + loss[None] += ti.sin(x[None]) + 1.0 + + # Code in Taichi scope i.e. inside Taichi kernels, is auto-differentiated. + manipulation_in_kernel() + differentiable_task() + +# The outputs are 5.0 and 4.0 +print(loss[None], x.grad[None]) + +# You can modify the grad of a field manually in python scope, e.g., clear the grad. +x.grad[None] = 0.0 +# The output is 0.0 +print(x.grad[None]) ``` + ## Extending Taichi Autodiff system diff --git a/docs/lang/articles/advanced/layout.md b/docs/lang/articles/advanced/layout.md index be7088838e14a..dc41ca0ce38e1 100644 --- a/docs/lang/articles/advanced/layout.md +++ b/docs/lang/articles/advanced/layout.md @@ -3,58 +3,33 @@ sidebar_position: 2 --- # Fields (advanced) +Modern processor cores compute orders of magnitude faster than their equipped memory systems. To shrink this performance gap, multi-level cache systems and high-bandwidth multi-channel memories are built into computer architectures. -This section introduces some advanced features of Taichi fields. -Make sure you have gone through [Fields](../basic/field). +After familiar yourself with the basics of Taichi [Fields](../basic/field), this article helps you one step further by explaining the underlying memory layout that is essential to write high-performance Taichi programs. In particular, we present how to organize an efficient data layout and how to manage memory occupancy. -## Packed mode +## How to organize an efficient data layout -By default, all non-power-of-two dimensions of a field are automatically -padded to a power of two. For instance, a field of shape `(18, 65)` will -have an internal shape of `(32, 128)`. Although the padding has many benefits -such as allowing fast and convenient bitwise operations for coordinate handling, -it will consume potentially much more memory than expected. +In this section, we introduce how to organize data layouts in Taichi fields. The central principle of efficient data layout is _locality_. Generally speaking, a program with desirable locality has at least one of the following features: +* Dense data structure +* Loop over data in small-range (within 32KB is good for most processors) +* Sequential load/store -If you would like to reduce memory usage, you can use the optional packed -mode. In packed mode, no padding is applied such that a field does not have a -larger internal shape than the defined shape when some of its dimensions -are not powers of two. The downside is that the runtime performance will -regress slightly. - -A switch named `packed` for `ti.init()` decides whether to use packed mode: - -```python -ti.init() # default: packed=False -a = ti.field(ti.i32, shape=(18, 65)) # padded to (32, 128) -``` - -```python -ti.init(packed=True) -a = ti.field(ti.i32, shape=(18, 65)) # no padding -``` - -## Advanced data layouts +:::note -Apart from shape and data type, you can also specify the data layout of a -field in a recursive manner. This may allow you to achieve better performance. -Normally, you don't have to worry about the performance nuances between -different layouts, and you can just use the default one (simply by specifying -`shape` when creating fields) as a start. +Be aware that data are always fetched from memory in blocks (pages). The hardware has no knowledge about how a specific data element is used in the block. The processor blindly fetch the entire block according to the requested memory address. Therefore, the memory bandwidth is wasted when data are not fully utilized. -However, when a field gets large, a proper data layout may be critical to -performance, especially for memory-bound applications. A carefully designed -data layout has much better spatial locality, which will significantly -improve cache/TLB-hit rates and cache line utilization. +For sparse fields, refer to [Sparse computation](./sparse.md). -Taichi decouples computation from data structures, and the Taichi compiler -automatically optimizes data accesses on a specific data layout. This allows -you to quickly experiment with different data layouts and figure out the most -efficient one on a specific task and computer architecture. +::: ### Layout 101: from `shape` to `ti.root.X` -The following declares a 0-D field: + +In basic usages, we use the `shape` descriptor to construct a field. Taichi provides flexible statements to describe more advanced data organizations, the `ti.root.X`. +Let's get some familiarity with examples: + +* Declare a 0-D field: ```python {1-2} x = ti.field(ti.f32) @@ -63,7 +38,7 @@ ti.root.place(x) x = ti.field(ti.f32, shape=()) ``` -The following declares a 1D field of shape `3`: +* Declare a 1-D field of shape `3`: ```python {1-2} x = ti.field(ti.f32) @@ -72,7 +47,7 @@ ti.root.dense(ti.i, 3).place(x) x = ti.field(ti.f32, shape=3) ``` -The following declares a 2D field of shape `(3, 4)`: +* Declare a 2-D field of shape `(3, 4)`: ```python {1-2} x = ti.field(ti.f32) @@ -80,46 +55,48 @@ ti.root.dense(ti.ij, (3, 4)).place(x) # is equivalent to: x = ti.field(ti.f32, shape=(3, 4)) ``` +You can also nest two 1D `dense` statements to describe the same 2D array. +```python {1-2} +x = ti.field(ti.f32) +ti.root.dense(ti.i, 3).dense(ti.j, 4).place(x) +``` -After being comfortable with these equivalent definitions, you can move forward -and see how to change the data layout. +In a nutshell, the `ti.root.X` statement progressively binds a shape to the corresponding axis. +By nesting multiple statements, we can construct a field with higher dimensions. + -### Row-major versus column-major +In order to traverse the nested statements, we can use `struct-for`: +```python {1} +for i, j in A: + A[i, j] += 1 +``` +The order to access `A`, namely the order to iterate `i` and `j`, affects the program performance subtly. The Taichi compiler is capable to automatically deduce the underlying data layout and apply a proper access order. This is an advantage over most general-purpose programming languages where the access order has to be optimized manually. -As you might have learned in a computer architecture course, -address spaces are linear in modern computers. To -simplify the discussions, data type size will not be considered and will always -be treated as 1. Assume the starting address of a field is `base`. Then for 1D -Taichi fields, the address of the `i`-th element is simply `base + i`. +### Row-major versus column-major -However, a multi-dimensional field has to be flattened in order to fit into the -1D address space. For example, there are two ways to store a 2D field of size `(3, 2)`: +Memory address space is linear as you might have learnt from a computer architecture course. Without loss of generality, we omit the differences in data types and assume each data element has size 1. Moreover, we denote the starting memory address of a field as `base`, and the indexing formula for 1D Taichi fields is `base + i` for the `i`-th element. -- Row-major: let the address of the `(i, j)`-th element be `base + i * 2 + j`; -- Column-major: let the address of the `(i, j)`-th element be - `base + j * 3 + i`. +For multi-dimensional fields, we can flatten the high-dimension index into the linear memory address space in two ways: Take a 2D field of shape `(M, N)` as an instance, we can either store `M` rows with `N`-length 1D buffers, say the _row-major_ way, or store `N` columns, say the _column-major_ way. The index flatten formula for the `(i, j)`-th element is `base + i * N + j` for row-major and `base + j * M + i` for column-major, respectively. -To specify which layout to use (default layout is row-major): +We can easily derive that elements in the same row are close in memory for row-major fields. The selection of the optimal layout is based on how the elements are accessed, namely, the access patterns. Patterns such as frequently accessing elements of the same row in a column-major field typically lead to performance degradation. +The default Taichi field layout is row-major. With the `ti.root` statements, fields can be defined as follows: ```python -ti.root.dense(ti.i, 3).dense(ti.j, 2).place(x) # row-major -ti.root.dense(ti.j, 2).dense(ti.i, 3).place(y) # column-major +ti.root.dense(ti.i, M).dense(ti.j, N).place(x) # row-major +ti.root.dense(ti.j, N).dense(ti.i, M).place(y) # column-major ``` - -Both `x` and `y` have shape `(3, 2)`, and they can be accessed in the same -manner with `x[i, j]` and `y[i, j]`, where `0 <= i < 3 && 0 <= j < 2`. However, -they have different memory layouts: - +In the code above, the axis denotation in the rightmost `dense` statement indicates the continuous axis. For the `x` field, elements in the same row (with same `i` and different `j`) are close in memory, hence it's row-major; For the `y` field, elements in the same column (same `j` and different `i`) are close, hence it's column-major. With an example of (2, 3), we visualize the memory layouts of `x` and `y` as follows: ``` # address: low ........................................... high # x: x[0, 0] x[0, 1] x[1, 0] x[1, 1] x[2, 0] x[2, 1] # y: y[0, 0] y[1, 0] y[2, 0] y[0, 1] y[1, 1] y[2, 1] ``` -:::note +It is worth noting that the accessor is unified for Taichi fields: the `(i, j)`-th element in the field is accessed with the identical 2D index `x[i, j]` and `y[i, j]`. Taichi handles the layout variants and applies proper indexing equations internally. Thanks to this feature, users can specify their desired layout at definition, and use the fields without concerning about the underlying memory organizations. To change the layout, it's sufficient to just swap the order of `dense` statements, and leave rest of the code intact. -For those who are familiar with C/C++, here is what they look like in C code: +:::note +For readers who are familiar with C/C++, below is an example C code snippet that demonstrates data access in 2D arrays: ```c int x[3][2]; // row-major int y[2][3]; // column-major @@ -132,49 +109,60 @@ for (int i = 0; i < 3; i++) { } ``` +The accessors of `x` and `y` are in reverse order between row-major arrays and column-major arrays, respectively. Compared with Taichi fields, there are much more code to revise when you change the memory layout. + ::: -### Array of Structures (AoS) versus Structure of Arrays (SoA) + +### AoS versus SoA -Fields of same shape can be placed together. +AoS means _array of structures_ and SoA means _structure of arrays_. Consider an RGB image with 4 pixels and 3 color channels, an AoS layout stores `RGBRGBRGBRGB` while an SoA layout stores `RRRRGGGGBBBB`. -For example, the following places two 1D fields of shape `3` together, which -is called Array of Structures (AoS): +The selection of AoS or SoA layout largely depends on the access pattern to the field. Let's discuss a scenario to process large RGB images. The two layouts have the following arrangements in memory: +``` +# address: low ...................... high +# AoS: RGBRGBRGBRGBRGBRGB............. +# SoA: RRRRR...RGGGGGGG...GBBBBBBB...B +``` +To calculate grey scale of each pixel, we need all color channels but do not require the value of other pixels. In this case, the AoS layout has a better memory access pattern: Since color channels are stored continuously, and adjacent channels can be fetched instantly. The SoA layout is not a good option because the color channels of a pixel are stored far apart in the memory space. +We describe how to construct AoS and SoA fields with our `ti.root.X` statements. The SoA fields are trivial: ```python -ti.root.dense(ti.i, 3).place(x, y) +x = ti.field(ti.f32) +y = ti.field(ti.f32) +ti.root.dense(ti.i, M).place(x) +ti.root.dense(ti.i, M).place(y) ``` - -Their memory layout is: - +where M is the length of `x` and `y`. +The data elements in `x` and `y` are continuous in memory: ``` -# address: low ......................... high -# x[0] y[0] x[1] y[1] x[2] y[2] +# address: low ................................. high +# x[0] x[1] x[2] ... y[0] y[1] y[2] ... ``` -By contrast, the following places these two fields separately, which is called -Structure of Arrays (SoA): +For AoS fields, we construct the field with ```python -ti.root.dense(ti.i, 3).place(x) -ti.root.dense(ti.i, 3).place(y) +x = ti.field(ti.f32) +y = ti.field(ti.f32) +ti.root.dense(ti.i, M).place(x, y) ``` - -Now their memory layout is: - +The memroy layout then becomes ``` -# address: low ......................... high -# x[0] x[1] x[2] y[0] y[1] y[2] +# address: low .............................. high +# x[0] y[0] x[1] y[1] x[2] y[2] ... ``` +Here, `place` interleaves the elements of Taichi fields `x` and `y`. -**To improve spatial locality of memory accesses, it may be helpful to -place data elements that are often accessed together within -relatively close addresses.** Take a simple 1D wave equation solver as an example: +As previously introduced, the access methods to `x` and `y` remain the same for both AoS and SoA. Therefore, the data layout can be changed flexibly without revising the application logic. + +For better illustration, let's see an example of an 1D wave equation solver: ```python N = 200000 pos = ti.field(ti.f32) vel = ti.field(ti.f32) +# SoA placement ti.root.dense(ti.i, N).place(pos) ti.root.dense(ti.i, N).place(vel) @@ -182,127 +170,122 @@ ti.root.dense(ti.i, N).place(vel) def step(): pos[i] += vel[i] * dt vel[i] += -k * pos[i] * dt + +... ``` +The above code snippet defines SoA fields and a `step` kernel that sequentially accesses each element. +The kernel fetches an element from `pos` and `vel` for every iteration, respectively. +For SoA fields, the closest distance of any two elements in memory is `N`, which is unlikely to be efficient for large `N`. -Here, `pos` and `vel` are placed separately, so the distance in address -space between `pos[i]` and `vel[i]` is `200000`. This results in poor spatial -locality and poor performance. A better way is to place them together: +We hereby switch the layout to AoS as follows: ```python +N = 200000 +pos = ti.field(ti.f32) +vel = ti.field(ti.f32) +# AoS placement ti.root.dense(ti.i, N).place(pos, vel) + +@ti.kernel +def step(): + pos[i] += vel[i] * dt + vel[i] += -k * pos[i] * dt ``` +Merely revising the place statement is sufficient to change the layout. With this optimization, the instant elements `pos[i]` and `vel[i]` are now adjacent in memory, which is more efficient. -Then `vel[i]` is placed right next to `pos[i]`, which can increase spatial -locality and therefore improve performance. + -Struct-fors on nested dense data structures will automatically follow their -layout in memory. For example, if 2D scalar field `A` is defined in row-major, + -As you may notice, only dense data layouts are covered in this section. For sparse -data layouts, see [Sparse computation](./sparse.md). + -2D field, column-major: -```python -A = ti.field(ti.f32) -ti.root.dense(ti.j, 256).dense(ti.i, 256).place(A) -``` -_8x8_-blocked 2D field of size _1024x1024_: +### AoS extension: hierarchical fields + +Sometimes we want to access memory in a complex but fixed pattern, like traversing an image in 8x8 blocks. The apparent best practice is to flatten each 8x8 block and concatenate them together. From a Taichi user's perspective, however, the field is no longer a flat buffer. It now has a hierarchy with two levels: The image level and the block level. Equivalently, the field is an array of implicit 8x8 block structures. + +We demonstrate the statements as follows: ```python -density = ti.field(ti.f32) -ti.root.dense(ti.ij, (128, 128)).dense(ti.ij, (8, 8)).place(density) +# Flat field +val = ti.field(ti.f32) +ti.root.dense(ti.ij, (M, N)).place(val) ``` -3D particle positions and velocities, AoS: - ```python -pos = ti.Vector.field(3, dtype=ti.f32) -vel = ti.Vector.field(3, dtype=ti.f32) -ti.root.dense(ti.i, 1024).place(pos, vel) -# equivalent to -ti.root.dense(ti.i, 1024).place(pos.get_scalar_field(0), - pos.get_scalar_field(1), - pos.get_scalar_field(2), - vel.get_scalar_field(0), - vel.get_scalar_field(1), - vel.get_scalar_field(2)) +# Hierarchical field +val = ti.field(ti.f32) +ti.root.dense(ti.ij, (M // 8, N // 8)).dense(ti.ij, (8, 8)).place(val) ``` +where `M` and `N` are multiples of 8. We encourage you to try this out! The performance difference can be significant! -3D particle positions and velocities, SoA: +## How to manage memory occupancy -```python -pos = ti.Vector.field(3, dtype=ti.f32) -vel = ti.Vector.field(3, dtype=ti.f32) -for i in range(3): - ti.root.dense(ti.i, 1024).place(pos.get_scalar_field(i)) -for i in range(3): - ti.root.dense(ti.i, 1024).place(vel.get_scalar_field(i)) -``` +### Manual field allocation and destruction -## Dynamic field allocation and destruction +Generally Taichi manages memory allocation and destruction without disturbing the users. However, there are times that users want explicit control over their memory allocations. -You can use the `FieldsBuilder` class for dynamic field allocation and destruction. -`FieldsBuilder` has the same data structure declaration APIs as `ti.root`, -including `dense()`. After declaration, you need to call the `finalize()` -method to compile it to an `SNodeTree` object. +In this scenario, Taichi provides the `FieldsBuilder` for manual field memory allocation and destruction. `FieldsBuilder` features identical declaration APIs as `ti.root`. The extra step is to invoke `finalize()` at the end of all declarations. The `finalize()` returns an `SNodeTree` object to handle subsequent destructions. -A simple example is: +Let's see a simple example: ```python import taichi as ti @@ -318,53 +301,34 @@ x = ti.field(dtype=ti.f32) fb1.dense(ti.ij, (5, 5)).place(x) fb1_snode_tree = fb1.finalize() # Finalizes the FieldsBuilder and returns a SNodeTree func(x) +... +fb1_snode_tree.destroy() # Destruction fb2 = ti.FieldsBuilder() y = ti.field(dtype=ti.f32) fb2.dense(ti.i, 5).place(y) fb2_snode_tree = fb2.finalize() # Finalizes the FieldsBuilder and returns a SNodeTree func(y) +... +fb2_snode_tree.destroy() # Destruction ``` +Actually, the above demonstrated `ti.root` statements are implemented with `FieldsBuilder`, despite that `ti.root` has the capability to automatically manage memory allocations and recycling. -In fact, `ti.root` is implemented by `FieldsBuilder` implicitly, so you can -allocate the fields directly under `ti.root`: -```python -import taichi as ti -ti.init() # Implicitly: ti.root = ti.FieldsBuilder() +### Packed mode -@ti.kernel -def func(v: ti.template()): - for I in ti.grouped(v): - v[I] += 1 +By default, Taichi implicitly fits a field in a larger buffer with power-of-two dimensions. We take the power-of-two padding convention because it is widely adopted in computer graphics. The design enables fast indexing with bitwise arithmetic and better memory address alignment, while trading off memory occupations. -x = ti.field(dtype=ti.f32) -ti.root.dense(ti.ij, (5, 5)).place(x) -func(x) # Automatically calls ti.root.finalize() -# Implicitly: ti.root = ti.FieldsBuilder() +For example, a `(18, 65)` field is materialized with a `(32, 128)` buffer, which is acceptable. As field size grows, the padding strategy can be exaggeratedly unbearable: `(129, 6553600)` will be expanded to `(256, 6335600)`, which allocates considerable unsed blank memory. Therefore, Taichi provides the optional packed mode to allocate buffer that tightly fits the requested field shape. It is especially useful when memory usage is a major concern. -y = ti.field(dtype=ti.f32) -ti.root.dense(ti.i, 5).place(y) -func(y) # Automatically calls ti.root.finalize() +To leverage the packed mode, spcifify `packed` in `ti.init()` argument: +```python +ti.init() # default: packed=False +a = ti.field(ti.i32, shape=(18, 65)) # padded to (32, 128) ``` -Furthermore, if you don't want to use the fields under a certain `SNodeTree` -anymore, you could call the `destroy()` method on the finalized `SNodeTree` -object, which will recycle its memory into the memory pool: - -```py -import taichi as ti -ti.init() - -@ti.kernel -def func(v: ti.template()): - for I in ti.grouped(v): - v[I] += 1 - -fb = ti.FieldsBuilder() -x = ti.field(dtype=ti.f32) -fb.dense(ti.ij, (5, 5)).place(x) -fb_snode_tree = fb.finalize() # Finalizes the FieldsBuilder and returns a SNodeTree -func(x) - -fb_snode_tree.destroy() # x cannot be used anymore +```python +ti.init(packed=True) +a = ti.field(ti.i32, shape=(18, 65)) # no padding ``` + +You might observe mild performance regression with the packed mode due to more complex adressing and memory alignment. Therefore, the packed mode should be specified only when memory capacity is a major concern. diff --git a/docs/lang/articles/advanced/meta.md b/docs/lang/articles/advanced/meta.md index 9b9f7b54d8d61..82947bed60728 100644 --- a/docs/lang/articles/advanced/meta.md +++ b/docs/lang/articles/advanced/meta.md @@ -18,7 +18,7 @@ Every kernel in Taichi is a template kernel, even if it has no template argument ## Template metaprogramming -By using `ti.template()` as a argument type hint, a Taichi field can be passed into a kernel. Template programming also enables the code to be reused for fields with different shapes: +By using `ti.template()` as an argument type hint, a Taichi field or a python object can be passed into a kernel. Template programming also enables the code to be reused for fields with different shapes: ```python {2} @ti.kernel @@ -38,6 +38,10 @@ copy_1D(a, b) copy_1D(c, d) ``` +:::note +If a template parameter is not a Taichi object, it cannot be reassigned inside Taichi kernel. +::: + :::note The template parameters are inlined into the generated kernel after compilation. ::: @@ -122,6 +126,12 @@ def foo(): Using compile-time evaluation allows for some computation to be executed when kernels are instantiated. This helps the compiler to conduct optimization and reduce computational overhead at runtime: +### Static Scope +`ti.static` is a function which receives one argument. It is a hint for the compiler to evaluate the argument at compile time. +The scope of the argument of `ti.static` is called static-scope. + +### Compile-time branching + - Use `ti.static` for compile-time branching (for those who are familiar with C++17, this is similar to [if constexpr](https://en.cppreference.com/w/cpp/language/if).): @@ -139,6 +149,8 @@ def static(): One of the two branches of the `static if` will be discarded after compilation. ::: +### Loop unrolling + - Use `ti.static` for forced loop unrolling: ```python {3} @@ -173,3 +185,28 @@ def reset(): # The inner loop must be unrolled since j is an index for accessing a vector x[i][j] = 0 ``` + +## Compile-time recursion of `ti.func` + +A compile-time recursive function is a function with recursion that can be recursively inlined at compile time. The condition which determines whether to recurse is evaluated at compile time. + +You can combine [compile-time branching](#compile-time-evaluations) and [template](#template-metaprogramming) to write compile-time recursive functions. + +For example, `sum_from_one_to` is a compile-time recursive function that calculates the sum of numbers from `1` to `n`. + +```python {1-6} +@ti.func +def sum_from_one_to(n: ti.template()) -> ti.i32: + ret = 0 + if ti.static(n > 0): + ret = n + sum_from_one_to(n - 1) + return ret + +@ti.kernel +def sum_from_one_to_ten(): + print(sum_from_one_to(10)) # prints 55 +``` + +:::caution WARNING +When the recursion is too deep, it is not recommended to use compile-time recursion because deeper compile-time recursion expands to longer code during compilation, resulting in increased compilation time. +::: diff --git a/docs/lang/articles/advanced/performance.md b/docs/lang/articles/advanced/performance.md index eaed9c0d5f295..d2de92ffec58b 100644 --- a/docs/lang/articles/advanced/performance.md +++ b/docs/lang/articles/advanced/performance.md @@ -12,7 +12,7 @@ the target architecture. Nevertheless, for Ninjas who strive for the last few % of performance, we also provide some APIs to allow developers fine-tune their applications. For example, specifying a suitable `ti.block_dim` could yield an almost 3x performance boost in -[examples/mpm3d.py](https://github.com/taichi-dev/taichi/blob/master/examples/mpm3d.py). +[examples/mpm3d.py](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/mpm3d.py). :::note For **performance profiling** utilities, please see [Profiler section of the Contribution Guide](../misc/profiler.md). diff --git a/docs/lang/articles/advanced/sparse.md b/docs/lang/articles/advanced/sparse.md index 7c17b41631e2c..fd0528d688c80 100644 --- a/docs/lang/articles/advanced/sparse.md +++ b/docs/lang/articles/advanced/sparse.md @@ -2,50 +2,29 @@ sidebar_position: 3 --- -# Sparse computation - -Compiler-level support for spatially sparse computation is a unique feature of Taichi. - -![image](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/sparse_grids.gif) - -Figure: A swinging "Taichi" pattern represented with a 512x512 sparse grid. The sparse grid has a multi-level *tree* structure. -White stands for inactive tree nodes, and active tree nodes are darker. - -The sparse grid above has the following structure: -- The grid is divided into 8x8 `block1` containers; -- Each `block1` container has 4x4 `block2` cells; -- Each `block2` container has 4x4 `block3` cells; -- Each `block3` container has 4x4 pixel cells; -- Each pixel contains an `i32` value `x[i, j]`. +# Sparse spatial data structures :::note -For more information about *cells* and *containers*, see [**Data structure organization**](../misc/internal.md#data-structure-organization). -In this article, you can assume *containers* and *cells* are the same. +Prerequisite: please read the [Fields](lang/articles/basic/field.md), [Fields (advanced)](lang/articles/advanced/layout.md), and [SNodes](lang/articles/misc/internal.md#data-structure-organization) first. ::: -Taichi allows you to define sparse data structures effortlessly. For example, you can define the grid above as - -```python -x = ti.field(dtype=ti.i32) - -block1 = ti.root.pointer(ti.ij, 8) -block2 = block1.pointer(ti.ij, 4) -block3 = block2.pointer(ti.ij, 4) -block3.dense(ti.ij, 4).place(x) -``` -[[Full source code of this animation]](https://github.com/taichi-dev/taichi/blob/master/examples/features/sparse/taichi_sparse.py) - -Intuitively, a sparse grid in Taichi allows you to use memory space more wisely, since only tree nodes involved in computation are allocated. -Now, let's take a step back and think about *why we need sparse grids, how to define them in Taichi, and how to compute on these data structures*. +![image](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/sparse_grids_3d.jpg) +Figure: A 3D fluid simulation that uses both particles and grids. Left to right: particles, 1x1x1 voxels, 4x4x4 blocks, 16x16x16 blocks. ## Motivation High-resolution 2D/3D grids are often needed in large-scale spatial computation, such as physical simulation, rendering, and 3D reconstruction. -However, these grids tend to consume a huge amount of memory space and computation. +However, these grids tend to consume a huge amount of memory space and computation if we use dense data structures (see [field](lang/articles/basic/field.md) and [field advanced](lang/articles/advanced/layout.md)). While a programmer may allocate large dense grids to store spatial data (especially physical quantities such as a density or velocity field), oftentimes, they only care about a small fraction of this dense grid since the rest may be empty space (vacuum or air). -In short, the regions of interest in sparse grids may only occupy a small fraction of the whole bounding box. +
+ +![BVH](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/bvh.png) + +
+ +For example, the regions of interest in sparse grids shown above may only occupy a small fraction of the whole bounding box. If we can leverage such "spatial sparsity" and focus computation on the regions we care about, we will significantly save storage and computing power. @@ -53,265 +32,258 @@ we will significantly save storage and computing power. The key to leverage spatial sparsity is to replace *dense* grids with *sparse* grids. ::: -On a sparse data structure, we consider a pixel, voxel, or a grid node to be *active*, -if it is allocated and involved in computation. -The rest of the grid is simply *inactive*. -The *activity* of a leaf or intermediate cell is a boolean value. The activity value of a cell is `True` if and only if the cell is *active*. - -Below is a 2D multi-physics simulation (material point method) with 256x256 grid cells. -Since the simulated objects do not fully occupy the whole domain, we would like to *adaptively* allocate the underlying simulation grid. -We subdivide the whole simulation domain into 16x16 *blocks*, -and each block has 16x16 *grid cells*. -Memory allocation can then happen at *block* granularity, -and we only consume memory space of blocks that are actually in the simulation. +The traditional sparse spatial data stuctures are [Quadtrees](https://en.wikipedia.org/wiki/Quadtree) (2D) and +[Octrees](https://en.wikipedia.org/wiki/Octree) (3D). Since dereferencing pointers is relatively costly on modern computer architectures, compared to quadtrees and octrees, it is more performance-friendly to use shallower trees with larger branching factors. +[VDB](https://www.openvdb.org/) and [SPGrid](http://pages.cs.wisc.edu/~sifakis/papers/SPGrid.pdf) are such examples. +In Taichi, programmers can compose data structures similar to VDB and SPGrid with SNodes. The advantages of Taichi sparse spatial data structures include +1. Access with indices, which just like accessing a dense data structure. +2. Automatic parallelization when iterating. +3. Automatic memory access optimization. -![image](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi_elements/sparse_mpm_active_blocks.gif) -(Note the changing distribution of active blocks throughout the simulation.) :::note -**Backend compatibility**: The LLVM backends (CPU/CUDA) and the Metal backend offer the full functionality of sparse computation. -Other backends provide no or limited support of sparse computation. +**Backend compatibility**: The LLVM backends (CPU/CUDA) and the Metal backend offer the full functionality of computation on sparse spatial data structures. ::: + :::note -Sparse matrices are usually **not** implemented in Taichi via (spatially-) sparse data structures. Use `ti.SparseMatrixBuilder` instead. +Sparse matrices are usually **not** implemented in Taichi via sparse spatial data structures. See [sparse matrix](lang/articles/advanced/sparse_matrix.md) instead. ::: -## Defining sparse data structures in Taichi +## Sparse spatial data structures in Taichi -Ideally, it would be nice to have a sparse voxel data structure that consumes space or computation only when the voxels are active. -Practically, Taichi programmers use hierarchical data structures (trees) to organize sparse voxel data. +Sparse spatial data structures in Taichi are usually composed of `pointer`, `bitmasked`, `dynamic`, and `dense` SNodes. A SNode tree merely composed of `dense` SNodes is not a sparse spatial data structure. -### Data structure hierarchy - -Traditionally, [Quadtrees](https://en.wikipedia.org/wiki/Quadtree) (2D) and -[Octrees](https://en.wikipedia.org/wiki/Octree) (3D) are often adopted. -Since dereferencing pointers is relatively costly on modern computer architectures, -compared to quadtrees and octrees, it is more performance-friendly to use shallower trees with larger branching factors. -[VDB](https://www.openvdb.org/) and [SPGrid](http://pages.cs.wisc.edu/~sifakis/papers/SPGrid.pdf) are such examples. -In Taichi, programmers can compose data structures similar to VDB and SPGrid with SNodes. - -![image](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/sparse_grids_3d.jpg) -Figure: A 3D fluid simulation that uses both particles and grids. Left to right: particles, 1x1x1 voxels, 4x4x4 blocks, 16x16x16 blocks. +On a sparse spatial data structure, we consider a pixel, voxel, or a grid node to be *active*, +if it is allocated and involved in the computation. +The rest of the grid is simply *inactive*. +In SNode terms, the *activity* of a leaf or intermediate cell is a boolean value. The activity value of a cell is `True` if and only if the cell is *active*. When writing to an inactive cell, Taichi automatically activates it. Taichi also provides manual manipulation of the activity of a cell, see [Explicitly manipulating and querying sparsity](#explicitly-manipulating-and-querying-sparsity). -#### Blocked leaf cells and bitmasks +:::note +Reading an inactive pixel returns zero. +::: -While a null pointer can effectively represent an empty sub-tree, at the leaf level using 64 bits to represent the activity -of a single voxel can consume too much space. -For example, if each voxel contains a single `f32` value (4 bytes), -the 64-bit pointer pointing to the value would take 8 bytes. -The fact that storage costs of pointers are higher than the space to store the value themselves -goes against our goal to use sparse data structures to save space. +### Pointer SNode -To amortize the storage cost of pointers, programmers usually organize voxels in a *blocked* manner -and let the pointers directly point to the blocks (instead of voxels). +```python {2} title=pointer.py +x = ti.field(ti.f32) +block = ti.root.pointer(ti.ij, (4,4)) +pixel = block.dense(ti.ij, (2,2)) +pixel.place(x) -One caveat of this design is that voxels in the same `dense` block can no longer change their activity flexibly. -Instead, they share a single activity flag. To address this issue, -the `bitmasked` SNode additionally allocates 1-bit per voxel data to represent the voxel activity. +@ti.kernel +def activate(): + x[2,3] = 1.0 + x[2,4] = 2.0 -### A typical sparse data structure +@ti.kernel +def print_active(): + for i, j in block: + print("Active block", i, j) + # output: Active block 1 1 + # Active block 1 2 + for i, j in x: + print('field x[{}, {}] = {}'.format(i, j, x[i, j])) + # output: field x[2, 2] = 0.000000 + # field x[2, 3] = 1.000000 + # field x[3, 2] = 0.000000 + # field x[3, 3] = 0.000000 + # field x[2, 4] = 2.000000 + # field x[2, 5] = 0.000000 + # field x[3, 4] = 0.000000 + # field x[3, 5] = 0.000000 +``` +The code snippet above creates an 8x8 sparse grid, with the top-level being a 4x4 pointer array (line 2 of `pointer.py`), +and each pointer pointing to a 2x2 dense block. +You can write and read the sparse field like a dense field using indices. The below figure shows the active blocks and pixels in green. -Sparse data structures in Taichi are usually composed of `pointer`, `dense`, and `bitmasked` SNodes. -The code snippet below creates an 8x8 sparse grid, with the top level being 4x4 pointer arrays, -and each pointer points to a 2x2 dense block. +
-```python -x = ti.field(dtype=ti.i32) +![Pointer](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/pointer.png) -block = ti.root.pointer(ti.ij, (4, 4)) -pixel = block.dense(ti.ij, (2, 2)) -pixel.place(x) -``` +
-![image](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/sparse_grids_2d.png) +Executing the `activate()` function automatically activates `block[1,1]`, which includes `x[2,3]`, and `block[1,2]`, which includes `x[2,4]`. Other pixels of `block[1,1]` (`x[2,2], x[3,2], x[3,3]`) and `block[1,2]` (`x[2,5], x[3,4], x[3,5]`) are also implicitly activated because all pixels in the dense block share the same activity value. -## Computation on sparse data structures +In fact, the sparse field is a SNode tree shown in the following figure. You could use the struct-for loop to loop over the different levels of the SNode tree like the `print_active()` function in `pointer.py`. `for i, j in block` would loop over all active `pointer` SNodes. `for i, j in pixel` would loop over all active `dense` SNodes. -### Activation on write +
-When writing to an inactive cell on a sparse data structure, Taichi automatically populates the data structure. +![Pointer SNode Tree](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/pointer_tree.png) -For example, when executing `x[2, 3] = 2` on the aforementioned sparse grid `x`, -Taichi automatically activates `block[1, 1]` so that `pixel[2, 3]` is allocated. +
-:::note -Reading an inactive voxel returns zero. -::: -### Sparse struct-fors -Efficiently looping over sparse grid cells that distribute irregularly can be challenging, especially on parallel devices such as GPUs. -In Taichi, *struct-for's* natively support sparse data structures and only loops over currently active voxels. -The Taichi system ensures efficient parallelization. -You can loop over different levels of the tree. -The code below demonstrates the creation and manipulation of a sparse grid: +### Bitmasked SNode -```python -import taichi as ti +While a null pointer can effectively represent an empty sub-tree, at the leaf level using 64 bits to represent the activity +of a single pixel can consume too much space. +For example, if each pixel contains a single `f32` value (4 bytes), +the 64-bit pointer pointing to the value would take 8 bytes. +The fact that storage costs of pointers are higher than the space to store the value themselves +goes against our goal to use sparse spatial data structures to save space. -use_bitmask = True +To amortize the storage cost of pointers, you could organize pixels in a *blocked* manner +and let the pointers directly point to the blocks like the data structure defined in `pointer.py`. -ti.init() +One caveat of this design is that pixels in the same `dense` block can no longer change their activity flexibly. +Instead, they share a single activity flag. To address this issue, +the `bitmasked` SNode additionally allocates 1-bit per pixel data to represent the pixel activity. -x = ti.field(dtype=ti.i32) -block = ti.root.pointer(ti.ij, (4, 4)) -if use_bitmask: - pixel = block.bitmasked(ti.ij, (2, 2)) -else: - pixel = block.dense(ti.ij, (2, 2)) +```python {3} title=bitmasked.py +x = ti.field(ti.f32) +block = ti.root.pointer(ti.ij, (4,4)) +pixel = block.bitmasked(ti.ij, (2,2)) pixel.place(x) @ti.kernel -def sparse_struct_for(): - x[2, 3] = 2 - x[5, 6] = 3 +def activate(): + x[2,3] = 1.0 + x[2,4] = 2.0 +@ti.kernel +def print_active(): + for i, j in block: + print("Active block", i, j) for i, j in x: print('field x[{}, {}] = {}'.format(i, j, x[i, j])) - - for i, j in block: - print('Active block: [{}, {}]'.format(i, j)) - -print('use_bitmask = {}'.format(use_bitmask)) -sparse_struct_for() ``` -When `bitmask = True`, the program above outputs -``` -field x[2, 3] = 2 -field x[5, 6] = 3 -Active block: [1, 1] -Active block: [2, 3] -``` +The code snippet above also creates an 8x8 sparse grid. The only difference between `bitmasked.py` and `pointer.py` is that the bitmasked SNode replaces the dense SNode (line 3). As shown in the figure below, the active blocks are the same as `pointer.py`. However, the bitmasked pixels in the block are not all activated, because each of them has an activity value. -When `bitmask = False`, we get -``` -field x[2, 2] = 0 -field x[2, 3] = 2 -field x[3, 2] = 0 -field x[3, 3] = 0 -field x[4, 6] = 0 -field x[4, 7] = 0 -field x[5, 6] = 3 -field x[5, 7] = 0 -Active block: [1, 1] -Active block: [2, 3] -``` +
-When using a `dense` SNode as the leaf block, -activating `x[2, 3]` also implicitly activates other pixels in `block[1, 1]`, i.e., `x[2, 2]`, `x[3, 2]`, and `x[3, 3]`. -Without a bitmask, these pixels in the same `block` share the same activity. +![Bitmasked](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/bitmasked.png) -### Explicitly manipulating and querying sparsity +
-Taichi also provides APIs that explicitly manipulates data structure sparsity. -- Use `ti.is_active(snode, [i, j, ...])` to query if `snode[i, j, ...]` is active or not. -- `ti.activate/deactivate(snode, [i, j, ...])` to explicitly activate or deactivate a cell of `snode[i, j, ...]`. -- Use `snode.deactivate_all()` to deactivate all cells of SNode `snode`. This operation also recursively deactivates all its children. -- Use `ti.deactivate_all_snodes()` to deactivate all cells of all SNodes with sparsity. -- Use `ti.rescale_index(descendant_snode/field, ancestor_snode, index)` to compute the ancestor index given a descendant index. +The bitmasked SNodes are like dense SNodes with auxiliary activity values. +
-Below is an example of these APIs: +![Bitmasked SNode Tree](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/bitmasked_tree.png) -```python -import taichi as ti +
-ti.init() +### Dynamic SNode -x = ti.field(dtype=ti.i32) -block1 = ti.root.pointer(ti.ij, (4, 4)) -block2 = block1.pointer(ti.ij, (2, 2)) -pixel = block2.dense(ti.ij, (2, 2)) +To support variable-length fields, Taichi provides dynamic SNodes. The code snippet below first creates a 5x1 dense block (line 2). Then each cell of the dense block contains a variable-length dynamic container (line 3). The maximum length of the dynamic container is 5. In the `make_lists()` function, you can use `ti.append()` to add a value to the end of a dynamic SNode. `x.parent()` is the same as `pixel`. The dense field `l` stores the length of each dynamic SNode. + +```python {3} title=dynamic.py +x = ti.field(ti.i32) +block = ti.root.dense(ti.i, 5) +pixel = block.dynamic(ti.j, 5) pixel.place(x) +l = ti.field(ti.i32) +ti.root.dense(ti.i, 5).place(l) @ti.kernel -def sparse_api_demo(): - ti.activate(block1, [0, 1]) - ti.activate(block2, [1, 2]) +def make_lists(): + for i in range(5): + for j in range(i): + ti.append(x.parent(), i, j * j) # ti.append(pixel, i, j * j) + l[i] = ti.length(x.parent(), i) # [0, 1, 2, 3, 4] +``` - for i, j in x: - print('field x[{}, {}] = {}'.format(i, j, x[i, j])) - # outputs: - # field x[2, 4] = 0 - # field x[2, 5] = 0 - # field x[3, 4] = 0 - # field x[3, 5] = 0 - for i, j in block2: - print('Active block2: [{}, {}]'.format(i, j)) - # output: Active block2: [1, 2] +
- for i, j in block1: - print('Active block1: [{}, {}]'.format(i, j)) - # output: Active block1: [0, 1] +![Dynamic](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/dynamic.png) - for j in range(4): - print('Activity of block2[2, {}] = {}'.format(j, ti.is_active(block2, [1, j]))) +
- ti.deactivate(block2, [1, 2]) +## Computation on sparse spatial data structures - for i, j in block2: - print('Active block2: [{}, {}]'.format(i, j)) - # output: nothing +### Sparse struct-fors - for i, j in block1: - print('Active block1: [{}, {}]'.format(i, j)) - # output: Active block1: [0, 1] +Efficiently looping over sparse grid cells that distribute irregularly can be challenging, especially on parallel devices such as GPUs. +In Taichi, *struct-for*s natively support sparse spatial data structures and only loop over currently active pixels with automatic efficient parallelization. - print(ti.rescale_index(x, block1, ti.Vector([9, 17]))) - # output = [2, 4] +### Explicitly manipulating and querying sparsity - # Note: ti.Vector is optional in ti.rescale_index. - print(ti.rescale_index(x, block1, [9, 17])) - # output = [2, 4] +Taichi also provides APIs that explicitly manipulate data structure sparsity. You can manually **check** the activity of a SNode, **activate** a SNode, or **deactivate** a SNode. We now illustrate these functions based on the field defined below. - ti.activate(block2, [1, 2]) +```python +x = ti.field(dtype=ti.i32) +block1 = ti.root.pointer(ti.ij, (3, 3)) +block2 = block1.pointer(ti.ij, (2, 2)) +pixel = block2.bitmasked(ti.ij, (2, 2)) +pixel.place(x) +``` -sparse_api_demo() +#### 1. Activity checking +You can use `ti.is_active(snode, [i, j, ...])` to explicitly query if `snode[i, j, ...]` is active or not. +```python @ti.kernel -def check_activity(snode: ti.template(), i: ti.i32, j: ti.i32): +def activity_checking(snode: ti.template(), i: ti.i32, j: ti.i32): print(ti.is_active(snode, [i, j])) -check_activity(block2, 1, 2) # output = 1 -block2.deactivate_all() -check_activity(block2, 1, 2) # output = 0 -check_activity(block1, 0, 1) # output = 1 -ti.deactivate_all_snodes() -check_activity(block1, 0, 1) # output = 0 +for i in range(3): + for j in range(3): + activity_checking(block1, i, j) +for i in range(6): + for j in range(6): + activity_checking(block2, i, j) +for i in range(12): + for j in range(12): + activity_checking(pixel, i, j) +``` +#### 2. Activation +You can use `ti.activate(snode, [i, j, ...])` to explicitly activate a cell of `snode[i, j, ...]`. +```python +@ti.kernel +def activate_snodes() + ti.activate(block1, [1, 0]) + ti.activate(block2, [3, 1]) + ti.activate(pixel, [7, 3]) + +activity_checking(block1, [1, 0]) # output: 1 +activity_checking(block2, [3, 1]) # output: 1 +activity_checking(pixel, [7, 3]) # output: 1 ``` +
+ +![Activation](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/activation.png) + +
+ +#### 3. Deactivation +- Use `ti.deactivate(snode, [i, j, ...])` to explicitly deactivate a cell of `snode[i, j, ...]`. +- Use `snode.deactivate_all()` to deactivate all cells of SNode `snode`. This operation also recursively deactivates all its children. +- Use `ti.deactivate_all_snodes()` to deactivate all cells of all SNodes with sparsity. + +When deactivation happens, the Taichi runtime automatically recycles and zero-fills memory of the deactivated containers. :::note For performance reasons, `ti.activate(snode, index)` only activates `snode[index]`. -The programmer must ensure all ancestor containers of `snode[index]` are already active. +The programmer must ensure all ancestor containers of `snode[index]` is already active. Otherwise, this operation results in undefined behavior. Similarly, `ti.deactivate` ... - does **not** recursively deactivate all the descendants of a cell. -- does **not** trigger an deactivation of its parent container, even if all the children of the parent container are deactivated. +- does **not** trigger deactivation of its parent container, even if all the children of the parent container are deactivated. ::: -:::note -When deactivation happens, the Taichi runtime automatically recycles and zero-fills memory of the deactivated containers. -::: +#### 4. Ancestor index query +You can use `ti.rescale_index(descendant_snode/field, ancestor_snode, index)` to compute the ancestor index given a descendant index. -:::note -While it is possible to directly use `[i // 2, j // 2]` to compute the `block` index given `pixel` index, -doing so couples computation code with the internal configuration of data structures (in this case, the size of `block` containers). +```python +print(ti.rescale_index(x, block1, ti.Vector([7, 3]))) # output: [1, 0] +print(ti.rescale_index(x, block2, [7, 3])) # output: [3, 1] +print(ti.rescale_index(x, pixel, [7, 3])) # output: [7, 3] +print(ti.rescale_index(block1, block2, [3, 1])) # output: [1, 0] +``` -Use `ti.rescale_index` to avoid hard-coding internal information of data structures. -::: +Regarding line 1, you can also compute the `block1` index given `pixel` index `[7, 3]` as `[7//2//2, 3//2//2]`. However, doing so couples computation code with the internal configuration of data structures (in this case, the size of `block1` containers). By using `ti.rescale_index()`, you can avoid hard-coding internal information of data structures. ## Further reading -Please read our [paper](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang.pdf), -watch the [introduction video](https://www.youtube.com/watch?v=wKw8LMF3Djo), or check out -the SIGGRAPH Asia 2019 [slides](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang-slides.pdf) -for more details on sparse computation. +Please read the SIGGRAPH Asia 2019 [paper](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang.pdf) or watch the associated +[introduction video](https://www.youtube.com/watch?v=wKw8LMF3Djo) with [slides](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang-slides.pdf) +for more details on computation of sparse spatial data structures. -[Taichi elements](https://github.com/taichi-dev/taichi_elements) implement a high-performance -MLS-MPM solver on Taichi sparse grids. +[Taichi elements](https://github.com/taichi-dev/taichi_elements) implement a high-performance MLS-MPM solver on Taichi sparse grids. diff --git a/docs/lang/articles/advanced/sparse_matrix.md b/docs/lang/articles/advanced/sparse_matrix.md index 9a619991f6e14..5e69c1d7a457b 100644 --- a/docs/lang/articles/advanced/sparse_matrix.md +++ b/docs/lang/articles/advanced/sparse_matrix.md @@ -22,7 +22,7 @@ n = 4 K = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) @ti.kernel -def fill(A: ti.linalg.sparse_matrix_builder()): +def fill(A: ti.types.sparse_matrix_builder()): for i in range(n): A[i, i] += 1 # Only += and -= operators are supported for now. @@ -146,7 +146,7 @@ K = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) b = ti.field(ti.f32, shape=n) @ti.kernel -def fill(A: ti.linalg.sparse_matrix_builder(), b: ti.template(), interval: ti.i32): +def fill(A: ti.types.sparse_matrix_builder(), b: ti.template(), interval: ti.i32): for i in range(n): A[i, i] += 2.0 @@ -184,5 +184,5 @@ print(f">>>> Computation was successful?: {isSuccess}") ## Examples Please have a look at our two demos for more information: -+ [Stable fluid](https://github.com/taichi-dev/taichi/blob/master/examples/simulation/stable_fluid.py): A 2D fluid simulation using a sparse Laplacian matrix to solve Poisson's pressure equation. -+ [Implicit mass spring](https://github.com/taichi-dev/taichi/blob/master/examples/simulation/implicit_mass_spring.py): A 2D cloth simulation demo using sparse matrices to solve the linear systems. ++ [Stable fluid](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/stable_fluid.py): A 2D fluid simulation using a sparse Laplacian matrix to solve Poisson's pressure equation. ++ [Implicit mass spring](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/simulation/implicit_mass_spring.py): A 2D cloth simulation demo using sparse matrices to solve the linear systems. diff --git a/docs/lang/articles/basic/_category_.json b/docs/lang/articles/basic/_category_.json index 60b7ef558624d..245518655465e 100644 --- a/docs/lang/articles/basic/_category_.json +++ b/docs/lang/articles/basic/_category_.json @@ -1,4 +1,4 @@ { - "label": "Taichi Language Basic Concepts", + "label": "Basic Concepts", "position": 2 } diff --git a/docs/lang/articles/basic/differences_between_taichi_and_python_programs.md b/docs/lang/articles/basic/differences_between_taichi_and_python_programs.md new file mode 100644 index 0000000000000..89de2268c05f5 --- /dev/null +++ b/docs/lang/articles/basic/differences_between_taichi_and_python_programs.md @@ -0,0 +1,162 @@ +--- +sidebar_position: 2 +--- + +# Differences between Taichi and Python programs + +Although Taichi uses Python as the frontend, it follows a different set of rules in many aspects, including: + +1. [Taichi only supports return statement outside non-static `if`/`for`/`while` scope in the program](#return-statement) +2. [Variables defined inside an `if`/`for`/`while` block cannot be accessed outside the block.](#variable-scoping) +3. [Taichi does not fully support some language features of Python.](#unsupportedpartially-supported-python-language-features) + - [Set, list, dictionary and operator `in`](#set-list-dictionary-and-operator-in) + - [Comprehensions](#comprehensions) + - [Operator `is`](#operator-is) + +## Return statement and return type annotation + +- If a Taichi kernel/function does not have a return statement, it must not have return type annotation. +- If a Taichi kernel has a return statement, it must have return type annotation. +- If a Taichi function has a return statement, return type annotation is recommended, and it will be mandatory in the future. + +```python {3,7,10,14} +@ti.kernel +def error_kernel_no_return_annotation(): + return 0 # Error: Have return statement but have no return type annotation + +@ti.kernel +def error_kernel_no_return() -> ti.i31: # Error: Have return type annotation but have no return statement + pass + +@ti.func +def error_func_no_return() -> ti.i31: # Error: Have return type annotation but have no return statement + pass +``` + +- The return statement can not be in a scope of non-static `if`/`for`/`while`. + +```python {4} +@ti.kernel +def error_return_inside_non_static_if(a: ti.i32) -> ti.i32: + if a: + return 1 # Error: Return statement inside if scope +``` + +- The compiler discards code after the first return statement. + +```python {4-5} +@ti.kernel +def discarded_after_first_return(a: ti.i32) -> ti.i32: + return 1 + if a: # Discarded + return 1 # Discarded + +discarded_after_first_return(0) # OK: returns 1 +``` +- If there are [compile-time evaluations](/lang/articles/advanced/meta#compile-time-evaluations) in the code, make sure there is a return statement under all circumstances. +Otherwise, error occurs when a branch is chosen which does not have return statement. +```python {7-8,15-16,21,23-24} +@ti.kernel +def return_inside_static_if(a: ti.template()) -> ti.i32: + if ti.static(a): + return 1 + return 0 + +return_inside_static_if(1) # OK: Returns 1 +return_inside_static_if(0) # OK: Returns 0 + +@ti.kernel +def return_inside_static_if_no_return_outside(a: ti.template()) -> ti.i32: + if ti.static(a): + return 1 + +return_inside_static_if_no_return_outside(1) # OK: Returns 1 +return_inside_static_if_no_return_outside(0) # Error: No return statement + +@ti.kernel +def ok_return_inside_static_for() -> ti.i32: + a = 0 + for i in ti.static(range(10)): # Static for + a += i + if ti.static(i == 8): # Static if + return a # OK: Returns 36 +``` + +## Variable scoping + +In Python, a variable defined inside an `if`/`for`/`while` block can be accessed outside the block. +**However**, in Taichi, the variables can only be accessed **within the block it is defined**. + +```python {5,13,17,22} +@ti.kernel +def error_access_var_outside_for() -> ti.i32: + for i in range(10): + a = i + return a # Error: variable "a" not found + +@ti.kernel +def error_access_var_outside_if(a: ti.i32) -> ti.i32: + if a: + b = 1 + else: + b = 2 + return b # Error: variable "b" not found + +@ti.kernel +def ok_define_var_before_if(a: ti.i32) -> ti.i32: + b = 0 + if a: + b = 1 + else: + b = 2 + return b # OK: "b" is defined before "if" + +ok_define_var_before_if(0) # Returns 2 +``` + +## Unsupported/partially supported Python language features + +### Set, list, dictionary and operator `in` + +Currently, Taichi does not support `set`. + +List and dictionary before assigning to a variable works as the python list and dictionary. +However, after assigning to a variable, the content of the list and the values (not keys) of the dictionary are converted to Taichi variables. + +Taichi does not have a runtime implementation of `in` currently. Therefore, operator `in` and `not in` only works in [static scope](/lang/articles/advanced/meta#static-scope) (inside `ti.static()`). + +```python {3,11-12,20} +@ti.kernel +def list_without_assign() -> ti.i32: + if ti.static(1 in [1, 2]): # [1, 2] + return 1 + return 0 + +list_without_assign() # Returns 1 + +@ti.kernel +def list_assigned() -> ti.i32: + a = [1, 2] # a: [Variable(1), Variable(2)] + if ti.static(1 in a): # 1 is not in [Variable(1), Variable(2)] + return 1 + return 0 + +list_assigned() # Returns 0 + +@ti.kernel +def error_non_static_in(): + if i in [1, 2]: # Error: Cannot use `in` outside static scope + pass +``` + +### Comprehensions + +Taichi partially supports list comprehension and dictionary comprehension, +but does not support set comprehension. + +For list comprehensions and dictionary comprehensions, the `if`s and `for`s in them are evaluated at compile time. +The iterators and conditions are implicitly in [static scope](/lang/articles/advanced/meta#static-scope). + +### Operator `is` + +Currently, Taichi does not support operator `is` and `is not`. diff --git a/docs/lang/articles/basic/external.md b/docs/lang/articles/basic/external.md index 2904f40b60577..3e4e70fc07464 100644 --- a/docs/lang/articles/basic/external.md +++ b/docs/lang/articles/basic/external.md @@ -1,5 +1,5 @@ --- -sidebar_position: 4 +sidebar_position: 5 --- # Interacting with external arrays @@ -19,8 +19,8 @@ support NumPy, e.g. `matplotlib`. ```python {8} @ti.kernel def my_kernel(): - for i in x: - x[i] = i * 2 + for i in x: + x[i] = i * 2 x = ti.field(ti.f32, 4) my_kernel() @@ -41,12 +41,38 @@ print(x[2]) # 3 print(x[3]) # 5 ``` +Likewise, Taichi fields can be **imported from and exported to PyTorch tensors**: +```python +@ti.kernel +def my_kernel(): + for i in x: + x[i] = i * 2 + +x = ti.field(ti.f32, 4) +my_kernel() +x_torch = x.to_torch() +print(x_torch) # torch.tensor([0, 2, 4, 6]) + +x.from_numpy(torch.tensor([1, 7, 3, 5])) +print(x[0]) # 1 +print(x[1]) # 7 +print(x[2]) # 3 +print(x[3]) # 5 +``` +When calling `to_torch()`, specify the PyTorch device where the Taichi field is exported using the `device` argument: +```python +x = ti.field(ti.f32, 4) +x.fill(3.0) +x_torch = x.to_torch(device="cuda:0") +print(x_torch.device) # device(type='cuda', index=0) +``` + ## External array shapes -Shapes of Taichi fields and those of corresponding NumPy arrays are closely +Shapes of Taichi fields and those of corresponding NumPy arrays or PyTorch tensors are closely connected via the following rules: -- For scalar fields, **the shape of NumPy array is exactly the same as +- For scalar fields, **the shape of NumPy array or PyTorch tensor equals the shape of the Taichi field**: ```python @@ -60,7 +86,7 @@ field.from_numpy(array) # the input array must be of shape (256, 512) ``` - For vector fields, if the vector is `n`-D, then **the shape of NumPy - array should be** `(*field_shape, vector_n)`: + array or Pytorch tensor should be** `(*field_shape, vector_n)`: ```python field = ti.Vector.field(3, ti.i32, shape=(256, 512)) @@ -74,7 +100,7 @@ field.from_numpy(array) # the input array must be of shape (256, 512, 3) ``` - For matrix fields, if the matrix is `n`-by-`m` (`n x m`), then **the shape of NumPy -array should be** `(*field_shape, matrix_n, matrix_m)`: +array or Pytorch Tensor should be** `(*field_shape, matrix_n, matrix_m)`: ```python field = ti.Matrix.field(3, 4, ti.i32, shape=(256, 512)) @@ -88,7 +114,8 @@ array.shape # (256, 512, 3, 4) field.from_numpy(array) # the input array must be of shape (256, 512, 3, 4) ``` -- For struct fields, the external array will be exported as **a dictionary of arrays** with the keys being struct member names and values being struct member arrays. Nested structs will be exported as nested dictionaries: +- For struct fields, the external array will be exported as **a dictionary of NumPy arrays or PyTorch tensors** with keys +being struct member names and values being struct member arrays. Nested structs will be exported as nested dictionaries: ```python field = ti.Struct.field({'a': ti.i32, 'b': ti.types.vector(float, 3)} shape=(256, 512)) @@ -104,7 +131,7 @@ field.from_numpy(array_dict) # the input array must have the same keys as the fi ## Using external arrays as Taichi kernel arguments -Use the type hint `ti.ext_arr()` for passing external arrays as kernel +Use type hint `ti.ext_arr()` or `ti.any_arr()` to pass external arrays as kernel arguments. For example: ```python {10} @@ -135,3 +162,31 @@ for i in range(n): for j in range(m): assert a[i, j] == i * j + i + j ``` + +Note that the elements in an external array must be indexed using a single square bracket. +This contrasts with a Taichi vector or matrix field where field and matrix indices are indexed separately: +```python +@ti.kernel +def copy_vector(x: ti.template(), y: ti.ext_arr()): + for i, j in ti.ndrange(n, m): + for k in ti.static(range(3)): + y[i, j, k] = x[i, j][k] # correct + # y[i][j][k] = x[i, j][k] incorrect + # y[i, j][k] = x[i, j][k] incorrect +``` +Also, external arrays in a Taichi kernel are indexed using its **physical memory layout**. For PyTorch users, +this implies that the PyTorch tensor [needs to be made contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html) +before being passed into a Taichi kernel: +```python +@ti.kernel +def copy_scalar(x: ti.template(), y: ti.ext_arr()): + for i, j in x: + y[i, j] = x[i, j] + +x = ti.field(dtype=int, shape=(3, 3)) +y = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +y = y.T # Transposing the tensor returns a view of the tensor which is not contiguous +copy(x, y) # error! +copy(x, y.clone()) # correct +copy(x, y.contiguous()) # correct +``` diff --git a/docs/lang/articles/basic/field.md b/docs/lang/articles/basic/field.md index 592ca50aa3a45..d0779f3592d0c 100644 --- a/docs/lang/articles/basic/field.md +++ b/docs/lang/articles/basic/field.md @@ -1,102 +1,168 @@ --- -sidebar_position: 3 +sidebar_position: 4 --- # Fields +Taichi fields are used to store data. +In general, fields are global data containers that can be read and written from both the Python scope and the Taichi scope. -Fields are **global** variables provided by Taichi. **Global** indicates that fields can be read/written from both the Python scope and the Taichi scope. A field can be considered as a multi-dimensional array of elements, and it can be either **dense** or **sparse**. Similar to a NumPy `ndarray` object, a field has a data type and a shape. Moreover, an element of a field can be a scalar, a **vector**, a **matrix**, or a **struct**. +A field has its own data type and shape and can be considered as a multi-dimensional array of elements. +An element of a field can be a **scalar**, a **vector**, a **matrix**, or a **struct**. +The sparsity of a field element is **dense** by default, but it can also be **sparse**, as detailed described in [Sparse spatial data structures](/lang/articles/advanced/sparse). -The term **field** is borrowed from mathematics and physics. If you -have already known [scalar field](https://en.wikipedia.org/wiki/Scalar_field) (e.g., heat field) or vector field (e.g., [gravitational field](https://en.wikipedia.org/wiki/Gravitational_field)) in mathematics and physics, it will be straightforward to understand the fields in Taichi. - -To be noticed: -* Fields are always accessed by indices. -* Field values are initially zero. -* Sparse fields are initially inactive. - -:::tip -In earlier versions of Taichi, you could not allocate new fields after executing the first kernel. Since Taichi v0.8.0, you can use a new class `FieldsBuilder` for dynamic field allocation and destruction. For more details, please see [Field (advanced)](/lang/articles/advanced/layout). +:::note +The term **field** is borrowed from mathematics and physics. +If you have already known [scalar field](https://en.wikipedia.org/wiki/Scalar_field) (e.g., heat field) or vector field (e.g., [gravitational field](https://en.wikipedia.org/wiki/Gravitational_field)) in mathematics and physics, +it will be straightforward to understand the fields in Taichi. ::: ## Scalar fields +We start introducing fields from this very basic type, the elements of scalar fields are simply scalars. +* A 0D scalar field is a single scalar. +* A 1D scalar field is an array. +* A 2D scalar field can be used to represent a 2D regular grid of values. +* A 3D scalar field can be used for volumetric data. -A simple example might help you understand scalar fields. Assume you have a rectangular wok on the top of a fire. At each point of the wok, there would be a temperature. The surface of the wok forms a heat field. The width and height of the wok are similar to the `shape` of the Taichi scalar field. The temperature (0-D scalar) is like the element of the Taichi scalar field. We could use the following field to represent the -heat field on the wok: - +### Declaration ``` python -heat_field = ti.field(dtype=ti.f32, shape=(width_wok, height_wok)) +import taichi as ti +ti.init(arch=ti.cpu) + +energy = ti.field(ti.f32, shape=()) # 0-D +linear_array = ti.field(ti.i32, shape=128) # 1-D +gray_scale_image = ti.field(ti.u8, shape=(640, 480)) # 2-D +volumetric_data = ti.field(ti.f32, shape=(32, 32, 32)) # 3-D ``` ### Access elements of scalar fields -- If `x` is a 3D scalar field (`ti.field(dtype=ti.f32, shape=(10, 20, 30)`), access its element with `x[i, j, k]` (`0 <= i < 10, 0 <= j < 20, 0 <= k < 30`). -- When accessing 0-D field `x`, use `x[None] = 0` instead of `x = 0`. A 0-D field looks like `energy = ti.field(dtype=ti.f32, shape=())`. +``` python +energy[None] = 10.0 +linear_array[0] = 1 +gray_scale_image[1,2] = 255 +volumetric_data[3,3,3] = 2.0 +``` -:::caution -Please **always** use indexing to access entries in fields. +### Meta data +``` python +linear_array.shape # (128,) +volumetric_data.dtype # f32 +``` + +:::note +* Field values are initially zero. +* Fields are **always** accessed by indices. When accessing 0-D field `x`, use `x[None] = 0` instead of `x = 0`. +::: + +### Example +An example might help you understand scalar fields. +Assume you have a gray-scale image. At each point in the image, there would be a pixel value. The width and height of the image are similar to the `shape` of the Taichi scalar field. The pixel value (0-D scalar) is like the element of the Taichi scalar field. We could use the following code to generate a gray-scale image with random pixel values: + +``` python {5} +import taichi as ti + +ti.init(arch=ti.cpu) +width, height = 640,480 +gray_scale_image = ti.field(dtype=ti.f32, shape=(width, height)) + +@ti.kernel +def fill_image(): + for i,j in gray_scale_image: + gray_scale_image[i,j] = ti.random() + +fill_image() + +gui = ti.GUI('gray-scale image with random values', (width, height)) +while gui.running: + gui.set_image(gray_scale_image) + gui.show() +``` + +:::tip +In earlier versions of Taichi, you could not allocate new fields after executing the first kernel. Since Taichi v0.8.0, you can use a new class `FieldsBuilder` for dynamic field allocation and destruction. For more details, please see [Field (advanced)](/lang/articles/advanced/layout). ::: ## Vector fields -We are all live in a gravitational field which is a vector field. At each position of the 3D space, there is a gravity force vector. The gravitational field could be represented with: +We are all living in a gravitational field, which is a vector field. At each position in 3D space, there is a gravity force vector. The gravitational field could be represented by: ```python gravitational_field = ti.Vector.field(n=3, dtype=ti.f32, shape=(x, y, z)) ``` `x, y, z` are the sizes of each dimension of the 3D space respectively. `n` is the number of elements of the gravity force vector. ### Access elements of vector fields +There are **two** indexing operators `[]` when you access a member of a vector field: the first is for field indexing, and the second is for vector indexing. - The gravity force vector could be accessed by `gravitational_field[i, j, k]` (`0 <= i < x, 0 <= j < y, 0 <= k < z`). - The `p`-th member of the gravity force vector could be accessed by `gravitational_field[i, j, k][p]` (`0 <= p < n`). - The 0-D vector field `x = ti.Vector.field(n=3, dtype=ti.f32, shape=())` should be accessed by `x[None][p]` (`0 <= p < n`). -:::note -As you may have noticed, there are **two** indexing operators `[]` when you access a member of a vector from a vector field: the first is for field indexing, and the second is for vector indexing. -::: +### Example +This example helps you understand how to access vector fields: +``` python +import taichi as ti +ti.init(arch=ti.cpu) -## Matrix fields +n,w,h = 3,128,64 +vec_field = ti.Vector.field(n, dtype=ti.f32, shape=(w,h)) + +@ti.kernel +def fill_vector(): + for i,j in vec_field: + for k in ti.static(range(n)): + #ti.static unrolls the inner loops + vec_field[i,j][k] = ti.random() + +fill_vector() +print(vec_field[w-1,h-1][n-1]) +``` +## Matrix fields Field elements can also be matrices. In continuum mechanics, each infinitesimal point in a material exists a strain and a stress tensor. The strain and stress tensor is a 3 by 3 matrix in the 3D space. To represent this tensor field we could use: ```python strain_tensor_field = ti.Matrix.field(n=3, m=3, dtype=ti.f32, shape=(x, y, z)) ``` - `x, y, z` are the sizes of each dimension of the 3D material respectively. `n, m` are the dimensions of the strain tensor. In a general case, suppose you have a `128 x 64` field called `A`, and each element is a `3 x 2` matrix, you can define it with `A = ti.Matrix.field(3, 2, dtype=ti.f32, shape=(128, 64))`. ### Access elements of matrix fields -- If you want to get the matrix of grid node `i, j`, please use - `mat = A[i, j]`. `mat` is simply a `3 x 2` matrix. -- To get the element on the first row and second column of that - matrix, use `mat[0, 1]` or `A[i, j][0, 1]`. +There are **two** indexing operators `[]` when you access a member of a matrix from a matrix field: +the first is for field indexing, and the second is for matrix indexing. +- If you want to get the element `i, j` of the matrix field, please use `mat = A[i, j]`. `mat` is simply a `3 x 2` matrix. +- To get the member on the first row and second column of that element `mat`, use `mat[0, 1]` or `A[i, j][0, 1]`. - The 0-D matrix field `x = ti.Matrix.field(n=3, m=4, dtype=ti.f32, shape=())` should be accessed by `x[None][p, q]` (`0 <= p < n, 0 <= q < m`). - -:::note -- As you may have noticed, there are **two** indexing operators `[]` - when you access a member of a matrix from a matrix field: the - first is for field indexing, and the second is for matrix indexing. - `ti.Vector` is simply an alias of `ti.Matrix`. -::: -### Matrix size +### Example +This example helps you understand element and member in matrix fields: +``` python +matrix_field = ti.Matrix.field(n = 2, m = 3, dtype = ti.f32, shape = (2, 2)) +Element = matrix_field[0, 0] +Member = matrix_field[0, 1][1,1] +``` +![image](https://raw.githubusercontent.com/taichi-dev/public_files/master/taichi/doc/matrix_field.jpg) -For performance reasons matrix operations will be unrolled during the compile stage, therefore we -suggest using only small matrices. For example, `2x1`, `3x3`, `4x4` +### Matrix size +For performance reasons, matrix operations will be unrolled during the compile stage. +Therefore we suggest using only small matrices. For example, `2x1`, `3x3`, `4x4` matrices are fine, yet `32x6` is probably too big as a matrix size. -:::caution -Due to the unrolling mechanisms, operating on large matrices (e.g. -`32x128`) can lead to a very long compilation time and low performance. -::: - If you have a dimension that is too large (e.g. `64`), it's better to declare a field of size `64`. E.g., instead of declaring `ti.Matrix.field(64, 32, dtype=ti.f32, shape=(3, 2))`, declare `ti.Matrix.field(3, 2, dtype=ti.f32, shape=(64, 32))`. Try to put large dimensions to fields instead of matrices. +:::caution +Due to the unrolling mechanism, operating on large matrices (e.g. +`32x128`) can lead to a very long compilation time and low performance. +::: + ## Struct fields -In addition to vectors and matrices, field elements can be user-defined structs. A struct variable may contain scalars, vectors/matrices, or other structs as its members. A struct field is created by providing a dictionary of the name and data type of each member. For example, a 1D field of particles with position, velocity, acceleration, and mass for each particle can be represented as: +Field elements can be user-defined structs. +Struct fields are created by providing the name and data type of each member variable in a dictionary format. +Member variables of struct fields might be scalars, vectors, matrices, or other struct fields. +For example, a 1-D field of particles with position, velocity, acceleration, and mass can be declared as: ```python particle_field = ti.Struct.field({ "pos": ti.types.vector(3, ti.f32), @@ -105,7 +171,8 @@ particle_field = ti.Struct.field({ "mass": ti.f32, }, shape=(n,)) ``` -[Compound types](type.md#compound-types) (`ti.types.vector`, `ti.types.matrix`, and `ti.types.struct`) need to be used to create vectors, matrices, or structs as field members. Apart from using `ti.Struct.field`, the above particle field can be alternatively created using field creation from compound types as: + +[Compound types](type.md#compound-types) (`ti.types.vector`, `ti.types.matrix`, and `ti.types.struct`) are used to declare vectors, matrices, or structs as field members. Apart from using `ti.Struct.field`, the above particle field can also be declared by using the field of compound types: ```python vec3f = ti.types.vector(3, ti.f32) particle = ti.types.struct( @@ -113,6 +180,7 @@ particle = ti.types.struct( ) particle_field = particle.field(shape=(n,)) ``` + Members of a struct field can be accessed either locally (i.e., member of a struct field element) or globally (i.e., member field of a struct field): ```python # set the position of the first particle to origin diff --git a/docs/lang/articles/basic/operator.md b/docs/lang/articles/basic/operator.md new file mode 100644 index 0000000000000..1fb675cb86535 --- /dev/null +++ b/docs/lang/articles/basic/operator.md @@ -0,0 +1,325 @@ +--- +sidebar_position: 4 +--- + +# Operators +Here we present the supported operators in Taichi for both primitive types and +compound types such as matrices. + +## Supported operators for primitive types +### Arithmetic operators + +| Operation | Result | +| --------- | ------------------------------- | +| `-a` | `a` negated | +| `+a` | `a` unchanged | +| `a + b` | sum of `a` and `b` | +| `a - b` | difference of `a` and `b` | +| `a * b` | product of `a` and `b` | +| `a / b` | quotient of `a` and `b` | +| `a // b` | floored quotient of `a` and `b` | +| `a % b` | remainder of `a / b` | +| `a ** b` | `a` to the power of `b` | + +:::note + +The `%` operator in Taichi follows the Python style instead of C style, +e.g., + +```python +# In Taichi-scope or Python-scope: +print(2 % 3) # 2 +print(-2 % 3) # 1 +``` + +For C-style mod (`%`), please use `ti.raw_mod`. This function also receives floating points as arguments. + +`ti.raw_mod(a, b)` returns `a - b * int(float(a) / b)`. + +```python +print(ti.raw_mod(2, 3)) # 2 +print(ti.raw_mod(-2, 3)) # -2 +print(ti.raw_mod(3.5, 1.5)) # 0.5 +``` +::: + +:::note + +Python3 distinguishes `/` (true division) and `//` (floor division), e.g., `1.0 / 2.0 = 0.5`, `1 / 2 = 0.5`, `1 // 2 = 0`, +`4.2 // 2 = 2`. Taichi follows the same design: + +- **True divisions** on integral types first cast their + operands to the default floating point type. +- **Floor divisions** on floating point types first cast their + operands to the default integral type. + +To avoid such implicit casting, you can manually cast your operands to +desired types, using `ti.cast`. Please see +[Default precisions](#default-precisions) for more details on +default numerical types. + +Taichi also provides `ti.raw_div` function which performs true division if one of the operands is floating point type +and performs floor division if both operands are integral types. + +```python +print(ti.raw_div(5, 2)) # 2 +print(ti.raw_div(5, 2.0)) # 2.5 +``` + +::: + + +### Comparison operators + +| Operation | Result | +| ------------------ | ------------------------------------------------------------- | +| `a == b` | if `a` is equal to `b`, then True, else False | +| `a != b` | if `a` is not equal to `b`, then True, else False | +| `a > b` | if `a` is strictly greater than `b`, then True, else False | +| `a < b` | if `a` is strictly less than `b`, then True, else False | +| `a >= b` | if `a` is greater than or equal to `b`, then True, else False | +| `a <= b` | if `a` is less than or equal to `b`, then True, else False | + +### Logical operators + +| Operation | Result | +| ------------------ | ------------------------------------------------------------- | +| `not a` | if `a` is False, then True, else False | +| `a or b` | if `a` is False, then `b`, else `a` | +| `a and b` | if `a` is False, then `a`, else `b` | + +### Conditional operations + +The result of conditional expression `a if cond else b` is `a` if `cond` is True, or `b` otherwise. +`a` and `b` must have a same type. + +The conditional expression does short-circuit evaluation, which means the branch not chosen is not evaluated. + +```python +a = ti.field(ti.i32, shape=(10,)) +for i in range(10): + a[i] = i + +@ti.kernel +def cond_expr(ind: ti.i32) -> ti.i32: + return a[ind] if ind < 10 else 0 + +cond_expr(3) # returns 3 +cond_expr(10) # returns 0, a[10] is not evaluated +``` + + +For element-wise conditional operations on Taichi vectors and matrices, +Taichi provides `ti.select(cond, a, b)` which **does not** do short-circuit evaluation. +```python {4} +cond = ti.Vector([1, 0]) +a = ti.Vector([2, 3]) +b = ti.Vector([4, 5]) +ti.select(cond, a, b) # ti.Vector([2, 5]) +``` + +### Bitwise operators + +| Operation | Result | +| ----------------------- | ----------------------------------- | +| `~a` | the bits of `a` inverted | +| `a & b` | bitwise and of `a` and `b` | +| `a ^ b` | bitwise exclusive or of `a` and `b` | +| a | b | bitwise or of `a` and `b` | +| `a << b` | left-shift `a` by `b` bits | +| `a >> b` | right-shift `a` by `b` bits | + +:::note + +The `>>` operation denotes the +[Shift Arithmetic](https://en.wikipedia.org/wiki/Arithmetic_shift) Right (SAR) operation. +For the [Shift Logical](https://en.wikipedia.org/wiki/Logical_shift) Right (SHR) operation, +consider using `ti.bit_shr()`. For left shift operations, SAL and SHL are the +same. + + +::: + +### Trigonometric functions + +```python +ti.sin(x) +ti.cos(x) +ti.tan(x) +ti.asin(x) +ti.acos(x) +ti.atan2(x, y) +ti.tanh(x) +``` + +### Other arithmetic functions + +```python +ti.sqrt(x) +ti.rsqrt(x) # A fast version for `1 / ti.sqrt(x)`. +ti.exp(x) +ti.log(x) +ti.round(x) +ti.floor(x) +ti.ceil(x) +ti.sum(x) +ti.max(x, y, ...) +ti.min(x, y, ...) +ti.abs(x) # Same as `abs(x)` +ti.pow(x, y) # Same as `pow(x, y)` and `x ** y` +``` + +### Builtin-alike functions + +```python +abs(x) # Same as `ti.abs(x, y)` +pow(x, y) # Same as `ti.pow(x, y)` and `x ** y`. +``` + +### Random number generator + +```python +ti.random(dtype=float) +``` + +:::note + +`ti.random` supports `u32`, `i32`, `u64`, `i64`, and all floating point types. +The range of the returned value is type-specific. + +| Type | Range | +| --- | --- | +| i32 | -2,147,483,648 to 2,147,483,647 | +| u32 | 0 to 4,294,967,295 | +| i64 | -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807 | +| u64 | 0 to 18,446,744,073,709,551,615 | +| floating point | 0.0 to 1.0 | + +::: + +### Supported atomic operations + +In Taichi, augmented assignments (e.g., `x[i] += 1`) are automatically +[atomic](https://en.wikipedia.org/wiki/Fetch-and-add). + +:::caution + +When modifying global variables in parallel, make sure you use atomic +operations. For example, to sum up all the elements in `x`, + +```python +@ti.kernel +def sum(): + for i in x: + # Approach 1: OK + total[None] += x[i] + + # Approach 2: OK + ti.atomic_add(total[None], x[i]) + + # Approach 3: Wrong result since the operation is not atomic. + total[None] = total[None] + x[i] +``` +::: + +:::note + +When atomic operations are applied to local values, the Taichi compiler +will try to demote these operations into their non-atomic counterparts. +::: + +Apart from the augmented assignments, explicit atomic operations, such +as `ti.atomic_add`, also do read-modify-write atomically. These +operations additionally return the **old value** of the first argument. +For example, + +```python +x[i] = 3 +y[i] = 4 +z[i] = ti.atomic_add(x[i], y[i]) +# now x[i] = 7, y[i] = 4, z[i] = 3 +``` + +Below is a list of all explicit atomic operations: + +| Operation | Behavior | +| --------------------- | ---------------------------------------------------------------------------------------------------- | +| `ti.atomic_add(x, y)` | atomically compute `x + y`, store the result in `x`, and return the old value of `x` | +| `ti.atomic_sub(x, y)` | atomically compute `x - y`, store the result in `x`, and return the old value of `x` | +| `ti.atomic_and(x, y)` | atomically compute `x & y`, store the result in `x`, and return the old value of `x` | +| `ti.atomic_or(x, y)` | atomically compute x | y, store the result in `x`, and return the old value of `x` | +| `ti.atomic_xor(x, y)` | atomically compute `x ^ y`, store the result in `x`, and return the old value of `x` | +| `ti.atomic_max(x, y)` | atomically compute `max(x, y)`, store the result in `x`, and return the old value of `x` | +| `ti.atomic_min(x, y)` | atomically compute `min(x, y)`, store the result in `x`, and return the old value of `x` | + +:::note + +Supported atomic operations on each backend: + +| type | CPU | CUDA | OpenGL | Metal | C source | +| ---- | ---- | ---- | ------ | ----- | -------- | +| i32 |:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:| +| f32 |:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:| +| i64 |:heavy_check_mark:|:heavy_check_mark:|:large_orange_diamond:|:x:|:heavy_check_mark:| +| f64 |:heavy_check_mark:|:heavy_check_mark:|:large_orange_diamond:|:x:|:heavy_check_mark:| + +(:large_orange_diamond: requires extension) +::: + + +## Supported operators for matrices + +The previously mentioned operations on primitive types can also be applied on +compound types such as matrices. +In these cases, they are applied in an element-wise manner. For example: + +```python +B = ti.Matrix([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) +C = ti.Matrix([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]) + +A = ti.sin(B) +# is equivalent to +for i in ti.static(range(2)): + for j in ti.static(range(3)): + A[i, j] = ti.sin(B[i, j]) + +A = B ** 2 +# is equivalent to +for i in ti.static(range(2)): + for j in ti.static(range(3)): + A[i, j] = B[i, j] ** 2 + +A = B ** C +# is equivalent to +for i in ti.static(range(2)): + for j in ti.static(range(3)): + A[i, j] = B[i, j] ** C[i, j] + +A += 2 +# is equivalent to +for i in ti.static(range(2)): + for j in ti.static(range(3)): + A[i, j] += 2 + +A += B +# is equivalent to +for i in ti.static(range(2)): + for j in ti.static(range(3)): + A[i, j] += B[i, j] +``` + +In addition, the following methods are supported matrices operations: + +```python +a = ti.Matrix([[2, 3], [4, 5]]) +a.transpose() # the transposed matrix of `a`, will not effect the data in `a`. +a.trace() # the trace of matrix `a`, the returned scalar value can be computed as `a[0, 0] + a[1, 1] + ...`. +a.determinant() # the determinant of matrix `a`. +a.inverse() # (ti.Matrix) the inverse of matrix `a`. +a@a # @ denotes matrix multiplication +``` + +:::note +For now, determinant() and inverse() only works in Taichi-scope, and the +size of the matrix must be 1x1, 2x2, 3x3 or 4x4. +::: diff --git a/docs/lang/articles/basic/overview.md b/docs/lang/articles/basic/overview.md index 80bb9d7c9abd1..c558e06e0ae22 100644 --- a/docs/lang/articles/basic/overview.md +++ b/docs/lang/articles/basic/overview.md @@ -2,22 +2,51 @@ sidebar_position: 0 --- -# Why new programming language +# Why a new programming language -Taichi is a high-performance programming language for computer graphics -applications. The design goals are +Imagine you'd like to write a new particle-based fluid algorithm. You started simple, didn't spend much time before finding a reference C++/CUDA work online (or derived the work from your labmate, unfortunately). `cmake .. && make`, you typed. Oops, cmake threw out an error due to a random incompatible third party library. Installed and rebuilt, now it passed. Then you ran it, which immediately segfaulted (without any stacktrace, of course). Then you started gazing at the code, placed the necessary asset files at the right place, fixed a few dangling pointers and reran. It... actually worked, until you plugged in your revised algorithm. Now another big fight with the GPU or CPU code. More often than not, you get lost in the language details. -- Productivity -- Performance -- Portability -- Spatially sparse computation -- Differentiable programming -- Metaprogramming +If all these sound too familiar to you, congratulations, you are probably looking at the right solution. -## Design decisions +Born from the MIT CSAIL lab, Taichi was designed to facilitate computer graphics researchers' everyday life, by helping them quickly implement visual computing and physics simulation algorithms that are executable on GPU. The path Taichi took was an innovative one: Taichi is embedded in Python and uses modern just-in-time (JIT) frameworks (for example LLVM, SPIR-V) to offload the Python source code to native GPU or CPU instructions, offering the performance at both development time and runtime. -- Decouple computation from data structures -- Domain-specific compiler optimizations -- Megakernels -- Two-scale automatic differentiation -- Embedding in Python +To be fair, a domain-specific language (DSL) with a Python frontend is not something new. In the past few years, frameworks like Halide, PyTorch, and TVM have matured into the de facto standards in areas such as image processing and deep learning (DL). What distinguishes Taichi the most from these frameworks is its imperative programming paradigm. As a DSL, Taichi is not so specialized in a particular computing pattern. This provides better flexibility. While one may argue that flexibility usually comes at the cost of not being fully optimized, we often find this not the case for a few reasons: + +* Taichi's workload typically does *not* exhibit an exploitable pattern (e.g., element-wise operations), meaning that the arithmetic intensity is bounded anyway. By simply switching to the GPU backend, one can already enjoy a nice performance gain. +* Unlike the traditional DL frameworks, where operators are simple math expressions and have to be fused at the graph level to achieve higher arithmetic intensity, Taichi's imperative paradigm makes it quite easy to write a large amount of computation in a single kernel. We call it *mega-kernel*. +* Taichi heavily optimizes the source code using various compiler technologies: common subexpression elimination, dead code elimination, control flow graph analysis, etc. These optimizations are backend neutral, because Taichi hosts its own intermediate representation (IR) layer. +* JIT compilation provides additional optimization opportunities. + +That said, Taichi goes beyond a Python JIT transpiler. One of the initial design goals is to *decouple the computation from the data structures*. The mechanism that Taichi provides is a set of generic data containers, called *SNode* (/ˈsnoʊd/). SNodes can be used to compose hierarchical, dense or sparse, multi-dimensional fields conveniently. Switching between array-of-structures and structure-of-arrays layouts is usually a matter of ≤10 lines of code. This has sparked many use cases in numerical simulation. If you are interested to learn them, please check out [Fields (advanced)](https://docs.taichi.graphics/lang/articles/advanced/layout), [Sparse spatial data structures](https://docs.taichi.graphics/lang/articles/advanced/sparse), or [the original Taichi paper](https://yuanming.taichi.graphics/publication/2019-taichi/taichi-lang.pdf). + +The concept of decoupling is further extended to the type system. With GPU memory capacity and bandwidth becoming the major bottlenecks nowadays, it is vital to be able to pack more data per memory unit. Since 2021, Taichi has introduced customizable quantized types, allowing for the definition of fixed point or floating point numbers with arbitrary bits (still needs to be under 64). This has allowed an MPM simulation of over 400 million particles on a single GPU device. Learn more details in [the QuanTaichi paper](https://yuanming.taichi.graphics/publication/2021-quantaichi/quantaichi.pdf). + +Taichi is intuitive. If you know Python, you know Taichi. If you write Taichi, you awaken your GPU (or CPU as a fallback). Ever since its debut, this simple idea has gained so much popularity, that many were attracted to contribute new backends, including Vulkan, OpenGL and DirectX (working in progress). Without our strong and dedicated community, Taichi would never have been where it is now. + +Going forward, we see many new opportunities lying ahead, and would like to share some of our vision with you. + +**Academia** + +90% of the research code will be trashed due to the nature of research where assumptions keep being broken and ideas keep being iterated. Swiftly coding without thinking too much about performance may lead to incorrect conclusions, while pre-matured code optimization can be a waste of time and often produces a tangled mess. The high performance and productivity are, therefore, extremely helpful for research projects. + +Taichi will keep embracing the academia. The key features we have (or plan to have) for high-performance computing research projects include small-scale linear algebra (inside kernels), large-scale sparse systems, and efficient neighbor accessing for both structured and unstructured data. + +Taichi also provides an automatic differentiation module via source code transformation (at IR level), making it a sweet differentiable simulation tool for machine learning projects. + +**Apps & game engine integration** + +One huge advantange of Taichi lies in its portability, thanks to the support for a wide variety of backends. During the development process, we have also recognized the increasing demands from our industry users for multi-platform packaging and deployment. Below shows an experimental demo of integrating Taichi with Unity. By exporting Taichi kernels as SPIR-V shaders, we can easily import them into a Unity project. + +![](https://github.com/taichi-dev/taichi_assets/blob/master/static/imgs/unity_fluid.gif?raw=true) + +**General-purpose computing** + +While originally designed for physics simulation, Taichi has found its application in many other areas that can be boosted by GPU general-purpose computing. + +* [taichimd](https://github.com/victoriacity/taichimd): Interactive, GPU-accelerated Molecular (& Macroscopic) Dynamics using the Taichi programming language +* [TaichiSLAM](https://github.com/xuhao1/TaichiSLAM): a 3D Dense mapping backend library of SLAM based Taichi-Lang, designed for the aerial swarm. +* [Stannum](https://github.com/ifsheldon/stannum): Fusing Taichi into PyTorch. + +**Maybe a new frontend?** + +The benefit of adopting the compiler approach is that you can decouple the frontend from the backend. Taichi is *currently* embedded in Python, but who says it needs to stay that way? Stay tuned :-) diff --git a/docs/lang/articles/basic/syntax.md b/docs/lang/articles/basic/syntax.md index c61311e10d781..afbd33300e466 100644 --- a/docs/lang/articles/basic/syntax.md +++ b/docs/lang/articles/basic/syntax.md @@ -4,6 +4,29 @@ sidebar_position: 1 # Kernels and functions +Taichi has two types of functions: Taichi kernels, and Taichi functions. + +Scope inside Taichi kernels and Taichi functions is called Taichi scope, and scope outside them is called Python scope. + +A Taichi kernel is the entrypoint of a Taichi program, and it is similar to a `__global__` function in CUDA. It can only be called inside Python scope. + +A Taichi function can only be called inside Taichi scope, and it is similar to a `__device__` function in CUDA. + +Major differences between Taichi kernels and Taichi functions are listed in the table below. + +| | Taichi kernels | Taichi functions | +| :--- | :--- | :--- | +| Can be called in | Python scope | Taichi scope | +| Argument type annotation | Mandatory | Recommended | +| Return type annotation | Mandatory| Recommended | +| Return value | Scalar/Vector/Matrix | Arbitrary | +| Max number of total elements in arguments | 8 (for OpenGL and CC) or 64 (other) | Unlimited | +| Max number of return values in a return statement | 1 | Unlimited | +| Max number of total elements in return values | 30 | Unlimited | + + + + ## Taichi-scope vs Python-scope Code decorated by `@ti.kernel` or `@ti.func` is in the **Taichi-scope**. @@ -45,7 +68,6 @@ For people from CUDA, Taichi kernels are similar to `__global__` functions. Kernels can have multiple arguments, which support passing values from Python-scope to Taichi-scope conveniently. :::note -For kernels executed on OpenGL and CC backends, the number of arguments is limited to 8. ::: Kernel arguments must be type hinted: @@ -59,18 +81,22 @@ my_kernel(24, 3.2) # prints: 27.2 ``` :::note +Taichi supports scalars, `ti.Matrix` and`ti.Vector` as kernel arguments. +The total number of elements in kernel arguments must not exceed 8 on OpenGL and CC backends, or 64 on other backends. +The number of elements in a scalar argument is 1, and the number of elements in a `ti.Matrix` or`ti.Vector` is the number of elements inside it. -For now, Taichi supports scalars as kernel arguments. Specifying `ti.Matrix` or -`ti.Vector` as an argument is not supported yet: - -```python {2,7} +```python {2,7,11} @ti.kernel -def valid_kernel(vx: ti.f32, vy: ti.f32): +def valid_scalar_argument(vx: ti.f32, vy: ti.f32): v = ti.Vector([vx, vy]) ... @ti.kernel -def error_kernel(v: ti.Vector): # Error: Invalid type annotation +def valid_matrix_argument(u: ti.i32, v: ti.types.matrix(2, 2, ti.i32)): # OK: has 5 elements in total + ... + +@ti.kernel +def error_too_many_arguments(u: ti.i32, v: ti.i64, w: ti.types.matrix(7, 9, ti.i64)): # Error: has 65 elements in total ... ``` @@ -78,7 +104,7 @@ def error_kernel(v: ti.Vector): # Error: Invalid type annotation ### Return value -It is optional for a kernel to have a return value. If specified, it must be a type hinted **scalar** value: +It is optional for a kernel to have a return value. If specified, it must be a type hinted **scalar/vector/matrix** value: ```python {2} @ti.kernel @@ -100,21 +126,19 @@ print(my_kernel()) # 128, cast into ti.i32 :::note -For now, a kernel can only have one scalar return value. Returning -`ti.Matrix`, `ti.Vector` or Python-style tuple is not supported: - -```python {3,9} +For now, a kernel can only have one return value, and the number of elements in the return value must not exceed 30. +```python {2,6,10} @ti.kernel -def valid_kernel() -> ti.f32: +def valid_scalar_return() -> ti.f32: return 128.0 # Return 128.0 @ti.kernel -def error_kernel() -> ti.Matrix: - return ti.Matrix([[1, 0], [0, 1]]) # Compilation error +def valid_matrix_return() -> ti.types.matrix(2, 2, ti.i32): + return ti.Matrix([[1, 0], [0, 1]]) @ti.kernel -def error_kernel() -> (ti.i32, ti.f32): +def error_multiple_return() -> (ti.i32, ti.f32): x = 1 y = 0.5 return x, y # Compilation error @@ -141,7 +165,7 @@ differentiation. Instead, it is recommended to store the result into a global va `loss[None]`). ::: -### Functions +## Functions A Python function decorated by `@ti.func` is a **Taichi function**: @@ -170,8 +194,9 @@ Taichi functions can be nested. ::: :::caution -Currently, all functions are force-inlined. Therefore, no recursion is -allowed. +Currently, all functions are force-inlined. Therefore, no runtime recursion is allowed. + +Compile-time recursion is an advanced metaprogramming feature for experienced programmers. See [Metaprogramming](/lang/articles/advanced/meta#compile-time-recursion-of-tifunc) for more information. ::: ### Arguments and return values @@ -230,27 +255,6 @@ def my_kernel(): ... ``` -:::note - -Unlike kernels, functions **do support vectors or matrices as arguments -and return values**: - -```python {2,6} -@ti.func -def sdf(u): # functions support matrices and vectors as arguments. No type-hints needed. - return u.norm() - 1 - -@ti.kernel -def render(d_x: ti.f32, d_y: ti.f32): # Kernels do not support vector/matrix arguments yet. - d = ti.Vector([d_x, d_y]) - p = ti.Vector([0.0, 0.0]) - t = sdf(p) - p += d * t - ... -``` - -::: - :::caution Functions with multiple `return` statements are not supported for now. diff --git a/docs/lang/articles/basic/type.md b/docs/lang/articles/basic/type.md index 28161a96bac7b..a67d2935544ab 100644 --- a/docs/lang/articles/basic/type.md +++ b/docs/lang/articles/basic/type.md @@ -1,106 +1,52 @@ --- -sidebar_position: 2 +sidebar_position: 3 --- # Type system -Data types in Taichi consist of Primitive Types and Compound Types. Primitive Types are the numerical data types used by backends, while Compound Types are user-defined types of data records composed of multiple members. +Data types in Taichi consist of _primitive types_ and _compound types_. Primitive types are the numerical data types used by different backends, while compound types are user-defined types of data records composed of multiple members. ## Primitive types -Taichi supports common numerical data types. Each type is denoted as a -character indicating its _category_ and a number of _precision bits_, -e.g., `i32` and `f64`. +Taichi supports common numerical data types as its primitive types. Each type is denoted as a +character indicating its _category_ followed by a number indicating its _precision bits_. The +_category_ can be either `i` (for signed integers), `u` (for unsigned integers), or `f` (for floating-point numbers). The _precision bits_ can be either `8`, `16`, `32`, or `64`, +which represents the number of **bits** for storing the data. For example, the two most commonly used types: -The _category_ can be one of: +- `i32` represents a 32-bit signed integer; +- `f32` represents a 32-bit floating-point number. -- `i` for signed integers, e.g. 24, -32 -- `u` for unsigned integers, e.g. 128, 256 -- `f` for floating point numbers, e.g. 3.14, 1.0, 1e-4 +### Supported primitive types on each backend -The _digital number_ can be one of: +| type | CPU | CUDA | OpenGL | Metal | Vulkan | +| ---- | ---------------- | ---------------- | -------------------- | ---------------- | -------------------- | +| i8 |:heavy_check_mark:|:heavy_check_mark:|:x: |:heavy_check_mark:|:large_orange_diamond:| +| i16 |:heavy_check_mark:|:heavy_check_mark:|:x: |:heavy_check_mark:|:large_orange_diamond:| +| i32 |:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark: |:heavy_check_mark:|:heavy_check_mark: | +| i64 |:heavy_check_mark:|:heavy_check_mark:|:large_orange_diamond:|:x: |:large_orange_diamond:| +| u8 |:heavy_check_mark:|:heavy_check_mark:|:x: |:heavy_check_mark:|:large_orange_diamond:| +| u16 |:heavy_check_mark:|:heavy_check_mark:|:x: |:heavy_check_mark:|:large_orange_diamond:| +| u32 |:heavy_check_mark:|:heavy_check_mark:|:x: |:heavy_check_mark:|:heavy_check_mark: | +| u64 |:heavy_check_mark:|:heavy_check_mark:|:x: |:x: |:large_orange_diamond:| +| f16 |:heavy_check_mark:|:heavy_check_mark:|:x: |:x: |:heavy_check_mark: | +| f32 |:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark: |:heavy_check_mark:|:heavy_check_mark: | +| f64 |:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark: |:x: |:large_orange_diamond:| -- `8` -- `16` -- `32` -- `64` +(:large_orange_diamond: Requiring extensions of the backend) -It represents how many **bits** are used in storing the data. The larger -the bit number, the higher the precision is. +### Default types for integers and floating-point numbers -For example, the two most commonly used types: - -- `i32` represents a 32-bit signed integer. -- `f32` represents a 32-bit floating point number. - -## Supported primitive types - -Currently, supported primitive types in Taichi are - -- int8 `ti.i8` -- int16 `ti.i16` -- int32 `ti.i32` -- int64 `ti.i64` -- uint8 `ti.u8` -- uint16 `ti.u16` -- uint32 `ti.u32` -- uint64 `ti.u64` -- float32 `ti.f32` -- float64 `ti.f64` - -:::note - -Supported types on each backend: - -| type | CPU/CUDA | OpenGL | Metal | Vulkan | -| ---- | -------- | ------- | ----- | -------- | -| i8 | > OK | > N/A | > OK | > EXT | -| i16 | > OK | > N/A | > OK | > EXT | -| i32 | > OK | > OK | > OK | > OK | -| i64 | > OK | > EXT | > N/A | > EXT | -| u8 | > OK | > N/A | > OK | > EXT | -| u16 | > OK | > N/A | > OK | > EXT | -| u32 | > OK | > N/A | > OK | > OK | -| u64 | > OK | > N/A | > N/A | > EXT | -| f32 | > OK | > OK | > OK | > OK | -| f64 | > OK | > OK | > N/A | > EXT | - -(OK: supported, EXT: require extension, N/A: not available) -::: - -:::note -Boolean types are represented using `ti.i32`. -::: - -## Type promotion - -Binary operations on different types will give you a promoted type, -following the C programming language convention, e.g.: - -- `i32 + f32 = f32` (integer + float = float) -- `i32 + i64 = i64` (less-bits + more-bits = more-bits) - -Basically it will try to choose the more precise type to contain the -result value. - -## Default precisions - -By default, all numerical literals have 32-bit precisions. For example, -`42` has type `ti.i32` and `3.14` has type `ti.f32`. - -Default integer and float-point precisions (`default_ip` and -`default_fp`) can be specified when initializing Taichi: +An integer literal, e.g., `42`, has default type `ti.i32`, while a floating-point literal, +e.g., `3.14`, has default type `ti.f32`. This behavior can be changed by explicitly specifying +default types when initializing Taichi: ```python -ti.init(default_fp=ti.f32) -ti.init(default_fp=ti.f64) - -ti.init(default_ip=ti.i32) -ti.init(default_ip=ti.i64) +ti.init(default_ip=ti.i64) # set default integer type to ti.i64 +ti.init(default_fp=ti.f64) # set default floating-point type to ti.f64 ``` -Also note that you may use `float` or `int` in type definitions as -aliases for default precisions, e.g.: +In addition, you can use `int` as an alias for the default integer type, and `float` as an alias +for the default floating-point type: ```python ti.init(default_ip=ti.i64, default_fp=ti.f32) @@ -113,102 +59,114 @@ y = ti.field(ti.i64, 5) def func(a: float) -> int: ... - # is equivalent to: def func(a: ti.f32) -> ti.i64: ... ``` -## Type casts +### Explicit type casting -All data types are static in the **Taichi scope**. Therefore, casts are needed when you want to assign a certain type of data to another one. - -### Implicit casts +Just like programming in other languages, you may encounter situations where you have a certain +type of data, but it is not feasible for the assignment or calculation you want to perform. In this +case, you can do *explicit type casting*. There are two kinds of explicit type casting in Taichi, +namely *normal casting* and *bit casting*. :::caution -The type of a variable is **determined on its initialization**. +In Taichi-scope, the type of a variable is **static** and **determined on its initialization**. +That is, you can never change the type of a variable. The compiler relies on this compile-time +information to check the validity of expressions in Taichi programs. ::: -When a _low-precision_ variable is assigned to a _high-precision_ -variable, it will be implicitly promoted to the _high-precision_ type -and no warning will be raised: +#### Normal casting -```python {4} -@ti.kernel -def foo(): - a = 3.14 - a = 1 - print(a) # 1.0 -``` +`ti.cast()` is used for normal type casting as in other programming languages: -When a _high-precision_ variable is assigned to a _low-precision_ type, -it will be implicitly down-cast into the _low-precision_ type and Taichi -will raise a warning: - -```python {4} +```python {4-5} @ti.kernel def foo(): - a = 1 a = 3.14 - print(a) # 3 + b = ti.cast(a, ti.i32) # 3 + c = ti.cast(b, ti.f32) # 3.0 ``` -### Explicit casts - -You may use `ti.cast` to explicitly cast scalar values between different +You can also use `int()` and `float()` to convert values to default integer and floating-point types: ```python {4-5} @ti.kernel def foo(): a = 3.14 - b = ti.cast(a, ti.i32) # 3 - c = ti.cast(b, ti.f32) # 3.0 + b = int(a) # 3 + c = float(b) # 3.0 ``` -Equivalently, use `int()` and `float()` to convert values to float-point -or integer types of default precisions: +#### Bit casting + +Use `ti.bit_cast()` to cast a value into another type **with its underlying bits preserved**: ```python {4-5} @ti.kernel def foo(): a = 3.14 - b = int(a) # 3 - c = float(b) # 3.0 + b = ti.bit_cast(a, ti.i32) # 1078523331 + c = ti.bit_cast(b, ti.f32) # 3.14 ``` -### Casting vectors and matrices +Note that the new type must have the same precision bits as the old type (`i32`->`f64` is not +allowed). Use this operation with caution. -Type casts applied to vectors/matrices are element-wise: +:::note +`ti.bit_cast` is equivalent to `reinterpret_cast` in C++. +::: -```python {4,6} +### Implicit type casting + +When you accidentally use a value in a place where a different type is expected, implicit type +casting is triggered for the following cases. + +:::caution +Relying on implicit type casting is bad practice and one major source of bugs. +::: + +#### In binary operations + +Following the [implicit conversion rules](https://en.cppreference.com/w/c/language/conversion) of the C programming language, Taichi implicitly casts +binary operation operands into a *common type* if they have different types. Some simple but most +commonly used rules to determine the common type of two types are listed below: + +- `i32 + f32 = f32` (int + float = float) +- `i32 + i64 = i64` (low precision bits + high precision bits = high precision bits) + +#### In assignments + +When a value is assigned to a variable with a different type, the value is implicitly cast into that +type. If the type of the variable differs from the common type of the variable and the value, a +warning about losing precisions is raised. + +In the following example, variable `a` is initialized with type `float`. On the next line, the +assignment casts `1` from `int` to `float` implicitly without any warning because the type of the +variable is the same as the common type `float`: + +```python {4} @ti.kernel def foo(): - u = ti.Vector([2.3, 4.7]) - v = int(u) # ti.Vector([2, 4]) - # If you are using ti.i32 as default_ip, this is equivalent to: - v = ti.cast(u, ti.i32) # ti.Vector([2, 4]) + a = 3.14 + a = 1 + print(a) # 1.0 ``` -### Bit-casts - -Use `ti.bit_cast` to bit-cast a value into another data type. The -underlying bits will be preserved in this cast. The new type must have -the same width as the the old type. For example, bit-casting `i32` to -`f64` is not allowed. Use this operation with caution. +In the following example, variable `a` is initialized with type `int`. On the next line, the +assignment casts `3.14` from `float` to `int` implicitly with a warning because the type of the +variable differs from the common type `float`: -```python {4-5} +```python {4} @ti.kernel def foo(): + a = 1 a = 3.14 - b = ti.bit_cast(a, ti.i32) # 1078523331 - c = ti.bit_cast(b, ti.f32) # 3.14 + print(a) # 3 ``` -:::note -For people from C++, `ti.bit_cast` is equivalent to `reinterpret_cast`. -::: - ## Compound types User-defined compound types can be created using the `ti.types` module. Supported compound types include vectors, matrices, and structs: @@ -219,6 +177,7 @@ my_vec3f = ti.types.vector(3, float) my_mat2f = ti.types.matrix(2, 2, float) my_ray3f = ti.types.struct(ro=my_vec3f, rd=my_vec3f, l=ti.f32) ``` +In this example, we define four compound types for creating fields and local variables. ### Creating fields @@ -234,8 +193,10 @@ vec1 = ti.Vector.field(2, dtype=ti.i32, shape=(128, 128, 128)) mat2 = ti.Matrix.field(2, 2, dtype=ti.i32, shape=(24, 32)) ray3 = ti.Struct.field({'ro': my_vec3f, 'rd': my_vec3f, 'l': ti.f32}, shape=(1024, 768)) ``` +In this example, we define three fields in two different ways but of exactly the same effect. ### Creating local variables + Compound types can be directly called to create vector, matrix or struct instances. Vectors, matrices and structs can be created using GLSL-like broadcast syntax since their shapes are already known: ```python ray1 = my_ray3f(0.0) # ti.Struct(ro=[0.0, 0.0, 0.0], rd=[0.0, 0.0, 0.0], l=0.0) @@ -244,3 +205,17 @@ mat1 = my_mat2f(1.0) # ti.Matrix([[1.0, 1.0], [1.0, 1.0]]) vec2 = my_vec3f(my_vec2i(0), 1) # ti.Vector([0.0, 0.0, 1.0]), will perform implicit cast ray2 = my_ray3f(ro=vec1, rd=vec2, l=1.0) ``` +In this example, we define five local variables, each of a different type. In the definition statement of `vec2`, `my_vec3f()` performs an implicit cast operation when combining `my_vec2i(0)` with `1`. + +### Type casting on vectors and matrices + +Type casting on vectors/matrices is element-wise: + +```python {4,6} +@ti.kernel +def foo(): + u = ti.Vector([2.3, 4.7]) + v = int(u) # ti.Vector([2, 4]) + # If you are using ti.i32 as default_ip, this is equivalent to: + v = ti.cast(u, ti.i32) # ti.Vector([2, 4]) +``` diff --git a/docs/lang/articles/contribution/_category_.json b/docs/lang/articles/contribution/_category_.json index 74136a441f97d..e5adcd33e7752 100644 --- a/docs/lang/articles/contribution/_category_.json +++ b/docs/lang/articles/contribution/_category_.json @@ -1,4 +1,4 @@ { "label": "Contribution Guide", - "position": 5 + "position": 6 } diff --git a/docs/lang/articles/contribution/contributor_guide.md b/docs/lang/articles/contribution/contributor_guide.md index 492ab454ecad6..8f2599cf1f5f4 100644 --- a/docs/lang/articles/contribution/contributor_guide.md +++ b/docs/lang/articles/contribution/contributor_guide.md @@ -4,295 +4,300 @@ sidebar_position: 1 # Contribution guidelines -First of all, thank you for contributing! We welcome all kinds of contributions, including but not limited to - -- Bug fixes -- New feature proposals and implementations -- Documentation improvements and translations -- More user-friendly error messages -- New test cases and examples -- Compiler performance enhancements -- High-quality blog posts and tutorials -- Participation in the [Taichi forum](https://forum.taichi.graphics/) -- Introducing Taichi to your friends or simply staring [the - project on GitHub](https://github.com/taichi-dev/taichi) -- Typo fixes in the documentation, code or comments (please go ahead and - make a pull request for minor issues like these) - -:::tip reminder -Please take some time to familiarize yourself with this contribution guide before opening a pull request. -For more details regarding development of the Taichi compiler, read the [development tips](./development_tips). -::: -## Where to find contribution opportunities - -- Issues marked with ["good first -issue"](https://github.com/taichi-dev/taichi/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) -are great chances for starters. -- Issues marked with ["welcome - contribution"](https://github.com/taichi-dev/taichi/issues?q=is%3Aopen+is%3Aissue+label%3A%22welcome+contribution%22) - are slightly more challenging but still friendly to beginners. - -## How to take over an issue - -- Please first leave a comment (e.g. _I know how to fix this and would - like to help!_) on the issue, so that people know someone is already - working on it. This helps prevent redundant work. -- If no core developer has commented and described a potential - solution on the issue, please briefly describe your plan, and wait - for a core developer to reply before you start. This helps keep - implementations simple and effective. - -## High-level guidelines - -- Be pragmatic: practically solving problems is our ultimate goal. -- No overkills: always use _easy_ solutions to solve easy problems, so - that you have time and energy for real hard ones. -- Almost every design decision has pros and cons. A decision is - *good* if its pros outweigh its cons. Always think about - both sides. -- Debugging is hard. Changesets should be small so that sources of - bugs can be easily pinpointed. -- Unit tests and integration tests are our friends. -:::note -"There are two ways of constructing a software design: One way is to -make it so simple that there are obviously no deficiencies, and the -other way is to make it so complicated that there are no obvious -deficiencies. _The first method is far more difficult_." -— [C.A.R. Hoare](https://en.wikipedia.org/wiki/Tony_Hoare) -::: +Thank you for your interest in contributing to Taichi. Taichi was born as an academic research project. Though we are working hard to improve its code quality, Taichi has a long way to go to become a mature, large-scale engineering project. This is also why we decided to open source Taichi from the very beginning: We rely on our community to help Taichi evolve and thrive. From document updates, bug fix, to feature implementation, wherever you spot an issue, you are very welcome to file a PR (pull request) with us!:-) -One thing to keep in mind is that, Taichi was originally born as an -academic research project. This usually means that some parts did not -have the luxury to go through a solid design. While we are always trying -to improve the code quality, it doesn't mean that the project is free -from technical debts. Some places may be confusing or overly -complicated. Whenever you spot one, you are more than welcome to shoot -us a PR! :-) - -## Effective communication - -A few tips for effective communication in the Taichi community: - -- How much information one effectively conveys, is way more important - than how many words one typed. -- Be constructive. Be polite. Be organized. Be concise. -- Bulleted lists are our friends. -- Proofread before you post: if you are the reader, can you understand - what you typed? -- If you are not a native speaker, consider using a spell checker such - as [Grammarly](https://app.grammarly.com/). - -Please base your discussion and feedback on facts, and not personal -feelings. It is very important for all of us to maintain a friendly and -blame-free community. Some examples: - -:::tip Acceptable :-) -This design could be confusing to new Taichi users. -::: +Centered around the common process of taking on an issue, testing, and making a corresponding PR, this document provides guidelines, tips, and major considerations for Taichi's contributors. We highly recommend that you spend some time familiarizing yourself with this contribution guide before contributing to Taichi. -:::danger Not Acceptable -This design is terrible. -::: +## General guidelines and tips -## Making good pull requests (PRs) - -- PRs with **small** changesets are preferred. - - A PR should ideally address **only one issue**. - - It is fine to include off-topic **trivial** refactoring such as - typo fixes; - - The reviewers reserve the right to ask PR authors to remove - off-topic **non-trivial** changes. - - When implementing a complex feature, consider breaking it down into - small PRs to keep a more detailed development history and to - interact with core developers more frequently. -- PR titles should be short sentences describing the changes and following - [certain format](./contributor_guide#pr-title-format-and-tags). -- In the description of a PR, it will be nice to link relevant GitHub issues - (e.g. `fixes #issue_number`) or provide a little context on the motivation. - Some important implementation decisions you made in the PR is also helpful. -- If you want early feedback from core developers, - - Open a PR in - [Draft](https://github.blog/2019-02-14-introducing-draft-pull-requests/) - state on GitHub to share your progress; - - Make sure you @ the corresponding developers in the comments or - request reviews from them. -- All PRs should ideally come with corresponding **tests**. See [Testing](./contributor_guide#testing). -- All PRs should come with **documentation updates**, except for - internal compiler implementations. See [Documentation](./contributor_guide#documentation). -- All PRs must pass **continuous integration tests** before they get - merged. See [Using continuous integration](./contributor_guide#using-continuous-integration). -- All PRs must pass **code format checks**. See [Enforcing code style](./contributor_guide#enforcing-code-style). -- Read a great article from Google on [how to have your PR merged - quickly](https://testing.googleblog.com/2017/06/code-health-too-many-comments-on-your.html). - [\[PDF\]](https://github.com/yuanming-hu/public_files/blob/master/graphics/taichi/google_review_comments.pdf) -- All commits in a PR will always be **squashed and merged into master - as a single commit**. However, PR authors **should not squash commits on their own**. -- If you are making multiple PRs, - - Independent PRs should be based on **different** branches - forking from `master`; - - PRs with dependencies should be raised only after all - prerequisite PRs are merged into `master`. - - -## PR reviewing & merging - -- Please try to follow these tips from Google: - - [Code Health: Understanding Code In - Review](https://testing.googleblog.com/2018/05/code-health-understanding-code-in-review.html); - [\[PDF\]](https://github.com/yuanming-hu/public_files/blob/master/graphics/taichi/google_understanding_code.pdf) - - [Code Health: Respectful Reviews == Useful - Reviews](https://testing.googleblog.com/2019/11/code-health-respectful-reviews-useful.html). - [\[PDF\]](https://github.com/yuanming-hu/public_files/blob/master/graphics/taichi/google_respectful_reviews.pdf) -- The merger should always **squash and merge** PRs into the master - branch. -- The master branch is required to have a **linear history**. -- Make sure the PR passes **continuous integration tests**, except for - cases like documentation updates. See [Using continuous integration](./contributor_guide#using-continuous-integration). -- Make sure the title follows [PR tag rules](./contributor_guide#pr-title-format-and-tags). - -## Using continuous integration - -- Continuous Integration (CI) will **build** and **test** your - commits in a PR in multiple environments. -- Currently, Taichi uses [Github Actions](https://github.com/features/actions). -- CI will be triggered every time you push commits to an open PR. -- You can prepend `[skip ci]` to your commit message to avoid - triggering CI. For example, a commit with the message - `[skip ci] This commit will not trigger CI` will not trigger CI. -- A tick on the left-hand side of a commit hash means CI passed, while - a cross means CI failed. - -## Enforcing code style - -- Locally, you can run `ti format` in the command line to re-format - code style. Note that you have to install `clang-format-10` and - `yapf v0.31.0` locally before you use `ti format`. - -- If you don't have these formatting tools locally, feel free to - leverage GitHub Actions: simply comment `/format` in a PR - (e.g., [#2481](https://github.com/taichi-dev/taichi/pull/2481#issuecomment-872226701)) - and then [Taichi Gardener](https://github.com/taichi-gardener) - will automatically push a commit to your branch that formats the code. - Note if you want to make more changes afterwards, you'll need to - `git pull` first. - -- For your C++ code, please also follow [C++ style](./cpp_style). - -## PR title format and tags - -PR titles will be part of the commit history reflected in the `master` -branch, therefore it is important to keep PR titles readable. - -- Please always prepend **at least one tag** such as `[Lang]` to PR - titles: - - When using multiple tags, make sure there is exactly one - space between tags; - - For example, `[Lang][refactor]` (no space) should be replaced - by `[Lang] [refactor]`. -- The first letter of the PR title body should be capitalized: - - For example, `[Doc] improve documentation` should be replaced by - `[Doc] Improve documentation`; - - `[Lang] "ti.sqr(x)" is now deprecated` is fine because `"` - is a symbol. -- Please do not include back quotes ("`") in PR titles. -- Good examples include `[Metal] Support bitmasked SNode`, `[Vulkan] - ti.atomic_min/max support`, or `[Opt] [ir] Enhanced intra-function optimizations`. - -Frequently used tags: - -- `[CPU]`, `[CUDA]`, `[Metal]`, `[Vulkan]`, `[OpenGL]`: backends; -- `[LLVM]`: the LLVM backend shared by CPUs and CUDA; -- `[Lang]`: frontend language features, including syntax sugars; -- `[Std]`: standard library, e.g., `ti.Matrix` and `ti.Vector`; -- `[Sparse]`: sparse computation; -- `[IR]`: intermediate representation; -- `[Opt]`: IR optimization passes; -- `[GUI]`: the built-in GUI system; -- `[Refactor]`: code refactoring; -- `[CLI]`: commandline interfaces, e.g., the `ti` command; -- `[Doc]`: documentation under [docs/](https://github.com/taichi-dev/taichi/blob/master/docs/); -- `[Example]`: examples under [examples/](https://github.com/taichi-dev/taichi/blob/master/examples/); -- `[Test]`: tests under [tests/](https://github.com/taichi-dev/taichi/blob/master/tests/); -- `[Linux]`: Linux platform; -- `[Mac]`: macOS platform; -- `[Windows]`: Windows platform; -- `[Perf]`: performance improvements; -- `[CI]`: CI/CD workflow; -- `[Misc]`: something that doesn't belong to any category, such as - version bump, reformatting; -- `[Bug]`: bug fixes. - -Check out more tags in - [misc/prtags.json](https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json). When introducing a new tag, please update - [misc/prtags.json](https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json) in the first PR with that tag, so that people can - follow. +This section provides some general guidelines for the Taichi community and tips that we find practically useful. -:::note +### Be pragmatic & no overkills -We do appreciate all kinds of contributions, yet we should not expose -the title of every PR to end-users. Therefore the changelog will -distinguish *what the user should know* from *what the -developers are doing*. This is done by **capitalizing PR -tags**: - -- PRs with visible or notable features to the users should be marked - with tags starting with **the first letter capitalized**, e.g., - `[Metal]`, `[Vulkan]`, `[IR]`, `[Lang]`, `[CLI]`. These PRs will be - [highlighted in the release note](https://github.com/taichi-dev/taichi/blob/master/misc/make_changelog.py) - for end-users, therefore it is important to make sure your PR title is - effective in telling what your PR does. -- Other PRs (underlying development or intermediate implementation) - should use tags with **everything in lowercase letters**, e.g., - `[metal]`, `[vulkan]`, `[ir]`, `[lang]`, `[cli]`. -- Because of the way the release changelog is generated, there - should be **at most one capitalized tag** in a PR title to prevent - duplicate PR highlights. For example, - `[GUI] [Mac] Support modifier keys` ([#1189](https://github.com/taichi-dev/taichi/pull/1189)) - is an improper tag choice, and we - should have used `[gui] [Mac] Support modifier keys in GUI` instead. - Please capitalize the tag that is the *most* relevant to the PR. -::: +Always use straightforward (sometimes even brute-force) solutions: Complicated code usually suggests a lack of design or over-engineering. + +> - "There are two ways of constructing a software design: One way is to make it so simple that there are obviously no deficiencies, and the other way is to make it so complicated that there are no obvious deficiencies. *The first method is far more difficult*." — [C.A.R. Hoare](https://en.wikipedia.org/wiki/Tony_Hoare) +> - "Perfection (in design) is achieved not when there is nothing more to add, but rather when there is nothing more to take away." — [Antoine de Saint-Exupéry](https://en.wikipedia.org/wiki/The_Cathedral_and_the_Bazaar) + +### Juxtapose pros and cons + +When it comes to making a design decision, weigh up its pros and cons. A design is *good to go* so long as its advantages outweigh its disadvantages. + +### Communicate effectively + +Our ultimate goal is to build a sustainable, prosperous Taichi community, and effective communication is the cornerstone of that goal. Following are tips that may contribute to effective communication: + +- Concise: + - The message behind your words outweighs the number of your words. Use as few words as possible to drive your point home. + - Use tables, figures, and lists where possible. + +- Professional: + - Read twice before you post: Would your point get across with your words? + - Use a spell checker, such as [Grammarly](https://app.grammarly.com/), to improve your writing in terms of grammar, style, and tone. + +- Constructive and courteous: Base your feedback and discussions on facts, *NOT* on personal feelings. + - Acceptable😃: *"This design could be confusing to new Taichi users. If it were designed this way, it could..."* + - Undesirable😞: ~~*"This design is terrible."*~~ + +## What you can contribute + + + +We welcome all kinds of contributions, including but not limited to: + +- Fixing a bug +- Proposing and implementing new features +- Improving or refactoring an existing document +- Suggesting more friendly error messages +- Adding new test cases and examples (demos) +- Posting blog articles and tutorials +- Enhancing compiler performance +- Minor updates to documentation, codes, or annotations. + +## Take over an issue + +Except for minor updates, most PRs start from a developer taking over an issue. This section provides some corresponding tips and best practices. + +### Where to find issues for starters + +| Issue Tag | Description | Target developer | +| ------------------------------------------------------------ | ------------------------- | ---------------------------------------------- | +| [good first issue](https://github.com/taichi-dev/taichi/issues?q=is:open+is:issue+label:"good+first+issue") | Issues that are easy to start with | Developers new to Taichi | +| [welcome contribution](https://github.com/taichi-dev/taichi/issues?q=is:open+is:issue+label:"welcome+contribution") | Issues *slightly* more challenging | Developers who wish to dive deeper into Taichi | + +### Best practices + +- When you plan to take over an issue: + - **Best practice**: Leave a message claiming that you are working on it. + - **Goal**: Avoid unnecessary repeated work. + - **Example**: *"I know how to fix this and would like to help."* +- After you take over an issue: + - **Best practice**: + 1. Briefly describe how you plan to handle it (if no solution has been provided). + 2. Hold off until a core developer responds to your action plan. + - **Goal**: Keep your implementation neat and effective. + - **Example**: See [#2610](https://github.com/taichi-dev/taichi/issues/2610). + +## References for documentation updates + +As part of the effort to increase visibility of the community and to improve developer experience, we highly recommend including documentation updates in your PR if applicable. Here are some of the documentation-specific references and tips: + +- Documentation source files are hosted under [docs/](https://github.com/taichi-dev/taichi/blob/master/docs/). +- We use GitHub Flavored Markdown (GFM) and [Docusaurus](https://docusaurus.io/) to build our documentation site. For information on the supported Markdown syntax, see the [Documentation Writing Guide](./doc_writing). +- When it comes to writing, we adhere to the [Google Developer Documentation Style Guide](https://developers.google.com/style/). +- For instructions on setting up a local server and previewing your updated documentation in real-time, see the [Local Development](https://github.com/taichi-dev/docs.taichi.graphics#local-development). + +## Add test cases for your local changes + +If your PR is to implement a new feature, we recommend that you write your own test cases to cover corner cases for your codes before filing a PR. + +- To write a Python test case, see the [Workflow for writing a Python test](./write_test). +- To write a C++ test case, see the [Workflow for writing a C++ test](./writing_cpp_tests). + +## Conduct style checks and integration tests locally + +We highly recommend that you complete code style checks and integration tests on your local computer before filing a PR. + +### Enforce code style + +1. Ensure that you have installed `clang-format-10`. +2. Ensure that you have installed `yapf v0.31.0`. +3. Re-format your code style: + +``` +python misc/code_format.py +``` +
+ How to install clang-format-10 on M1 Mac + +1. Download and extract [Clang + LLVM 10.0.0 pre-built binary for macOS](https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-apple-darwin.tar.xz) + +2. Copy the `clang-format` binary to `~/.local/bin` and add `~/.local/bin` to `PATH` + +```shell +mkdir -p ~/.local/bin +cp clang+llvm-10.0.0-x86_64-apple-darwin/bin/clang-format ~/.local/bin/clang-format-10 +echo "export PATH=$HOME/.local/bin:\$PATH" >> ~/.zshrc +source ~/.zshrc +``` + +Please refer to [this](./dev_install#llvm-as-cannot-be-opened-on-macos) if you get an error message like `clang-format-10 can’t be opened because Apple cannot check it for malicious software on macOS`. + +
+ +
+ What if I didn't format my code style locally? + +1. Have your reviewer leave a comment `/format` in your PR to enable GitHub Actions. See [#2481](https://github.com/taichi-dev/taichi/pull/2481). + *[Taichi Gardener](https://github.com/taichi-gardener)* *automatically pushes a commit to your branch to format your code.* -## Testing +2. If you wish to submit more changes after someone leaves the `/format` comment, ensure that your branch is up to date with your remote counterpart. -Tests should be added to [tests/](https://github.com/taichi-dev/taichi/blob/master/tests/). We -have both Python tests and C++ tests. +
-### Python tests + -- Use `ti test` to run all the tests. -- Use `ti test -v` for verbose outputs. -- Use `ti test -s` for original output from the tests. -- Use `ti test -a ` to test only specified backends, e.g., - `ti test -a cuda,metal`. -- Use `ti test -na ` to test all backends excluding specified ones, - e.g., `ti test -na opengl,x64`. -- Use `ti test ` to run tests in specified files. For example, - `ti test numpy_io` will run all tests in [tests/python/test_numpy_io.py](https://github.com/taichi-dev/taichi/blob/master/tests/python/test_numpy_io.py). -- Use `ti test -k ` to run tests that match the specified key. For - example, `ti test linalg -k "cross or diag"` will run `test_cross` - and `test_diag` in [tests/python/test_linalg.py](https://github.com/taichi-dev/taichi/blob/master/tests/python/test_linalg.py). -- Use `ti test -t ` to set custom number of threads for parallel testing. +> For more style information for your C++ code, see [our C++ style](./cpp_style). + +### Run integration tests + +To run all the C++ and Python tests: +`python tests/run_tests.py` + +- **Example 1:** +`python tests/run_tests.py -v -t3 -a cpu,metal -s` + - `-v`: Verbose output. + - `-t `: Set a custom number of threads for parallel testing. + - `-a `: Test only the specified backends (separated by comma). + - `-s`: Original output from the tests. + +- **Example 2:** +`python tests/run_tests.py numpy_io` + - ``: Run test cases in specified files only (separated by comma). + - This command runs all tests in [tests/python/test_numpy_io.py](https://github.com/taichi-dev/taichi/blob/master/tests/python/test_numpy_io.py). + +- **Example 3:** +`python tests/run_tests.py linalg -k "cross or diag"` + - `-k `: Run only the tests that match the specified keys (supports expression in a key string). + - This command runs `test_cross()` and `test_diag()` in [tests/python/test_linalg.py](https://github.com/taichi-dev/taichi/blob/master/tests/python/test_linalg.py). + +- **To show all available options** +`python tests/run_tests.py -h` + +> We have both Python and C++ test cases, but C++ test cases are disabled by default. To enable C++ test cases: +> +> 1. Build Taichi from source using the `python setup.py develop` command. +> 2. Set `TAICHI_CMAKE_ARGS="-DTI_BUILD_TESTS:BOOL=ON"`. + +## File a pull request (PR) + +Now you get to the point where you need to get your hands dirty with your PRs. This section provides the following: +- [Considerations when you create PRs](#considerations) +- [PR naming conventions](#pr-naming-conventions) +- [PR review & merging checklist](#pr-review-merging-checklist) + +### Considerations + + + +- **When implementing a complex feature:** + + - Consider breaking it down to multiple separate, self-contained PRs to provide the community with a clearer context and keep a more traceable development history. + +- **When creating a PR:** + + - Have your PR address only one issue: + - In this way, you keep your changesets small so that potential issues can be readily identified. + - If you include in your PR irrevelant implementations, ensure that they are minor. + - Your reviewers have the right to request you to remove massive, irrevelant changes from your PR. + - If your PR is to implement a new feature, ensure that you have designed test cases for it. See [Add test cases for your local changes](#add-test-cases-for-your-local-changes). + - You are required to conduct code style checks and integration tests locally for your PR. See [Conduct style checks and integration tests locally](#conduct-style-checks-and-integration-tests-locally) + +- **When describing your PR:** + - Provide sufficient information in the description of your PR to provide the community with clearer context: + - Link to a specific GitHub issue if applicable, for example `fixes #`. + - Share important design decisions in your description. + +- **If you create a PR still in progress:** + + - Click **Convert to draft** on your PR page to convert the PR to draft, indicating that you are still working on it. + - Click **Ready for review** when you are all set and up for a review. + - See [Draft](https://github.blog/2019-02-14-introducing-draft-pull-requests/) for more information. + + +### PR naming conventions + +Your PR will make it into the commit history in the the master branch or even Taichi's release notes, therefore it is important to keep your PR title self-explanatory. This section describes our PR naming conventions: + +```Gherkin +[tag1] [tag2]...[tagN] Your PR title must be short but carry necessary info + +^----^ ^----^...^----^ ^--------------------------------------------------^ + +| | | | + +| | | +---> Capitalize the initial of your title. + +| | +---> Adjacent tags are separated with precisely one space. + +| +---> Frequently used tags: [cuda], [lang], [ci], [ir], [refactor]. + ++---> Prepend at least one tag to your PR title. +``` + +- **Tag naming conventions:** + - Prepend at least one tag, such as `[lang]`, to your PR title. + - If you have multiple tags, separate adjacent tags with one space. + - See [misc/prtags.json](https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json) for a full list of available tags. + - We differentiate PRs for end-users from PRs for developers by *capitalizing tag initial*. + - If a PR deals with a feature visible to the end-users, initialize the most relevant tag and the PR will [make it into the release notes](https://github.com/taichi-dev/taichi/blob/master/misc/make_changelog.py). For example, `[Metal]`, `[Vulkan]`, `[IR]`, `[Lang]`, or `[CUDA]`. Ensure that your PR title has *AT MOST* one tag dealt this way. + - If a PR deals with the underlying or intermediate implementation, then it is for the developers and you need to ensure that all its tags are *in lowercase*. For example, `[metal]`, `[vulkan]`, `[ir]`, `[lang]`, or `[cuda]`. + + :::danger INCORRECT + `[Lang][refactor]` (sans space) + ::: + + :::tip CORRECT + `[Lang] [refactor]` + ::: + + :::danger INCORRECT + `[GUI] [Mac] Support modifier keys` (both tags have their initial capitalized) + ::: + + :::tip CORRECT + `[gui] [Mac] Support modifier keys` (only one tag has its initial capitalized) + ::: + +- **Title naming conventions:** + - Keep your PR title short enough but ensure that it carries necessary information. + - Do not include back quotes ("\`") in your PR title. + - Capitalize the initial letter of your title, which is the word immediately after your tag(s). + + :::danger INCORRECT + `[Doc] improve documentation` (the initial of the title is not capitalized) + ::: + + :::tip CORRECT + `[Doc] Improve documentation` + ::: + +:::note + +Following are some frequently used tags: + +- `[cuda]`: Backend-specific changes. +- `[lang]`: Frontend language features, including syntax sugars. +- `[ir]`: Intermediate representation-specific changes. +- `[refactor]`: Code refactoring changes. +- `[ci]`: CI/CD workflow-specific changes. +- `[Doc]`: Documentation updates. + +When introducing a new tag, ensure that you add it to [misc/prtags.json](https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json) so that others can follow. + +::: -For more options, see `ti test -h`. +### PR review & merging checklist -For more details on how to write a Python test case, see -[Workflow for writing a Python test](./write_test). +Follow this checklist during PR review or merging: -### C++ tests +1. Ensure that your PR title follows our [naming conventions](#pr-naming-conventions). +2. Ensure that Taichi's master branch has a *linear history*. See [Linear vs Non-Linear History](https://idiv-biodiversity.github.io/git-knowledge-base/linear-vs-nonlinear.html) for more information. +3. Ensure that your PR passes all Continuous Integration (CI) tests before merging it. -For more details on C++ tests, see -[Workflow for writing a CPP test](./writing_cpp_tests). + CI is triggered each time you push a commit to an open PR. It builds and tests all commits in your PR in multiple environments. Keep an eye on the CI test results: + - A ✔️ on the left-hand side of a commit hash: CI has passed, + - A ❌ on the left-hand side of a commit hash: CI has failed. -## Documentation +Here, we do not want to repeat some best practices summarized in the following Google blog articles. But please spare a couple of minutes reading them if your PR is being reviewed or if you are reviewing a PR. They have our recommendation! + - [Code Health: Understanding Code In Review](https://testing.googleblog.com/2018/05/code-health-understanding-code-in-review.html) + - [Code Health: Respectful Reviews == Useful Reviews](https://testing.googleblog.com/2019/11/code-health-respectful-reviews-useful.html) + - [How to have your PR merged quickly](https://testing.googleblog.com/2017/06/code-health-too-many-comments-on-your.html) -Documentation source files are under [docs/](https://github.com/taichi-dev/taichi/blob/master/docs/) of [**the main Taichi repo**](https://github.com/taichi-dev/taichi). -An automatic service syncs the updated content with our [documentation repo](https://github.com/taichi-dev/docs.taichi.graphics) and deploys the documentation at [the Taichi documentation site](https://docs.taichi.graphics). +## Still have issues? -We use [Markdown](https://www.markdownguide.org/getting-started/) (.md) to write documentation. -Please see [the documentation writing guide](./doc_writing) for more tips. +If you encounter any issue that is not covered here, feel free to report it by asking us on GitHub discussions or by [opening an issue on GitHub](https://github.com/taichi-dev/taichi/issues/new?labels=potential+bug&template=bug_report.md) and including the details. We are always there to help! -To set up a local server and preview your documentation edits in real time, -see instructions for [Local Development](https://github.com/taichi-dev/docs.taichi.graphics#local-development). +Finally, thanks again for your interest in contributing to Taichi. We look forward to seeing your contributions! diff --git a/docs/lang/articles/contribution/dev_install.md b/docs/lang/articles/contribution/dev_install.md index 366a191975319..852030a0361a4 100644 --- a/docs/lang/articles/contribution/dev_install.md +++ b/docs/lang/articles/contribution/dev_install.md @@ -4,306 +4,560 @@ sidebar_position: 2 # Developer installation -:::note -End users should use the [pip packages](../get-started.md) instead of building from source. -::: - -This section documents how to configure the Taichi development environment and build Taichi from source for compiler developers. The installation instructions might vary among different operating systems. We also provide a Dockerfile which helps setup a containerized development environment with CUDA support based on the Ubuntu docker image. +## Target audience -## Installing dependencies -1. Python: Currently, 3.6/3.7/3.8/3.9 are supported. - - If you are on an Apple M1 machine, you might want to install `conda` from [Miniforge](https://github.com/conda-forge/miniforge/#download). +Developers who are interested in the compiler, computer graphics, or high-performance computing, and would like to contribute new features or bug fixes to the [Taichi programming language](https://github.com/taichi-dev/taichi). -2. Clang: Make sure you have `clang-8` (or later) on Linux, or download `clang-10` on Windows: - - On OSX: Normally, you don’t need to do anything. - - On Ubuntu: Execute `sudo apt install libtinfo-dev clang-8`. - - On Arch Linux, download `llvm == 10.0.0` prebuilt binary for `ubuntu 18.04` from [here](https://releases.llvm.org/download.html#10.0.1). Then update environment variables `TAICHI_CMAKE_ARGS` and `PATH`: +:::danger IMPORTANT - ```bash - export TAICHI_CMAKE_ARGS="-DCMAKE_CXX_COMPILER=/bin/clang++:$TAICHI_CMAKE_ARGS" - export PATH=/bin:$PATH - ``` +This installation guide is *NOT* intended for end users who only wish to do simulation or high performance numerical computation. We recommend that end users install Taichi via `pip install taichi` and that there is no need for you to build Taichi from source. Doing both at the same time may cause unnecessary conflicts. - - On other Linux distributions, please search [this site](https://pkgs.org) for clang version \>= 8. - - On Windows: Please download [clang-10](https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip). Make sure you add the `bin` folder containing `clang.exe` to the `PATH` environment variable. +See the [Get Started](https://docs.taichi.graphics/) for more information on quickly setting up Taichi for end users. -:::note -On Linux, `clang` is the **only** supported compiler for compiling the Taichi package. ::: -:::note -On Linux, some additional packages might be required to build Taichi. E.g., on Ubuntu 20.04, you may need `libxi-dev` `libxcursor-dev` `libxinerama-dev` `libxrandr-dev` `libx11-dev` `libgl-dev` `libtinfo5`. please check the output of of CMake when building from source. -::: +## Introduction -3. LLVM: Make sure you have version 10.0.0 installed. Taichi uses a **customized LLVM**, which we provided as binaries depending on your system environment. Note that the pre-built binaries from the LLVM official website or other sources may not work. - - [LLVM 10.0.0 for Linux](https://github.com/taichi-dev/taichi_assets/releases/download/llvm10_linux_patch2/taichi-llvm-10.0.0-linux.zip) - - [LLVM 10.0.0 for macOS](https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-macos.zip) - - [LLVM 10.0.0 for Windows MSVC 2019](https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-msvc2019.zip) + This installation guide covers the following: + + - [Prerequisites for building Taichi from source](#prerequisites) + - [Installing optional dependencies](#install-optional-dependencies) + - [Building Taichi from source](#build-taichi-from-source) + - [Troubleshooting and debugging](#troubleshooting-and-debugging) + - [Frequently asked questions](#frequently-asked-questions) :::note -When using the above pre-built LLVM for Taichi, please add `$LLVM_FOLDER/bin` to `PATH`, e.g., `export PATH=/bin:$PATH` on Linux. + +Installation instructions vary depending on which operating system (OS) you are using. Choose the right OS or platform before you proceed. + ::: - - If the previous LLVM binaries do not work, please build from source: - - For Linux & Mac OSX: - - ```bash - wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/llvm-10.0.0.src.tar.xz - tar xvJf llvm-10.0.0.src.tar.xz - cd llvm-10.0.0.src - mkdir build - cd build - cmake .. -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_TERMINFO=OFF - # If you are building on Apple M1, use -DLLVM_TARGETS_TO_BUILD="AArch64". - # If you are building on NVIDIA Jetson TX2, use -DLLVM_TARGETS_TO_BUILD="ARM;NVPTX" - # If you are building for a PyPI release, add -DLLVM_ENABLE_Z3_SOLVER=OFF to reduce the library dependency. - make -j 8 - sudo make install - # Check your LLVM installation - llvm-config --version # You should get 10.0.0 - ``` - - - For Windows: - - ```bash - # For Windows - # LLVM 10.0.0 + MSVC 2019 - cmake .. -G "Visual Studio 16 2019" -A x64 -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" -DLLVM_ENABLE_ASSERTIONS=ON -Thost=x64 -DLLVM_BUILD_TESTS:BOOL=OFF -DCMAKE_INSTALL_PREFIX=installed - ``` - - - Then open `LLVM.sln` and use Visual Studio 2017+ to build. - - Please make sure you are using the `Release` configuration. - After building the `INSTALL` project (under folder - `CMakePredefinedTargets` in the Solution Explorer window). - - If you use MSVC 2019, **make sure you use C++17** for the - `INSTALL` project. - - After the build is complete, find your LLVM binaries and - headers in `build/installed`. - - Please add `build/installed/bin` to `PATH`. Later, when you build Taichi, please use `cmake -DLLVM_DIR=/build/installed/lib/cmake/llvm`. - - -### Setting up CUDA (optional) +## Prerequisites + + + + + +| Category | Prerequisites | +|:----------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| OS | macOS / Ubuntu / Arch Linux / Other Linux distributions | +| Python | 3.6/3.7/3.8/3.9 We recommend installing Python from [Miniforge](https://github.com/conda-forge/miniforge/#download) conda if you are on a MacBook with M1 chip. | +| Clang++ | 8≤ Clang++ <12 | +| LLVM | 10.0.0 (Taichi customized version) | +| Command line tools for Xcode | For macOS users only: `xcode-select --install ` | + + + + + +| Category | Prerequisites | +|:-------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| OS | Windows 7/8/10/11 | +| Python | 3.6/3.7/3.8/3.9 | +| Clang++ | 8≤ Clang++ <12 (We provide pre-built versions in the clang section) | +| LLVM | 10.0.0 (Taichi customized version) | +| Visual Studio | Visual Studio 2019/2022 with "Desktop Development with C++" component. If you want to use Clang++ as the compiler, also install "C++ Clang Compiler for Windows" component | + + + + + +### Install Clang + +
+This Clang compiler is used to compile the Taichi device runtime. It is **not required** to use this compiler for the C++ compiler. +
+ + + + + +1. Ensure that the Clang that ships with your MacBook has a version ≥8 and <12: + + ``` + clang --version + ``` + +2. If your Clang version is ≥12, install Clang 11: + + ``` + brew install llvm@11 + export CXX=/opt/homebrew/opt/llvm@11/bin/clang++ + ``` + + + + + +Download and extract [Clang 10.0.0 pre-built binary for windows](https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/clang-10.0.0-win.zip). + + + + + +``` +sudo apt install clang-10 +``` + +:::tip NOTE + +- Some Linux distributions may require additional packages to build Taichi. For example, you may need `libxi-dev` `libxcursor-dev` `libxinerama-dev` `libxrandr-dev` `libx11-dev` `libgl-dev` for Ubuntu 20.04. Keep an eye on the output of CMake when building from source. +- If this installation fails, you may want to `apt-get` the corresponding Clang package for your distribution following [this page](https://apt.llvm.org/). -:::note -To build with NVIDIA GPU support, CUDA 10.0+ is needed. This installation guide works for Ubuntu 16.04+. ::: -If you don't have CUDA, go to [this website](https://developer.nvidia.com/cuda-downloads) and download the installer. + -- To check if CUDA is installed, run `nvcc --version` or - `cat /usr/local/cuda/version.txt`. -- On **Ubuntu** we recommend choosing `deb (local)` as **Installer - Type**. -- On **Arch Linux**, you can easily install CUDA via `pacman -S cuda` - without downloading the installer manually. + + +1. Download [Clang + LLVM 10.0.0 pre-built binary for Ubuntu 18.04](https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz). +2. Update the environment variables `TAICHI_CMAKE_ARGS` and `PATH`: + + ```shell + export TAICHI_CMAKE_ARGS="-DCMAKE_CXX_COMPILER=/bin/clang++ $TAICHI_CMAKE_ARGS" + + export PATH=/bin:$PATH + ``` + + :::tip NOTE + + Some Linux distributions may require additional packages to build Taichi. Keep an eye on the output of CMake when building from source. + + ::: + + + + + +Search [this site](https://pkgs.org/) for a Clang version that Taichi supports. + +:::tip NOTE + +Some Linux distributions may require additional packages to build Taichi. Keep an eye on the output of CMake when building from source. -:::note -If you are using a machine with an earlier CUDA version and/or old generation GPUs. We suggest to consult the [Compatibility Document](https://docs.nvidia.com/deploy/cuda-compatibility/) and the [CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) first. ::: -### Setting up Vulkan (optional) + + + + +### Install LLVM + +#### Install pre-built, customized LLVM binaries + +We provide pre-built, customized LLVM binaries. For now, Taichi supports LLVM 10.0.0 only. + +1. Download and install customized binaries from the following list per your system environment: + + + + + LLVM 10.0.0 for Linux + + + LLVM 10.0.0 for macOS (without M1 chip) + + + LLVM 10.0.0 for macOS (with M1 chip) + + + LLVM 10.0.0 for Windows MSVC 2019 + + + +2. Configure environment variable: + + + + + +1. Add LLVM to your PATH variable: + ``` + echo "export PATH=/bin:\$PATH" >> ~/.bashrc + ``` +2. Update your path for the remainder of the session: + + ```shell + source ~/.bashrc + ``` + + + + + +Add an environment variable `LLVM_DIR` with value `` + + + + + + +
+ +

Build LLVM 10.0.0 from source

+ +We provide instructions here if you need to build LLVM 10.0.0 from source. + + + + + +```shell +wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/llvm-10.0.0.src.tar.xz + +tar xvJf llvm-10.0.0.src.tar.xz + +cd llvm-10.0.0.src + +mkdir build + +cd build + +cmake .. -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" -DLLVM_ENABLE_ASSERTIONS=ON -DLLVM_ENABLE_TERMINFO=OFF + +# If you are building on Apple M1, use -DLLVM_TARGETS_TO_BUILD="AArch64". + +# If you are building on NVIDIA Jetson TX2, use -DLLVM_TARGETS_TO_BUILD="ARM;NVPTX" + +# If you are building for a PyPI release, add -DLLVM_ENABLE_Z3_SOLVER=OFF to reduce the library dependency. + +make -j 8 + +sudo make install + +# Check your LLVM installation + +llvm-config --version # You should get 10.0.0 +``` + + + + + +```shell +# For Windows + +# LLVM 10.0.0 + MSVC 2019 + +cmake .. -G "Visual Studio 16 2019" -A x64 -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" -DLLVM_ENABLE_ASSERTIONS=ON -Thost=x64 -DLLVM_BUILD_TESTS:BOOL=OFF -DCMAKE_INSTALL_PREFIX=installed +``` + +1. Use Visual Studio 2017+ to build **LLVM.sln**. +2. Ensure that you use the **Release** configuration. After building the `INSTALL` project (under folde **CMakePredefinedTargets** in the Solution Explorer window). +3. If you use MSVC 2019, ensure that you use **C++17** for the `INSTALL` project. +4. When the build completes, add an environment variable `LLVM_DIR` with value `/build/installed/lib/cmake/llvm`. + + + + + +
+ +## Install optional dependencies + +[CUDA](https://en.wikipedia.org/wiki/CUDA) is NVIDIA's answer to high-performance computing. Taichi has implemented a backend based on CUDA 10.0.0+. Vulkan is a next-generation, cross-platform API, open standard for 3D graphics and computing. Taichi has added a Vulkan backend as of v0.8.0. + +This section provides instructions on installing these two optional dependencies. + +
+

Install CUDA

+ +This section works for you if you have a Nvidia GPU supporting CUDA. Note that the required CUDA version is 10.0+. + +To install CUDA: + +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + + + + + +1. Go to [the official site](https://developer.nvidia.com/cuda-downloads) to download the installer. +2. Choose **deb (local)** as **Installer Type**. +3. Check if CUDA is properly installed: + + ``` + nvidia-smi + ``` + + + + + +1. `pacman -S cuda` +2. Check if CUDA is properly installed: + + ``` + nvidia-smi + ``` + + + + + +1. Go to [the official site](https://developer.nvidia.com/cuda-downloads) and download the installer. +2. Choose **exe (local)** as **Installer Type**. +3. Check if CUDA is properly installed: + + ``` + nvidia-smi + ``` + + + + +
+ +
+

Install Vulkan

+ +You must install the Vulkan SDK in order to debug Taichi's Vulkan backend. To proceed: + + + + + +1. Go to [Vulkan's SDK download page](https://vulkan.lunarg.com/sdk/home) and follow the instructions for your OS. +2. Check if environment variables `VULKAN_SDK`, `PATH`, `LD_LIBRARY_PATH`, and `VK_LAYER_PATH` are updated. + + > The SDK for Ubuntu provides a `setup-env.sh` for updating these variables. + +3. Ensure that you have a Vulkan driver from a GPU vendor properly installed. + + > On Ubuntu, check if a JSON file with a name corresponding to your GPU vendor is in: `/etc/vulkan/icd.d/` or `/usr/share/vulkan/icd.d/`. + +4. Check if the SDK is properly installed: `vulkaninfo`. + +5. If the SDK is properly installed, add an environment variable `TAICHI_CMAKE_ARGS` with the value `-DTI_WITH_VULKAN:BOOL=ON` to enable the Vulkan backend: (Otherwise Vulkan backend is disabled by default when compiling from source.) + + ```shell + export TAICHI_CMAKE_ARGS="$TAICHI_CMAKE_ARGS -DTI_WITH_VULKAN:BOOL=ON" + ``` + + + + + +1. Go to [Vulkan's SDK download page](https://vulkan.lunarg.com/sdk/home) and follow the instructions for your OS. +2. Set the environment variable `VULKAN_SDK` to `C:/VulkanSDK/${YOUR_VULKAN_VERSION}`. +3. If the SDK is properly installed, add an environment variable `TAICHI_CMAKE_ARGS` with the value `-DTI_WITH_VULKAN:BOOL=ON` to enable the Vulkan backend: + + ```shell + $env:TAICHI_CMAKE_ARGS += " -DTI_WITH_VULKAN:BOOL=ON" + ``` + + + + +
-If you wish to build taichi with Vulkan. You will need to install the Vulkan SDK. Please visit [this website](https://vulkan.lunarg.com/sdk/home) and follow the instructions for your OS. -- If you are working on Windows, please set the environment variable `VULKAN_SDK` to `C:/VulkanSDK/${YOUR_VULKAN_VERSION}`. (For example, when using Vulkan 1.2.189.0, set `VULKAN_SDK` to `C:/VulkanSDK/1.2.189.0`). -- On Linux, also make sure the environment variable `VULKAN_SDK` `PATH` `LD_LIBRARY_PATH` and `VK_LAYER_PATH` are updated. On Ubuntu, the downloaded SDK provides a `setup-env.sh` that can be sourced. -- Make sure you have a Vulkan driver from a GPU vendor installed. On Ubuntu, you - can verify there is a JSON file in one of these two locations: `/etc/vulkan/icd.d/` or `/usr/share/vulkan/icd.d`. -- You can verify the installation of the Vulkan SDK by running `vkvia`, `vulkaninfo`, and/or `vkcube`. +## Build Taichi from source -After Vulkan is successfully installed. You can build Taichi with Vulkan by adding an environment variable `TAICHI_CMAKE_ARGS` with the value `-DTI_WITH_VULKAN:BOOL=ON`. + -### Setting up Taichi for development + -1. Clone the Taichi repo **recursively**, and build: +1. Clone the Taichi repo *recursively* and build[^1]: - ```bash + ```shell git clone --recursive https://github.com/taichi-dev/taichi + cd taichi + python3 -m pip install --user -r requirements_dev.txt - # export CXX=/path/to/clang # Uncomment if clang is not system default compiler. - python3 setup.py develop --user # Optionally add DEBUG=1 to keep debug information. + + # Exports CXX=/path/to/clang++ # Uncomment if clang++ is not default compiler of the system. Note that clang is not acceptable due to requirements of some submodules. + + # export DEBUG=1 #Uncomment it if you wish to keep debug information. + + python3 setup.py develop --user ``` -:::note -We use `MSBUILD.exe` to build the generated project on Windows. Please note that Windows -could have multiple instances of `MSBUILD.exe` shipped with different -products. Please make sure you add the path for `MSBUILD.exe` within your -MSVS directory and make it a higher priority (for instance than the one -shipped with .NET). -::: +2. Try out some of the demos in the **examples/** folder to see if Taichi is properly installed. For example: + + ```shell + python3 examples/simulation/mpm128.py + ``` :::note -`python setup.py develop` command (recommended for developers) works very similarly to -`setup.py install` command (recommended for users) except -that it doesn't actually install anything. It fits developer need better since edits -on python file take effect immediately without rebuilding. You only need to rerun `develop` -commands when you change a project’s C extensions or similarly compiled files. See -[development mode](https://setuptools.pypa.io/en/stable/userguide/development_mode.html) for more details. + +[^1]Although the two commands work similarly, `python setup.py develop` is recommended for you as a developer and `python setup.py install`more for end users. The difference is: + +- The `develop` command does not actually install anything but only symbolically links the source code to the deployment directory. +- The `install` command deep copies the source code so that end users need to rerun the command every time they modify the source code. + +The `develop` command serves the developers' needs better because edits to the Python files take effect immediately without the need to rerun the command. A rerun is needed only if you have modified the project's C extension or compiled files. See the [Development Mode](https://setuptools.pypa.io/en/stable/userguide/development_mode.html) for more information. + ::: -2. Check out the `examples` folder for runnable examples. Run them with commands - like `python3 examples/simulation/mpm128.py`. + -3. Execute `python3 -m taichi test` to run all the tests. It may take - up to 5 minutes to run all tests. + -4. Execute `python3 setup.py clean` to clean up the local information of your - previous builds. This allows a fresh build without any cache from the previous - builds. Note that to uninstall the Taichi package from your Python - environment, please use `pip uninstall taichi`. +1. Set-up the environment variable `TAICHI_CMAKE_ARGS` with value `-DCLANG_EXECUTABLE=;/bin/clang.exe -DLLVM_AS_EXECUTABLE=/bin/llvm-as.exe` +2. Open the "x64 Native Tools Command Prompt" for VS2019 or VS2022. Please make sure you opened the x64 version. (Or load the Visual Studio environment yourself) +3. Clone the Taichi repo *recursively* & install python dependencies -## Conda -To avoid directly installing Taichi's dependencies into your existing -Python environment, we have provided a pre-defined `conda` environment. -You can find the instructions [here](https://github.com/taichi-dev/taichi/blob/master/conda/README.md). + ```shell + git clone --recursive https://github.com/taichi-dev/taichi -:::note -This step only helps you setup the development environment, -you would still need to run `python3 setup.py develop` to re-build -Taichi. -::: + cd taichi -## Docker + python -m pip install --user -r requirements_dev.txt + ``` -For those who prefer to use Docker, we also provide a Dockerfile which -helps setup the Taichi development environment with CUDA support based -on Ubuntu docker image. +4. Build taichi by using `python setup.py develop` :::note -In order to follow the instructions in this section, please make sure -you have the [Docker Desktop (or Engine for -Linux)](https://www.docker.com/products/docker-desktop) installed and -set up properly. -::: -### Build the Docker image +[^1]Although the two commands work similarly, `python setup.py develop` is recommended for you as a developer and `python setup.py install`more for end users. The difference is: + +- The `develop` command does not actually install anything but only symbolically links the source code to the deployment directory. +- The `install` command deep copies the source code so that end users need to rerun the command every time they modify the source code. + +The `develop` command serves the developers' needs better because edits to the Python files take effect immediately without the need to rerun the command. A rerun is needed only if you have modified the project's C extension or compiled files. See the [Development Mode](https://setuptools.pypa.io/en/stable/userguide/development_mode.html) for more information. -From within the root directory of the taichi Git repository, execute -`docker build -t taichi:latest .` to build a Docker image based off the -local master branch tagged with _latest_. Since this builds the image -from source, please expect up to 40 mins build time if you don't have -cached Docker image layers. +::: :::note -In order to save the time on building Docker images, you could always -visit our [Docker Hub -repository](https://hub.docker.com/r/taichidev/taichi) and pull the -versions of pre-built images you would like to use. +If you want to build Taichi with Clang or maybe utilize `ccache` to cache and speed-up builds, add the following to the end of environment variable `TAICHI_CMAKE_ARGS`: ` -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang`. -For example, to pull a image built from release v0.6.17, run -`docker pull taichidev/taichi:v0.6.17` ::: -:::caution + -The nature of Docker container determines that no changes to the file -system on the container could be preserved once you exit from the -container. If you want to use Docker as a persistent development -environment, we recommend you [mount the taichi Git repository to the -container as a volume](https://docs.docker.com/storage/volumes/) and set -the Python path to the mounted directory. -::: + -### Use Docker image on macOS (CPU only) +## Troubleshooting and debugging -1. Make sure `XQuartz` and `socat` are installed: +### `llvm-as` cannot be opened on macOS -```bash -brew cask install xquartz -brew install socat -``` +**Description** -2. Temporally disable the xhost access-control: `xhost +`. -3. Start the Docker container with - `docker run -it -e DISPLAY=$(ipconfig getifaddr en0):0 taichidev/taichi:v0.6.17`. -4. Do whatever you want within the container, e.g. you could run tests - or an example, try: `ti test` or `ti example mpm88`. -5. Exit from the container with `exit` or `ctrl+D`. -6. **[To keep your xhost safe]** Re-enable the xhost access-control: - `xhost -`. - -### Use Docker image on Ubuntu (with CUDA support) - -1. Make sure your host machine has CUDA properly installed and - configured. Usually you could verify it by running `nvidia-smi`. -2. Make sure [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) is properly - installed: - -```bash -distribution=$(. /etc/os-release;echo $ID$VERSION_ID) -curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - -curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list - -sudo apt-get update && sudo apt-get install -y nvidia-container-toolkit -sudo systemctl restart docker -``` +Gets an error message `llvm-as can’t be opened because Apple cannot check it for malicious software on macOS`. -3. Make sure `xorg` is installed: `sudo apt-get install xorg`. -4. Temporally disable the xhost access-control: `xhost +`. -5. Start the Docker container with - `sudo docker run -it --gpus all -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix taichidev/taichi:v0.6.17`. -6. Do whatever you want within the container, e.g. you could run tests - or an example, try: `ti test` or `ti example mpm88`. -7. Exit from the container with `exit` or `ctrl+D`. -8. **[To keep your xhost safe]** Re-enable the xhost access-control: - `xhost -`. +**Workaround** +One-off: **System Preferences > Security & Privacy > General > Allow anyway**. -## Troubleshooting developer installation +### Permission denied -- If `python3 setup.py develop`(or `python3 setup.py install`) gives `permission denied` error, it means you're - installing into system python without write permission. You can work around this by: - - `python3 setup.py develop --user` or `python3 setup.py install --user`. - - Install conda and use python from conda enviroments. +**Description** -- If `make` fails to compile and reports - `fatal error: 'spdlog/XXX.h' file not found`, please try runing - `git submodule update --init --recursive --depth=1`. +Gets a `permission denied` after `python3 setup.py develop` or `python3 setup.py install`. -- If the build succeeded but running any Taichi code results in errors - like +**Root cause** - ``` - Bitcode file (/tmp/taichi-tero94pl/runtime//runtime_x64.bc) not found - ``` +You were trying to install packages into the Python environment without write permission. - please double check `clang` is in your `PATH`: +**Workaround** - ```bash - clang --version - # version should be >= 7 - ``` +1. `python3 setup.py develop --user` or `python3 setup.py install --user`. +2. Install Conda and use python from within the conda environment. - and our **Taichi configured** `llvm-as`: +### `make` fails to compile - ```bash - llvm-as --version - # version should be >= 8 - which llvm-as - # should be /usr/local/bin/llvm-as or /opt/XXX/bin/llvm-as, which is our configured installation - ``` +**Description** + +`make` fails to compile and reports `fatal error: 'spdlog/XXX.h' file not found`. + +**Root cause** + +You did not use the `--recursive` flag when cloning the Taichi repository. + +**Workaround** + +Run `git submodule update --init --recursive --depth=1`. + +### `which python` still returns the system's Python location - If not, please install `clang` and **build LLVM from source** with - instructions above in [dev_install](#installing-dependencies-1), - then add their path to environment variable `PATH`. +**Description** -- If you don't have `wget` on OSX, try installing [homebrew](https://brew.sh/) and then run `brew install wget`. +`which python` still returns the system's Python location after Conda is installed. -- If you get a new Apple machine, you might need to run `xcode-select --install` first. +**Workaround** -- If you installed `conda` but `which python` still points to the system `python` location, run the following commands to enable it: +Run the following commands to activate Conda: + +```shell +source /bin/activate + +conda init +``` + +## Frequently asked questions + +### How can I get a fresh Taichi build? + +1. Clean up cache from your previous builds: ``` - source /bin/activate - conda init + python3 setup.py clean ``` -- See also [Installation Troubleshooting](../misc/install.md) for issues - that may share with end-user installation. +2. Uninstall the Taichi package from your Python environment: + +- `python setup.py develop --uninstall`, if you build Taichi using `python setup.py develop`. +- `pip uninstall taichi`, if you build Taichi using `python setup.py install`. + +### What if I don't have `wget` on my macOS? + +1. Install [Homebrew](https://brew.sh/). +2. Use Homebrew to install `wget`: + + `brew install wget` + +## Still have issues? + +- See [Installation Troubleshooting](../misc/install.md) for issues that may share with the end-user installation. -- If you encounter other issues, feel free to report (please include the details) by [opening an - issue on - GitHub](https://github.com/taichi-dev/taichi/issues/new?labels=potential+bug&template=bug_report.md). - We are willing to help! +- If you encounter any issue that is not covered here, feel free to report it by [opening an issue on GitHub](https://github.com/taichi-dev/taichi/issues/new?labels=potential+bug&template=bug_report.md) and including the details. We are always there to help! diff --git a/docs/lang/articles/contribution/development_tips.md b/docs/lang/articles/contribution/development_tips.md index fc1f825a221f2..065206bcff030 100644 --- a/docs/lang/articles/contribution/development_tips.md +++ b/docs/lang/articles/contribution/development_tips.md @@ -71,9 +71,6 @@ When creating a Taichi program using `ti.init(arch=desired_arch, **kwargs)`, pass in the following parameters to make the Taichi compiler print out IRs in different stages: -- `print_preprocessed=True`: print results of the frontend Python - AST transform. The resulting scripts will generate a Taichi Frontend - AST when executed. - `print_ir=True`: print the Taichi IR transformation process of kernel (excluding accessors) compilation. - `print_accessor_ir=True`: print the IR transformation process of diff --git a/docs/lang/articles/contribution/style_guide_en.md b/docs/lang/articles/contribution/style_guide_en.md new file mode 100644 index 0000000000000..15f282a4fd367 --- /dev/null +++ b/docs/lang/articles/contribution/style_guide_en.md @@ -0,0 +1,374 @@ +--- +sidebar_position: 12 +--- + +# Taichi Technical Documentation Style Guide + +This is a reference for the developers and users at Taichi community to improve their writing and the consistency of Taichi's documentation. You can find detailed style, usage, and grammar in the following sections. + +## General principles + +### Style and tone + +- Use active voice when possible. +- Write for scannability; use bullet points, short paragraphs, and sections/headers to break up your content. +- Oxford comma: In a list of three or more items, add a comma before the conjunction (for example: "Android, iOS, and Windows"). +- Spell check your content. +- Be friendly by using "you". +- Review your content. Edit out any information your reader does not need to know. +- Remove ambiguity by choosing words with clear meaning. + +Avoid the following: + +- Exclamation marks, except in code snippets. +- Using phrases such as "simply" or "it is that simple" or "it is easy" in a procedure. +- Do not use dangling modifiers. A modifier "dangles" when the sentence is not clear about what is being modified. + +### Write for a global audience + +- Use simple verbs. For example, do not use words like utilize when the simpler word use conveys the same information. +- Do not use idiomatic or colloquial expressions. +- Avoid making negative voice constructions. +- Do not use double negative. +- Keep paragraphs short. Dense pages full of text are intimidating for readers. +- Address the reader directly by using “you” instead of “the developer”. +- Be inclusive in your writing. Use gender-neutral pronouns. +- Be consistent in your word usage. Do not use the same word to mean different things, and vice versa. + +## Language and grammar + +### Abbreviations and acronyms + +Abbreviations are the shortened version of a word or phrase used to represent the whole. Examples include "s" for "seconds,” "approx." for "approximately," and "e.g." for "exempli gratia" (meaning "for example"). +Abbreviations and acronyms can affect the clarity of Taichi content for the reader. While many are understood by our readers and do not need to be spelled out, for new or novel terms always spell out the first mention of an abbreviated term in the text, followed immediately by the abbreviation in parentheses . Use the abbreviated form for all subsequent references of the abbreviation on the same page. + +#### Latin abbreviations + +Do not use Latin abbreviations in your technical writing. +Many abbreviations derived from Latin are used in written English. Examples include "e.g." for "exempli gratia" (meaning "for example"), "i.e." for "id est" (meaning "in other words"), and "etc." for "et cetera" (meaning "and the other things"). +Plain language principles suggest avoiding these abbreviated terms. + +#### Contractions + +Do not use contractions except in FAQs. +A contraction is a word or phrase that is shortened by dropping one or more letters. Examples include "aren't" for "are not", "let's" for "let us", and "can’t” for “cannot”. While any native English reader understands them, they add unnecessary complexity for non-native readers. For example, contractions that end with the letter "s" can be mistaken for possessive nouns. In business communication, the use of contractions is frowned upon as they make the tone of the writing too informal. +The only exception to this rule is when you are writing content for an FAQ. The more conversational tone of an FAQ allows for the use of contractions in titles and headers . + +### Articles (a, an, the) + +"A" and "an" are indefinite articles and are used before a singular countable noun. They refer to any member of a group. "The" is a definite article. It is used before singular and plural nouns and refers to one or more particular members of a group. +Sound rule for using "a" and "an" +The way a word or acronym is spoken determines whether "a" or "an" precedes it. Use "a" before a word that starts with a consonant sound, and "an" for words that begin with a vowel sound. For example "a URL", and "an SDK". + +### Capitalization + +- Use an uppercase letter to begin the first word of the text immediately following a colon. +- Use sentence case for captions and other figure-related text. +- Use sentence case for items in all types of lists. +- Use sentence case for all the elements in a table: contents, headings, labels, and captions. +- Refer to web page titles with the same casing used on the page. + +### Ornamental words + +An expletive is an added word or phrase that does not contribute meaning to the sentence. The most common expletives are "there are" and "there is". +- Not recommended: There are 10 users in the workspace. +- Recommended: The workspace has 10 users. + +### Direct address or imperative mood +Use the imperative mood for task steps or a command line, a shell command for example. +Use third person singular for a description of an API method. +The imperative mood keeps the content concise. The direct address is more personal. +- Not recommended: Developers can download the SDK from here. +- Better: You can download the SDK from here. +- Recommended: Download the SDK from here. + +### Gender-neutral + +- Avoid using "his" or "her", or "he/she". +- Use the second person, "you", or the collective noun. + +### Present tense + +- Use the present tense as it creates concise sentences and provides a tone of immediacy. An exception to this rule is the release date of an SDK or other product. Always frame the release date in the past tense, as that tense will only be correct on the day of release. For example, use "v0.8.0 was released on September 23, 2021", NOT "v0.8.0 is released on September 23, 2021". +- Avoid using "will" unless you want to stress that something happens at a later point. +- Use future tense if there is a significant time delay that matters in the context. + +### Second person + +- In general, use second person "you" (sometimes implicit) in your docs. +- In glossary terms, avoid using person where possible. Use "developer" to refer to the reader if necessary. + +### Clause order +- Put the most important information at the beginning of a sentence, followed by what the user can do with that information. +- Provide the context before you provide the instruction; that way, the reader can skip the additional information if it does not apply to their circumstance. + +### Punctuations + +#### Colons + +- The first word after the colon should be in uppercase. +- When a colon introduces a list, the phrase before the colon must be a complete sentence. + +#### Ampersands + +- Do not use ampersands ("&") unless they are part of a name, UI string, or in a piece of code. + +#### Hyphens + +- All words following a hyphen are in lowercase, even if the word is at the beginning of a sentence. For example, "32-bit", or "Multi-threaded". +- Use a hyphen to form a compound adjective which is an adjective made up of more than one word. Examples include, "A 25-minute interval", "16-bit color", "a six-figure price", and more. +- Use a hyphen to indicate a common second element. For example, "a 10- to 11-digit number", "a six- or seven-hour delay", "a two-, three-, or fourfold increase". +- Many common prefixes, such as "co", "de", "pre", "pro", "re", and "sub" do not need a hyphen. +- Do not use a hyphen when forming a compound adjective with an adverb that ends in "ly". + +#### Spaces + +- Add a space before an opening parenthesis. Example: operating system (OS) +- Use only one space after full stops. Example: Add a space. One after the full stop. +- Use one space between numbers and units. Example: 256 Kbps. +- Use spaces around mathematical symbols. Example: V = off, width x height, x < y. Use spaces around dimensions. Example: 3.2 x 3.6 x 0.6 mm. +Note that million is not a unit, so there is no need to add a space between a number and M. For example, 10M is the right Taichi style. + +## Plagiarism + +Plagiarism puts the firm in a questionable position. Ensure that you do not copy and paste anything that you find from an online search to our technical documentation. As a tip, you can paraphrase contents that you find online. + +## Formatting + +### Headings and titles + +Headings assist the reader in scanning content, helping them discover exactly what they are seeking. They provide structure and are visual points of reference for the reader. +Use headers to help outline your draft content. Some other points for consideration: +- Capitalize all words in a document title, except for articles and prepositions. +- Use sentence case for section titles. +- Be descriptive and concise. +- Focus on what the reader needs to know and what they can accomplish. +- Use ampersands or other symbols only if they appear in a UI or product name. +- Do NOT conclude a heading with a period or colon. (An exception are FAQs whose titles are often phrased as a conversation with the reader). + +### Table headings + +When referencing something specific (such as a unit of measure) in a table header, do not repeat it in the cells in that column. For example, if a table column header uses “Kbps”, then there is no need to repeat it in the cells for that column. + +### Information + +#### Note +Provides supplemental information that may not apply to all readers, but is important for those specific readers to know. +Wrap the notes in: +:::note +This is a note. +::: + +#### Warning +Suggests proceeding with caution. +Wrap the notes in +:::caution WARNING +This is a warning. +::: + +#### DANGER +Designed to guide the reader away from a circumstance that poses some form of problem or hazard. +Stronger than a Caution; it means "Don't do this." +Wrap the notes in: +:::danger DANGER +This is a danger! +::: + +## Writing examples + +### Example 1 + +- Not recommended: Taichi Zoo uses cookies for security, improvement and analytics purposes. +- Recommended: Taichi Zoo uses cookies for security, improvement, and analytics purposes. + +**Comments:** +In a list of three or more items, add a comma before the conjunction。 + +### Example 2 + +- Not recommended: Two of the most commonly used types: + - f32 represents a 32-bit floating point number. + - i32 represents a 32-bit signed integer. +- Recommended: Two of the most commonly used types: + - f32: 32-bit floating point number. + - i32: 32-bit signed integer. + +**Comments:** + +Avoid repetitive information in a bullet list. + +### Example 3 + +- Not recommended: If you run into this situation, Taichi's handy automatic differentiation (autodiff) system comes to the rescue! +- Recommended: Taichi's automatic differentiation (autodiff) system addresses this situation. + +**Comments:** + +- Avoid subjective descriptions, such as "handy" and "very", and dramatic expressions, for example "come to the rescue" in a technical document. + +### Example 4 + +- Not recommended: ScopedProfileris used to analyze the performance of the Taichi JIT compiler (host). +- Recommended: ScopedProfiler analyzes the performance of the Taichi JIT compiler (host). + +**Comments:** + +- Use third person singular when describing a function, a method, or a callback. +- Use active voice as much as possible in a technical document. + +### Example 5 + +- Not recommended: The easiest way is to make use of ti.GUI. +- Recommended: The easiest way is to use ti.GUI. + +**Comments:** + +Use simple verbs. A noun phrase, for example "make use of", is usually wordier than its original verb form, in this case "use". + +### Example 6 + +- Not recommended: Use ti video -f40for creating a video with 40 FPS. +- Recommended: Use ti video -f40to create a video with a frame rate of 40 FPS. + +### Example 7 + +- Not recommended: Write less bugs. +- Recommended: Write fewer bugs. + +**Comments:** + +"Less" describes uncountable noun; "fewer" describes countable noun. + +### Example 8 + +- Not recommended: Sometimes user may want to override the gradients provided by the Taichi autodiff system. +- Recommended: Sometimes you may want to override the gradients provided by the Taichi autodiff system. + +**Comments:** + +Address your audience directly by using "you". + +### Example 9 + +- Not recommended: Compared to FLAT , query speed is much faster. Compared with IVFFLAT , less disk and CPU/GPU memory is required for IVF_SQ8. +- Recommended: IVF_SQ8 has a much higher query speed than FLAT, and requires less disk space and CPU/GPU memory than IVFFLAT. + +**Comments:** + +- IVF_SQ8 has a much faster query speed than FLAT (has). The second instance of "has" here can be omitted. +- "Compared to" and "Compared with" are usually wordy. + +### Example 10 + +- Not recommended: Different from IVF_SQ8 , IVF_SQ8H uses a GPU-based coarse quantizer that greatly reduces the quantization time . +- Recommended: Unlike IVF_SQ8 , IVF_SQ8H uses a GPU-based coarse quantizer , which greatly reduces the time to quantize. + +**Comments:** + +- In technical writing, one word is always better than two. +- Which is used in a non-restrictive attributive clause; that is used in a restrictive attributive clause. Always precede a which-clause with a comma. + + +### Example 11 + +- Not recommended: When you use a client to update the following parameters, the updates take effect immediately. +- Recommended: Updates to the following parameters from a client take effect immediately: + +**Comments:** + +- The original is wordy. + +### Example 12 + +- Not recommended: Vectors are quantized to 8-bit floats , which may cause accuracy loss. +- Recommended: Vectors are quantized to 8-bit floats. This may cause accuracy loss. +- Not recommended: However, the process to build a search graph requires a lot of computations for distances between vectors, which results in high computation costs. +- Recommended: However, the process of building a search graph requires a lot of computations for distances between vectors, resulting in high computation costs. + +**Comments:** + +- You cannot use which to refer to an entire preceding clause. Which only modifies the noun or noun phrase ( noun + prep. + noun) immediately preceding it. Use this to refer to the entire preceding clause . + +### Example 13 + +- Not recommended: Make sure the Memory available to Docker Engine exceeds the sum of insert_buffer_size and cpu_cache_capacity you set in the config.yaml file. +- Recommended: Ensure that the Memory available to Docker Engine exceeds the sum of insert_buffer_size and cpu_cache_capacity , both of which are defined in config.yaml. + +**Comments:** + +- When it comes to technical writing, do not use more than one word when one word can convey the same information. +- Always use that to lead an appositive clause. +- If you have already spelt out the file name, you do not need to emphasize it is a file. Your readers can tell for themselves. + +### Example 14 + +- Not recommended: Start the Prometheus server, with the --config.file flag pointing to the configuration file: +$ ./prometheus --config.file=prometheus.yml + +- Recommended: Start the Prometheus server and specify the configuration file: +$ ./prometheus --config.file=prometheus.yml + +**Comments:** + +- Misuse of with. With modifies the subject of the main clause. +- The original is horribly specific. The revised version speaks for itself. + +### Example 15 + +- Not recommended: This document talks about the following topics: +- Recommended: This document covers the following topics: + +**Comments:** + +- Anthropomorphism is not accepted in technical documents. + +### Example 16 + +- Not recommended: + - True: Enables the debug mode. + - False: Disables the debug mode. +- Recommended: + - True: Enable debug mode. + - False: Disable debug mode. + + +**Comments:** + +- Use imperative mood when desbribing a binary parameter; use third person singular when describing a function, a method, or a callback. +- Do not use the definite article before the word mode. + +### Example 17 + +- Not recommended: This parameter is used to enable or disable Write Ahead Log (WAL). +- Recommended: This parameter enables or disables Write Ahead Log (WAL). + +**Comments:** + +- Clean, clear, and straight to the point! + +### Example 18 + +- Not recommended: Active monitoring helps you identify problems early. But it is also essential to create alerting rules that promptly send notifications when there are events that require investigation or intervention. +- Recommended: Proactively monitoring metrics helps identify issues when they emerge. Creating alerting rules for events that require immediate intervention is essential as well. + +**Comments:** + +- Do not use "but" to lead a separate sentence. +- Way too many that-clauses! +- The "there be" construction is always awkward. An expletive is an added word or phrase that does not contribute meaning to the sentence. The most common expletives are "there are" and "there is". + +### Example 19 + +- Not recommended: However, for delete operations, the operation speed is faster when write ahead log is enabled. +- Recommended: Delete operations are faster when write ahead log is enabled. + +**Comments:** + +- You cannot say faster speed. You can say higher speed or greater speed . You can also say an operation is faster. +- The original is wordy. + +## English style guide references + +- Microsoft Writing Style Guide +- The Chicago Manual of Style +- Merriam-Webster's Dictionary diff --git a/docs/lang/articles/contribution/utilities.md b/docs/lang/articles/contribution/utilities.md index d537ec7df20c3..2fd7dfc5fa62c 100644 --- a/docs/lang/articles/contribution/utilities.md +++ b/docs/lang/articles/contribution/utilities.md @@ -86,7 +86,7 @@ int func(void *p) { ## Benchmarking and regression tests - Run `ti benchmark` to run tests in benchmark mode. This will record - the performance of `ti test`, and save it in `benchmarks/output`. + the performance of `python tests/run_tests.py`, and save it in `benchmarks/output`. - Run `ti regression` to show the difference between the previous result in `benchmarks/baseline`. And you can see if the performance is increasing or decreasing after your commits. This is really @@ -150,7 +150,7 @@ when the program crashes. ```python # Python -ti.set_gdb_trigger(True) +ti.init(gdb_trigger=True) ``` ```cpp @@ -188,7 +188,7 @@ in is executed in test. not C++ yet. ```bash -ti test -C # run tests and save results to .coverage +python tests/run_tests.py -C # run tests and save results to .coverage coverage report # generate a coverage report on terminal output coverage html # generate a HTML form report in htmlcov/index.html ``` diff --git a/docs/lang/articles/contribution/write_test.md b/docs/lang/articles/contribution/write_test.md index 475b47e104507..b26a1805d831b 100644 --- a/docs/lang/articles/contribution/write_test.md +++ b/docs/lang/articles/contribution/write_test.md @@ -51,7 +51,7 @@ def test_log10(): assert r[None] == 2 ``` -Execute `ti test logarithm`, and the functions starting with `test_` in +Execute `python tests/run_tests.py logarithm`, and the functions starting with `test_` in `tests/python/test_logarithm.py` will be executed. ## Testing against multiple backends @@ -229,7 +229,7 @@ exclude them from the test in order to move forward: ```python # Run this test on all backends except for OpenGL -@ti.test(excludes=[ti.opengl]) +@ti.test(exclude=[ti.opengl]) def test_sparse_field(): # ... (some tests that requires sparse feature which is not supported by OpenGL) ``` diff --git a/docs/lang/articles/get-started/_category_.json b/docs/lang/articles/get-started/_category_.json new file mode 100644 index 0000000000000..3562d433d76f3 --- /dev/null +++ b/docs/lang/articles/get-started/_category_.json @@ -0,0 +1,4 @@ +{ + "label": "Getting Started", + "position": 1 +} diff --git a/docs/lang/articles/get-started.md b/docs/lang/articles/get-started/index.md similarity index 89% rename from docs/lang/articles/get-started.md rename to docs/lang/articles/get-started/index.md index b6d250d157f26..f607ce2f224d6 100644 --- a/docs/lang/articles/get-started.md +++ b/docs/lang/articles/get-started/index.md @@ -25,25 +25,15 @@ import TabItem from '@theme/TabItem'; There are a few of extra requirements depend on which operating system you are using: - - - On Ubuntu 19.04+, you need to install `libtinfo5`: - - ```sudo apt install libtinfo5``` - - - On Arch Linux, you need to install `ncurses5-compat-libs` package from the Arch User Repository: - - ```yaourt -S ncurses5-compat-libs``` + On Arch Linux, you need to install `ncurses5-compat-libs` package from the Arch User Repository: `yaourt -S ncurses5-compat-libs` @@ -54,13 +44,13 @@ There are a few of extra requirements depend on which operating system you are u -Please refer to the [Installation Troubleshooting](./misc/install.md) section if you run into any issues when installing Taichi. +Please refer to the [Installation Troubleshooting](../misc/install.md) section if you run into any issues when installing Taichi. ## Hello, world! We introduce the Taichi programming language through a very basic _fractal_ example. -Running the Taichi code below using either `python3 fractal.py` or `ti example fractal` _(you can find more information about the Taichi CLI in the [Command line utilities](./misc/cli_utilities.md) section)_ will give you an animation of [Julia set](https://en.wikipedia.org/wiki/Julia_set): +Running the Taichi code below using either `python3 fractal.py` or `ti example fractal` _(you can find more information about the Taichi CLI in the [Command line utilities](../misc/cli_utilities.md) section)_ will give you an animation of [Julia set](https://en.wikipedia.org/wiki/Julia_set):
@@ -178,7 +168,7 @@ type-hinted (if any). Taichi **functions** are defined with the decorator `@ti.func`. They can **only** be called by Taichi kernels or other Taichi functions. -See [syntax](./basic/syntax.md) for more details about Taichi +See [syntax](../basic/syntax.md) for more details about Taichi kernels and functions. The language used in Taichi kernels and functions looks exactly like @@ -273,7 +263,7 @@ over all the pixel coordinates, i.e., :::note -Struct-for is the key to [sparse computation](./advanced/sparse.md) in +Struct-for is the key to [sparse computation](../advanced/sparse.md) in Taichi, as it will only loop over active elements in a sparse field. In dense fields, all elements are active. ::: @@ -326,6 +316,20 @@ def foo(): ::: +### GUI system + +Taichi provides a cpu-based [GUI system](../gui/gui.md) for users to render +their results on the screen. + +```python +gui = ti.GUI("Julia Set", res=(n * 2, n)) + +for i in range(1000000): + paint(i * 0.03) + gui.set_image(pixels) + gui.show() +``` + ### Interacting with other Python packages #### Python-scope data access @@ -373,18 +377,18 @@ while gui.running: gui.show() ``` -See [Interacting with external arrays](./basic/external.md#interacting-with-external-arrays) for more details. +See [Interacting with external arrays](../basic/external.md#interacting-with-external-arrays) for more details. ## What's next? Now we have gone through core features of the Taichi programming language using the fractal example, feel free to dive into the language concepts in -the next section, or jump to the advanced topics, such as the [Metaprogramming](./advanced/meta.md) or [Differentiable programming](./advanced/differentiable_programming.md). Remember that you can +the next section, or jump to the advanced topics, such as the [Metaprogramming](../advanced/meta.md) or [Differentiable programming](../advanced/differentiable_programming.md). Remember that you can use the search bar at the top right corner to search for topics or keywords at any time! If you are interested in joining the Taichi community, we strongly recommend you take some time to -familiarize yourself with our [contribution guide](./contribution/contributor_guide.md). +familiarize yourself with our [contribution guide](../contribution/contributor_guide.md). We hope you enjoy your adventure with Taichi! diff --git a/docs/lang/articles/gui/_category_.json b/docs/lang/articles/gui/_category_.json new file mode 100644 index 0000000000000..d05d0602cb17c --- /dev/null +++ b/docs/lang/articles/gui/_category_.json @@ -0,0 +1,4 @@ +{ + "label": "GUI", + "position": 4 +} diff --git a/docs/lang/articles/misc/ggui.md b/docs/lang/articles/gui/ggui.md similarity index 83% rename from docs/lang/articles/misc/ggui.md rename to docs/lang/articles/gui/ggui.md index 39a2bb5ed24d7..edf084ba11caf 100644 --- a/docs/lang/articles/misc/ggui.md +++ b/docs/lang/articles/gui/ggui.md @@ -1,6 +1,5 @@ --- -sidebar_position: 1 - +sidebar_position: 2 --- # A New UI system: GGUI @@ -14,7 +13,7 @@ You also need to install the Vulkan environment: [https://vulkan.lunarg.com/sdk/ A new UI system has been added to Taichi in version `v0.8.0`. The new GUI system uses GPU for rendering, enabling it to be much faster and to render 3d scenes. For these reasons, this new system is sometimes referred to as GGUI. This doc describes the APIs provided. -Apart from this doc, a good way of getting familiarized with GGUI is to look at the examples. Please checkout the examples provided in [`examples/ggui_examples`](https://github.com/taichi-dev/taichi/tree/master/examples/ggui_examples). +Apart from this doc, a good way of getting familiarized with GGUI is to look at the examples. Please checkout the examples provided in [`examples/ggui_examples`](https://github.com/taichi-dev/taichi/tree/master/python/taichi/examples/ggui_examples). ## Creating a window @@ -42,7 +41,7 @@ this retrieves a `Canvas` object that covers the entire window. ### Drawing on the canvas ```python -canvas.set_back_ground_color(color) +canvas.set_background_color(color) canvas.triangles(vertices, color, indices, per_vertex_color) canvas.circles(vertices, radius, color, per_vertex_color) canvas.lines(vertices, width, indices, color, per_vertex_color) @@ -114,7 +113,7 @@ window.GUI.begin(name, x, y, width, height) window.GUI.text(text) is_clicked = window.GUI.button(name) new_value = window.GUI.slider_float(name, old_value, min_value, max_value) -new_color = window.GUI.slider_float(name, old_color) +new_color = window.GUI.color_edit_3(name, old_color) window.GUI.end() ``` @@ -145,7 +144,7 @@ To check if a specific key is currently pressed: -Here is an input processing example in GGUI version [`mpm128`](https://github.com/taichi-dev/taichi/blob/master/examples/ggui_examples/mpm128_ggui.py): +Here is an input processing example in GGUI version [`mpm128`](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/ggui_examples/mpm128_ggui.py): ```python while window.running: @@ -167,3 +166,24 @@ while window.running: if window.is_pressed(ti.ui.RMB): attractor_strength[None] = -1 ``` + + +## Image I/O + +To write the current screen content into an image file: + +```python +window.write_image(filename) +``` + +Notice that, when the window is showing, you have to call `window.write_image()` before the `window.show()` call. + + +## Off-screen rendering + +GGUI supports rendering contents off-screen, that is, writing the results into image files without showing the window at all. This is sometimes referred to as "headless" rendering. To enable this mode, initialize the window with the argument `show_window=False`: + +```python +window = ti.ui.Window('Window Title', (640, 360), show_window = False) +``` +Then, you can use `window.write_image()` as normal, and remove the `window.show()` call at the end. diff --git a/docs/lang/articles/gui/gui.md b/docs/lang/articles/gui/gui.md new file mode 100644 index 0000000000000..0d2495739185c --- /dev/null +++ b/docs/lang/articles/gui/gui.md @@ -0,0 +1,348 @@ +--- +sidebar_position: 1 +--- + +# GUI system + +Taichi has a built-in cpu-based GUI system to help users visualize results. + +## Create a window + +The following code show how to create a window of resolution `640x360`: + +```python +gui = ti.GUI('Window Title', (640, 360)) +``` + +:::note + +If you are running Taichi on a machine without a GUI environment, consider setting `show_gui` to `False`: + +```python +gui = ti.GUI('Window Title', (640, 360), show_gui=False) + +while gui.running: + ... + gui.show(f'{gui.frame:06d}.png') # save current frame to local file +``` + +::: + +## Display a window + +The following code snippet display frame of the current windows: + +```python + for frame in range(10000): + ... + gui.show() # display current frame +``` + +:::note +Current FPS will show besides the title of the window. By default, FPS is limited to 60. +We can change this number by setting `gui.fps_limit = the_number_we_want`. +::: + + +## Paint on a window +Taichi's GUI supports painting simple geometric objects, such as lines, triangles, rectangles, circles, and text. + +:::note + +The position parameter of every drawing API expects input of 2-element tuples, +whose values are the relative position of the object range from 0.0 to 1.0. +(0.0, 0.0) stands for the lower left corner of the window, and (1.0, 1.0) stands for the upper right corner. + +Acceptable input for positions are Taichi fields or numpy arrays. Primitive arrays in python are NOT acceptable. + +To convert Taichi field to numpy array, use `to_numpy()` on Taichi fields. By doing this, we can also use data +from Taichi program in other visualization APIs such as matplotlib. +::: + +:::tip + +Here we only list the most commonly-used APIs. For a full list of APIs and the detailed API descriptions, please +see the [API docs](https://api-docs.taichi.graphics/autoapi/taichi/ui/gui/index.html#module-taichi.ui.gui). + +::: + +```python +gui.circles(pos, radius=3, palette=[0x068587, 0xED553B, 0xEEEEF0], palette_indices=material) +``` +draws circles with radius of 1.5 and three different colors differed by `material`, an integer array with the same size as +`pos`. Each integer in `material` indicates which color the associated circle use (i.e. array [0, 1, 2] indicates these three +circles are colored separately by the first, second, and third color in `palette`. + +![circles](../static/assets/colored_circles.png) + +```python +gui.lines(begin=X, end=Y, radius=2, color=0x068587) +``` +draws line segments from X positions to Y positions with width of 2 and color in light blue. + +![lines](../static/assets/lines.png) + +```python +gui.triangles(a=X, b=Y, c=Z, color=0xED553B) +``` +draws triangles with color in red and three points positioned at X, Y, and Z. + +![triangles](../static/assets/triangles.png) + +## RGB & Hex conversion. + +A handy tool for converting colors from RGB to hex and vice versa. + +```python +rgb = (0.4, 0.8, 1.0) +hex = ti.rgb_to_hex(rgb) # 0x66ccff + +rgb = ti.hex_to_rgb(0x007fff) # (0.0, 0.5, 1.0) + +rgb = np.array([[0.4, 0.8, 1.0], [0.0, 0.5, 1.0]]) +hex = ti.rgb_to_hex(rgb) # np.array([0x66ccff, 0x007fff]) +``` + +The return values can be used in GUI drawing APIs. + + +## Event processing + +Every event have a key and type. + +_Event type_ is the type of event, for now, there are just three type of event: + + ti.GUI.RELEASE # key up or mouse button up + ti.GUI.PRESS # key down or mouse button down + ti.GUI.MOTION # mouse motion or mouse wheel + +_Event key_ is the key that you pressed on keyboard or mouse, can be one of: + + # for ti.GUI.PRESS and ti.GUI.RELEASE event: + ti.GUI.ESCAPE # Esc + ti.GUI.SHIFT # Shift + ti.GUI.LEFT # Left Arrow + 'a' # we use lowercase for alphabet + 'b' + ... + ti.GUI.LMB # Left Mouse Button + ti.GUI.RMB # Right Mouse Button + + # for ti.GUI.MOTION event: + ti.GUI.MOVE # Mouse Moved + ti.GUI.WHEEL # Mouse Wheel Scrolling + +A _event filter_ is a list combined of _key_, _type_ and _(type, key)_ tuple, e.g.: + +```python +# if ESC pressed or released: +gui.get_event(ti.GUI.ESCAPE) + +# if any key is pressed: +gui.get_event(ti.GUI.PRESS) + +# if ESC pressed or SPACE released: +gui.get_event((ti.GUI.PRESS, ti.GUI.ESCAPE), (ti.GUI.RELEASE, ti.GUI.SPACE)) +``` + +`gui.running` checks the state of the window. `ti.GUI.EXIT` occurs when +you click on the close (X) button of a window. `gui.running` will obtain +`False` when the GUI is being closed. + +For example, loop until the close button is clicked: + + while gui.running: + render() + gui.set_image(pixels) + gui.show() + +You can also close the window by manually setting `gui.running` to`False`: + + while gui.running: + if gui.get_event(ti.GUI.ESCAPE): + gui.running = False + + render() + gui.set_image(pixels) + gui.show() + +`gui.get_event(a, ...)` tries to pop an event from the queue, and stores it into `gui.event`. + +For example: + + if gui.get_event(): + print('Got event, key =', gui.event.key) + +For example, loop until ESC is pressed: + + gui = ti.GUI('Title', (640, 480)) + while not gui.get_event(ti.GUI.ESCAPE): + gui.set_image(img) + gui.show() + +`gui.is_pressed(key, ...)` detects the keys you pressed. You must use it +together with `gui.get_event`. Otherwise, it is not updated. For example: + + while True: + gui.get_event() # must be called before is_pressed + if gui.is_pressed('a', ti.GUI.LEFT): + print('Go left!') + elif gui.is_pressed('d', ti.GUI.RIGHT): + print('Go right!') + +:::caution + +`gui.is_pressed()` must be used together with `gui.get_event()`, or it won't be updated! + +::: + +For example: + +```python +while True: + gui.get_event() # must be called before is_pressed + if gui.is_pressed('a', ti.GUI.LEFT): + print('Go left!') + elif gui.is_pressed('d', ti.GUI.RIGHT): + print('Go right!') +``` + +`gui.get_cursor_pos()` retrieves the current cursor position on the window. For example: + + mouse_x, mouse_y = gui.get_cursor_pos() + + +## GUI Widgets + +Sometimes it's more intuitive to use widgets like slider or button to control the program variables instead of using chaotic keyboard bindings. Taichi GUI provides a set of widgets for that reason: + +```python +import taichi as ti + +gui = ti.GUI('GUI widgets') + +radius = gui.slider('Radius', 1, 50, step=1) +xcoor = gui.label('X-coordinate') +okay = gui.button('OK') + +xcoor.value = 0.5 +radius.value = 10 + +while gui.running: + for e in gui.get_events(gui.PRESS): + if e.key == gui.ESCAPE: + gui.running = False + elif e.key == 'a': + xcoor.value -= 0.05 + elif e.key == 'd': + xcoor.value += 0.05 + elif e.key == 's': + radius.value -= 1 + elif e.key == 'w': + radius.value += 1 + elif e.key == okay: + print('OK clicked') + + gui.circle((xcoor.value, 0.5), radius=radius.value) + gui.show() +``` + + + +## Image I/O + +`ti.imwrite(img, filename)` exports an `np.ndarray` or a Taichi field +(`ti.Matrix.field`, `ti.Vector.field`, or `ti.field`) to a file with a specified `filename`. + +Same as `ti.GUI.show(filename)`, the format of the exported image is determined by **the suffix of** `filename` as well. Now `ti.imwrite` supports exporting images to `png`, `img` and `jpg` and we recommend using `png`. + +Please make sure that the input image has **a valid shape**. If you want to export a grayscale image, the input shape of field should be `(height, weight)` or `(height, weight, 1)`. For example: + +```python +import taichi as ti + +ti.init() + +shape = (512, 512) +type = ti.u8 +pixels = ti.field(dtype=type, shape=shape) + +@ti.kernel +def draw(): + for i, j in pixels: + pixels[i, j] = ti.random() * 255 # integers between [0, 255] for ti.u8 + +draw() + +ti.imwrite(pixels, f"export_u8.png") +``` + +Besides, for RGB or RGBA images, `ti.imwrite` needs to receive a field which has shape `(height, width, 3)` and `(height, width, 4)` individually. + +Generally the value of the pixels on each channel of a `png` image is an integer in \[0, 255\]. For this reason, `ti.imwrite` will **cast fields** which has different data types all **into integers between \[0, 255\]**. As a result, `ti.imwrite` has the following requirements for different data types of input fields: + +- For float-type (`ti.f16`, `ti.f32`, etc.) input fields, **the value of each pixel should be float between \[0.0, 1.0\]**. Otherwise `ti.imwrite` will first clip them into \[0.0, 1.0\]. Then they are multiplied by 256 and cast to integers ranging from \[0, 255\]. +- For int-type (`ti.u8`, `ti.u16`, etc.) input fields, **the value of each pixel can be any valid integer in its own bounds**. These integers in this field will be scaled to \[0, 255\] by being divided over the upper bound of its basic type accordingly. + +Here is another example: + +```python +import taichi as ti + +ti.init() + +shape = (512, 512) +channels = 3 +type = ti.f32 +pixels = ti.Matrix.field(channels, dtype=type, shape=shape) + +@ti.kernel +def draw(): + for i, j in pixels: + for k in ti.static(range(channels)): + pixels[i, j][k] = ti.random() # floats between [0, 1] for ti.f32 + +draw() + +ti.imwrite(pixels, f"export_f32.png") +``` + +## Zero-copying frame buffer +When the GUI resolution (window size) is large, it sometimes becomes difficult to achieve 60 FPS even without any kernel +invocations between two frames. + +This is mainly due to the copy overhead, where Taichi GUI needs to copy the image buffer from one place to another. +This process is necessary for the 2D drawing functions, such as `gui.circles`, to work. The larger the image shape is, +the larger the overhead. + +Fortunately, sometimes your program only needs `gui.set_image` alone. In such cases, you can enable the `fast_gui` option +for better performance. This mode allows Taichi GUI to directly write the image data to the frame buffer without additional +copying, resulting in a much better FPS. + +```python +gui = ti.GUI(res, title, fast_gui=True) +``` + +:::note + +Because of the zero-copying mechanism, the image passed into `gui.set_image` must already be in the display-compatible +format. That is, this field must either be a `ti.Vector(3)` (RGB) or a `ti.Vector(4)` (RGBA). In addition, each channel +must be of type `ti.f32`, `ti.f64` or `ti.u8`. + +::: + +:::note + +If possible, consider enabling this option, especially when `fullscreen=True`. + +::: + +:::caution + +Despite the performance boost, it has many limitations as trade off: + +`gui.set_image` is the only available paint API in this mode. + +`gui.set_image` will only take Taichi 3D or 4D vector fields (RGB or RGBA) as input. + +::: diff --git a/docs/lang/articles/misc/_category_.json b/docs/lang/articles/misc/_category_.json index da3cf059960d1..f32e7be2120fd 100644 --- a/docs/lang/articles/misc/_category_.json +++ b/docs/lang/articles/misc/_category_.json @@ -1,4 +1,4 @@ { "label": "Miscellaneous Topics", - "position": 4 + "position": 5 } diff --git a/docs/lang/articles/misc/debugging.md b/docs/lang/articles/misc/debugging.md index e87c5e15f1a87..cd8f08da247e5 100644 --- a/docs/lang/articles/misc/debugging.md +++ b/docs/lang/articles/misc/debugging.md @@ -221,8 +221,7 @@ def copy(dst: ti.template(), src: ti.template()): ## Pretty Taichi-scope traceback -Sometimes the Python stack tracebacks resulted from **Taichi-scope** errors -could be too complicated to read. For example: +Taichi reports traceback messages when encountered errors in **Taichi-scope**. For example: ```python import taichi as ti @@ -247,117 +246,88 @@ def func0(): func0() ``` -The above snippet would result in an `AssertionError`: +The above snippet would trigger a long and scaring `AssertionError`: ``` Traceback (most recent call last): - File "misc/demo_excepthook.py", line 20, in - func0() - File "/root/taichi/python/taichi/lang/kernel.py", line 559, in wrapped - return primal(*args, **kwargs) - File "/root/taichi/python/taichi/lang/kernel.py", line 488, in __call__ - self.materialize(key=key, args=args, arg_features=arg_features) - File "/root/taichi/python/taichi/lang/kernel.py", line 367, in materialize - taichi_kernel = taichi_kernel.define(taichi_ast_generator) - File "/root/taichi/python/taichi/lang/kernel.py", line 364, in taichi_ast_generator - compiled() - File "misc/demo_excepthook.py", line 18, in func0 - func1() - File "/root/taichi/python/taichi/lang/kernel.py", line 39, in decorated - return fun.__call__(*args) - File "/root/taichi/python/taichi/lang/kernel.py", line 79, in __call__ - ret = self.compiled(*args) - File "misc/demo_excepthook.py", line 14, in func1 - func2() - File "/root/taichi/python/taichi/lang/kernel.py", line 39, in decorated - return fun.__call__(*args) - File "/root/taichi/python/taichi/lang/kernel.py", line 79, in __call__ - ret = self.compiled(*args) - File "misc/demo_excepthook.py", line 10, in func2 - func3() - File "/root/taichi/python/taichi/lang/kernel.py", line 39, in decorated - return fun.__call__(*args) - File "/root/taichi/python/taichi/lang/kernel.py", line 79, in __call__ - ret = self.compiled(*args) - File "misc/demo_excepthook.py", line 6, in func3 - ti.static_assert(1 + 1 == 3) - File "/root/taichi/python/taichi/lang/error.py", line 14, in wrapped - return foo(*args, **kwargs) - File "/root/taichi/python/taichi/lang/impl.py", line 252, in static_assert + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 23, in __call__ + return method(ctx, node) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer.py", line 342, in build_Call + node.ptr = node.func.ptr(*args, **keywords) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/impl.py", line 471, in static_assert assert cond AssertionError -``` -Many stack frames are the Taichi compiler implementation details, which -could be too noisy to read. You could choose to ignore them by using -`ti.init(excepthook=True)`, which _hooks_ on the exception handler and makes -the stack traceback from Taichi-scope more intuitive: +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 23, in __call__ + return method(ctx, node) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer.py", line 360, in build_Call + node.ptr = node.func.ptr(*args, **keywords) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/kernel_impl.py", line 59, in decorated + return fun.__call__(*args) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/kernel_impl.py", line 178, in __call__ + ret = transform_tree(tree, ctx) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/transform.py", line 8, in transform_tree + ASTTransformer()(ctx, tree) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 26, in __call__ + raise e + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 23, in __call__ + return method(ctx, node) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer.py", line 488, in build_Module + build_stmt(ctx, stmt) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 26, in __call__ + raise e + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 23, in __call__ + return method(ctx, node) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer.py", line 451, in build_FunctionDef + build_stmts(ctx, node.body) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer.py", line 1086, in build_stmts + build_stmt(ctx, stmt) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 26, in __call__ + raise e + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 23, in __call__ + return method(ctx, node) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer.py", line 964, in build_Expr + build_stmt(ctx, node.value) + File "/Users/lanhaidong/taichi/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 32, in __call__ + raise TaichiCompilationError(msg) +taichi.lang.exception.TaichiCompilationError: On line 10 of file "misc/demo_traceback.py": + ti.static_assert(1 + 1 == 3) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +AssertionError: + +... +``` +The error message can be verbose and scary. However, many stack frames reveal +Taichi compiler implementation details, which are too noisy for debugging. +In current verison, you could choose to supress the level of traceback messages by setting `sys.tracebacklimit`, +which makes the stack traceback from Taichi-scope more intuitive: ```python {2} import taichi as ti -ti.init(excepthook=True) +import sys +sys.tracebacklimit=0 ... ``` which makes the result look like: ```python -========== Taichi Stack Traceback ========== -In () at misc/demo_excepthook.py:21: --------------------------------------------- -@ti.kernel -def func0(): - func1() - -func0() <-- --------------------------------------------- -In func0() at misc/demo_excepthook.py:19: --------------------------------------------- - func2() - -@ti.kernel -def func0(): - func1() <-- +AssertionError -func0() --------------------------------------------- -In func1() at misc/demo_excepthook.py:15: --------------------------------------------- - func3() +During handling of the above exception, another exception occurred: -@ti.func -def func1(): - func2() <-- - -@ti.kernel --------------------------------------------- -In func2() at misc/demo_excepthook.py:11: --------------------------------------------- +taichi.lang.exception.TaichiCompilationError: On line 10 of file "misc/demo_traceback.py": ti.static_assert(1 + 1 == 3) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +AssertionError: -@ti.func -def func2(): - func3() <-- - -@ti.func --------------------------------------------- -In func3() at misc/demo_excepthook.py:7: --------------------------------------------- -ti.enable_excepthook() - -@ti.func -def func3(): - ti.static_assert(1 + 1 == 3) <-- - -@ti.func --------------------------------------------- -AssertionError +... ``` -:::note -For IPython / Jupyter notebook users, the IPython stack traceback hook -will be overriden by the Taichi one when `ti.enable_excepthook()` is called. -::: +Moreover, when filing an issue, please always unset the `sys.tracebacklimit` value and paste full traceback messages. ## Debugging Tips diff --git a/docs/lang/articles/misc/export_kernels.md b/docs/lang/articles/misc/export_kernels.md index 51edb87f7ee01..8847dcb2b569a 100644 --- a/docs/lang/articles/misc/export_kernels.md +++ b/docs/lang/articles/misc/export_kernels.md @@ -32,7 +32,7 @@ Linux. In the future, we will support macOS and Windows. Use `ti.core.start_recording` in the Taichi program you want to export. Suppose you want to export -[examples/mpm88.py](https://github.com/taichi-dev/taichi/blob/master/examples/mpm88.py), +[examples/mpm88.py](https://github.com/taichi-dev/taichi/blob/master/python/taichi/examples/mpm88.py), here is the workflow: ### Export YAML diff --git a/docs/lang/articles/misc/export_results.md b/docs/lang/articles/misc/export_results.md index 9db2c468f7c7e..818928e1cd2da 100644 --- a/docs/lang/articles/misc/export_results.md +++ b/docs/lang/articles/misc/export_results.md @@ -94,13 +94,13 @@ print(f'The image has been saved to {filename}') :::note All Taichi fields have their own data types, such as `ti.u8` and `ti.f32`. Different data types can lead to different behaviors of -`ti.imwrite`. Please check out [GUI system](./gui.md) for +`ti.imwrite`. Please check out [GUI system](../gui/gui.md) for more details. ::: - Taichi offers other helper functions that read and show images in addition to `ti.imwrite`. They are also demonstrated in - [GUI system./gui.md). + [GUI system](../gui/gui.md). ## Export videos diff --git a/docs/lang/articles/misc/global_settings.md b/docs/lang/articles/misc/global_settings.md index 6dd76d93ef0cb..ef87dcdc886b0 100644 --- a/docs/lang/articles/misc/global_settings.md +++ b/docs/lang/articles/misc/global_settings.md @@ -20,10 +20,6 @@ sidebar_position: 7 errors: `ti.init(advanced_optimization=False)`. - Disable fast math to prevent possible undefined math behavior: `ti.init(fast_math=False)`. -- To print preprocessed Python code: - `ti.init(print_preprocessed=True)`. -- To show pretty Taichi-scope stack traceback: - `ti.init(excepthook=True)`. - To print intermediate IR generated: `ti.init(print_ir=True)`. ## Runtime @@ -46,7 +42,7 @@ sidebar_position: 7 - Cache compiled runtime bitcode in **dev mode** to save start up time: `export TI_CACHE_RUNTIME_BITCODE=1`. - To specify how many threads to run test: `export TI_TEST_THREADS=4` - or `ti test -t4`. + or `python tests/run_tests.py -t4`. ## Specifying `ti.init` arguments from environment variables diff --git a/docs/lang/articles/misc/gui.md b/docs/lang/articles/misc/gui.md deleted file mode 100644 index fc321deb817dd..0000000000000 --- a/docs/lang/articles/misc/gui.md +++ /dev/null @@ -1,529 +0,0 @@ ---- -sidebar_position: 1 - ---- - -# GUI system - -Taichi has a built-in GUI system to help users visualize results. - -## Create a window - -[`ti.GUI(name, res)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=gui%20gui#taichi.misc.gui.GUI) -creates a window. - -The following code show how to create a window of resolution `640x360`: - -```python -gui = ti.GUI('Window Title', (640, 360)) -``` - -:::note - -If you are running Taichi on a machine without a GUI environment, consider setting `show_gui` to `False`: - -```python -gui = ti.GUI('Window Title', (640, 360), show_gui=False) - -while gui.running: - ... - gui.show(f'{gui.frame:06d}.png') # save a series of screenshot -``` - -::: - -## Display a window - -[`gui.show(filename)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=show#taichi.misc.gui.GUI.show) -helps display a window. If `filename` is specified, a screenshot will be saved to the path. For example, the following saves frames of the window to `.png`s: - - for frame in range(10000): - render(img) - gui.set_image(img) - gui.show(f'{frame:06d}.png') - - - -## Paint on a window -Taichi's GUI supports painting simple geometric objects, such as lines, triangles, rectangles, circles, and text. - -:::note - -The position parameter of every drawing API expects input of 2-element tuples, -whose values are the relative position of the object range from 0.0 to 1.0. -(0.0, 0.0) stands for the lower left corner of the window, and (1.0, 1.0) stands for the upper right corner. - -Acceptable input for positions are taichi fields or numpy arrays. Primitive arrays in python are NOT acceptable. - -For simplicity, we use numpy arrays in the examples below. - -::: - -:::tip - -For detailed API description, please click on the API code. For instance, click on -`gui.get_image()` to see the description to get a GUI images. - -::: - -[`gui.set_image(pixels)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=set_image#taichi.misc.gui.GUI.set_image) -sets an image to display on the window. - -The image pixels are set from the values of `img[i, j]`, where `i` indicates the horizontal coordinates (from left to right) and `j` the vertical coordinates (from bottom to top). - -If the window size is `(x, y)`, then `img` must be one of: - -- `ti.field(shape=(x, y))`, a gray-scale image - -- `ti.field(shape=(x, y, 3))`, where `3` is for `(r, g, b)` channels - -- `ti.field(shape=(x, y, 2))`, where `2` is for `(r, g)` channels - -- `ti.Vector.field(3, shape=(x, y))` `(r, g, b)` channels on each component - -- `ti.Vector.field(2, shape=(x, y))` `(r, g)` channels on each component - -- `np.ndarray(shape=(x, y))` - -- `np.ndarray(shape=(x, y, 3))` - -- `np.ndarray(shape=(x, y, 2))` - -The data type of `img` must be one of: - -- `uint8`, range `[0, 255]` - -- `uint16`, range `[0, 65535]` - -- `uint32`, range `[0, 4294967295]` - -- `float32`, range `[0, 1]` - -- `float64`, range `[0, 1]` - -:::note - -When using `float32` or `float64` as the data type, `img` entries will be clipped into range [0, 1] for display. - -::: - -[`gui.get_image()`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=get_image#taichi.misc.gui.GUI.get_image) -gets the 4-channel (RGBA) image shown in the current GUI system. - -[`gui.circle(pos)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=circle#taichi.misc.gui.GUI.circle) -draws one solid circle. - -The color and radius of circles can be further specified with additional parameters. - -[`gui.circles(pos)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=circles#taichi.misc.gui.GUI.circles) -draws solid circles. - -The color and radius of circles can be further specified with additional parameters. For a single color, use the `color` parameter. -For multiple colors, use `palette` and `palette_indices` instead. - -:::note - -The unit of raduis in GUI APIs is number of pixels. - -::: - -For examples: -```python -gui.circles(pos, radius=3, color=0x068587) -``` -draws circles all with radius of 1.5 and blue color positioned at pos array. - -![circles](../static/assets/circles.png) -```python -gui.circles(pos, radius=3, palette=[0x068587, 0xED553B, 0xEEEEF0], palette_indices=material) -``` -draws circles with radius of 1.5 and three different colors differed by `material`, an integer array with the same size as -`pos`. Each integer in `material` indicates which color the associated circle use (i.e. array [0, 1, 2] indicates these three -circles are colored separately by the first, second, and third color in `palette`. - -![circles](../static/assets/colored_circles.png) - -[`gui.line(begin, end)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=line#taichi.misc.gui.GUI.line) -draws one line. - -The color and radius of lines can be further specified with additional parameters. - -[`gui.lines(begin, end)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=line#taichi.misc.gui.GUI.lines) -draws lines. - -`begin` and `end` both require input of positions. - -The color and radius of lines can be further specified with additional parameters. - -For example: -```python -gui.lines(begin=X, end=Y, radius=2, color=0x068587) -``` -draws line segments from X positions to Y positions with width of 2 and color in light blue. - -![lines](../static/assets/lines.png) - -[`gui.triangle(a, b, c)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=triangle#taichi.misc.gui.GUI.triangle) -draws one solid triangle. - -The color of triangles can be further specified with additional parameters. - -[`gui.triangles(a, b, c)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=triangles#taichi.misc.gui.GUI.triangles) -draws solid triangles. - -The color of triangles can be further specified with additional parameters. - -For example: -```python -gui.triangles(a=X, b=Y, c=Z, color=0xED553B) -``` -draws triangles with color in red and three points positioned at X, Y, and Z. - -![triangles](../static/assets/triangles.png) - -[`gui.rect(topleft, bottomright)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=rect#taichi.misc.gui.GUI.rect) -draws a hollow rectangle. - -The color and radius of the stroke of rectangle can be further specified with additional parameters. - -For example: -```python -gui.rect([0, 0], [0.5, 0.5], radius=1, color=0xED553B) -``` -draws a rectangle of top left corner at [0, 0] and bottom right corner at [0.5, 0.5], with stroke of radius of 1 and color in red. - -![rect](../static/assets/rect.png) - -[`gui.arrows(origin, direction)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=arrows#taichi.misc.gui.GUI.arrows) -draws arrows. - -`origin` and `direction` both require input of positions. `origin` refers to the positions of arrows' origins, `direction` -refers to the directions where the arrows point to relative to their origins. - -The color and radius of arrows can be further specified with additional parameters. - -For example: -```python -x = nunpy.array([[0.1, 0.1], [0.9, 0.1]]) -y = nunpy.array([[0.3, 0.3], [-0.3, 0.3]]) -gui.arrows(x, y, radius=1, color=0xFFFFFF) -``` -draws two arrow originated at [0.1, 0.1], [0.9, 0.1] and pointing to [0.3, 0.3], [-0.3, 0.3] with radius of 1 and color in white. - -![arrows](../static/assets/arrows.png) - -[`gui.arrow_field(direction)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=arrow_field#taichi.misc.gui.GUI.arrow_field) -draws a field of arrows. - -The `direction` requires a field of `shape=(col, row, 2)` where `col` refers to the number of columns of arrow field and `row` -refers to the number of rows of arrow field. - -The color and bound of arrow field can be further specified with additional parameters. - -For example: -```python -gui.arrow_field(x, bound=0.5, color=0xFFFFFF) # x is a field of shape=(5, 5, 2) -``` -draws a 5 by 5 arrows pointing to random directions. - -![arrow_field](../static/assets/arrow_field.png) - -[`gui.point_field(radius)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=point_field#taichi.misc.gui.GUI.point_field) -draws a field of points. - -The `radius` requires a field of `shape=(col, row)` where `col` refers to the number of columns of arrow field and `row` -refers to the number of rows of arrow field. - -The color and bound of point field can be further specified with additional parameters. - -For example: -```python -x = numpy.array([[3, 5, 7, 9], [9, 7, 5, 3], [6, 6, 6, 6]]) -gui.point_field(radius=x, bound=0.5, color=0xED553B) -``` -draws a 3 by 4 point field of radius stored in the array. - -![point_field](../static/assets/point_field.png) - -[`gui.text(content, pos)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=text#taichi.misc.gui.GUI.text) -draws a line of text on screen. - -The font size and color of text can be further specified with additional parameters. - -## RGB & Hex conversion. - -[`ti.hex_to_rgb(hex)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=hex_to_rgb#taichi.misc.gui.hex_to_rgb) -can convert a single integer value to a (R, G, B) tuple of floats. - -[`ti.rgb_to_hex(rgb)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=rgb#taichi.misc.gui.rgb_to_hex) -can convert a (R, G, B) tuple of floats into a single integer value, e.g., - -```python -rgb = (0.4, 0.8, 1.0) -hex = ti.rgb_to_hex(rgb) # 0x66ccff - -rgb = np.array([[0.4, 0.8, 1.0], [0.0, 0.5, 1.0]]) -hex = ti.rgb_to_hex(rgb) # np.array([0x66ccff, 0x007fff]) -``` - -The return values can be used in GUI drawing APIs. - - -## Event processing - -Every event have a key and type. - -_Event type_ is the type of event, for now, there are just three type of event: - - ti.GUI.RELEASE # key up or mouse button up - ti.GUI.PRESS # key down or mouse button down - ti.GUI.MOTION # mouse motion or mouse wheel - -_Event key_ is the key that you pressed on keyboard or mouse, can be one of: - - # for ti.GUI.PRESS and ti.GUI.RELEASE event: - ti.GUI.ESCAPE # Esc - ti.GUI.SHIFT # Shift - ti.GUI.LEFT # Left Arrow - 'a' # we use lowercase for alphabet - 'b' - ... - ti.GUI.LMB # Left Mouse Button - ti.GUI.RMB # Right Mouse Button - - # for ti.GUI.MOTION event: - ti.GUI.MOVE # Mouse Moved - ti.GUI.WHEEL # Mouse Wheel Scrolling - -A _event filter_ is a list combined of _key_, _type_ and _(type, key)_ tuple, e.g.: - -```python -# if ESC pressed or released: -gui.get_event(ti.GUI.ESCAPE) - -# if any key is pressed: -gui.get_event(ti.GUI.PRESS) - -# if ESC pressed or SPACE released: -gui.get_event((ti.GUI.PRESS, ti.GUI.ESCAPE), (ti.GUI.RELEASE, ti.GUI.SPACE)) -``` - -[`gui.running`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=running#taichi.misc.gui.GUI.running) -can help check the state of the window. `ti.GUI.EXIT` occurs when you click on the close (X) button of a window. - `gui.running` will obtain `False` when the GUI is being closed. - -For example, loop until the close button is clicked: - - while gui.running: - render() - gui.set_image(pixels) - gui.show() - -You can also close the window by manually setting `gui.running` to`False`: - - while gui.running: - if gui.get_event(ti.GUI.ESCAPE): - gui.running = False - - render() - gui.set_image(pixels) - gui.show() - -[`gui.get_event(a, ...)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=get_event#taichi.misc.gui.GUI.get_event) -tries to pop an event from the queue, and stores it into `gui.event`. - -For example: - - if gui.get_event(): - print('Got event, key =', gui.event.key) - -For example, loop until ESC is pressed: - - gui = ti.GUI('Title', (640, 480)) - while not gui.get_event(ti.GUI.ESCAPE): - gui.set_image(img) - gui.show() - -[`gui.get_events(a, ...)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=get_event#taichi.misc.gui.GUI.get_events) -is basically the same as `gui.get_event`, except that it returns a generator of events instead of storing into `gui.event`: - - for e in gui.get_events(): - if e.key == ti.GUI.ESCAPE: - exit() - elif e.key == ti.GUI.SPACE: - do_something() - elif e.key in ['a', ti.GUI.LEFT]: - ... - -[`gui.is_pressed(key, ...)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=is_pressed#taichi.misc.gui.GUI.is_pressed) -can detect the keys you pressed. It must be used together with `gui.get_event`, or it won't be updated! For -example: - - while True: - gui.get_event() # must be called before is_pressed - if gui.is_pressed('a', ti.GUI.LEFT): - print('Go left!') - elif gui.is_pressed('d', ti.GUI.RIGHT): - print('Go right!') - -:::caution - -`gui.is_pressed()` must be used together with `gui.get_event()`, or it won't be updated! - -::: - -For example: - -```python -while True: - gui.get_event() # must be called before is_pressed - if gui.is_pressed('a', ti.GUI.LEFT): - print('Go left!') - elif gui.is_pressed('d', ti.GUI.RIGHT): - print('Go right!') -``` - -[`gui.get_cursor_pos()`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=get_cursor#taichi.misc.gui.GUI.get_cursor_pos) -can return current cursor position within the window. For example: - - mouse_x, mouse_y = gui.get_cursor_pos() - -[`gui.fps_limit`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=fps#taichi.misc.gui.GUI.fps_limit) -sets the FPS limit for a window. For example, to cap FPS at 24, simply use `gui.fps_limit = 24`. This helps reduce the overload on your hardware especially when you're using OpenGL on your integrated GPU which could make desktop slow to response. - - - -## GUI Widgets - -Sometimes it's more intuitive to use widgets like slider or button to control the program variables instead of using chaotic keyboard bindings. Taichi GUI provides a set of widgets for that reason: - -[`gui.slider(text, min, max)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=slider#taichi.misc.gui.GUI.slider) -creates a slider following the text `{text}: {value:.3f}`. - -[`gui.label(text)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=label#taichi.misc.gui.GUI.label) -displays the label as: `{text}: {value:.3f}`. - -[`gui.button(text)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=button#taichi.misc.gui.GUI.button) -creates a button with text on it. - -For example: -```python -radius = gui.slider('Radius', 1, 50) - -while gui.running: - print('The radius now is', radius.value) - ... - radius.value += 0.01 - ... - gui.show() -``` - - - -## Image I/O - -[`ti.imwrite(img, filename)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=imwrite#taichi.misc.image.imwrite) -can export a `np.ndarray` or Taichi field (`ti.Matrix.field`, `ti.Vector.field`, or `ti.field`) to a specified location `filename`. - -Same as `ti.GUI.show(filename)`, the format of the exported image is determined by **the suffix of** `filename` as well. Now `ti.imwrite` supports exporting images to `png`, `img` and `jpg` and we recommend using `png`. - -Please make sure that the input image has **a valid shape**. If you want to export a grayscale image, the input shape of field should be `(height, weight)` or `(height, weight, 1)`. For example: - -```python -import taichi as ti - -ti.init() - -shape = (512, 512) -type = ti.u8 -pixels = ti.field(dtype=type, shape=shape) - -@ti.kernel -def draw(): - for i, j in pixels: - pixels[i, j] = ti.random() * 255 # integers between [0, 255] for ti.u8 - -draw() - -ti.imwrite(pixels, f"export_u8.png") -``` - -Besides, for RGB or RGBA images, `ti.imwrite` needs to receive a field which has shape `(height, width, 3)` and `(height, width, 4)` individually. - -Generally the value of the pixels on each channel of a `png` image is an integer in \[0, 255\]. For this reason, `ti.imwrite` will **cast fields** which has different data types all **into integers between \[0, 255\]**. As a result, `ti.imwrite` has the following requirements for different data types of input fields: - -- For float-type (`ti.f16`, `ti.f32`, etc.) input fields, **the value of each pixel should be float between \[0.0, 1.0\]**. Otherwise `ti.imwrite` will first clip them into \[0.0, 1.0\]. Then they are multiplied by 256 and cast to integers ranging from \[0, 255\]. -- For int-type (`ti.u8`, `ti.u16`, etc.) input fields, **the value of each pixel can be any valid integer in its own bounds**. These integers in this field will be scaled to \[0, 255\] by being divided over the upper bound of its basic type accordingly. - -Here is another example: - -```python -import taichi as ti - -ti.init() - -shape = (512, 512) -channels = 3 -type = ti.f32 -pixels = ti.Matrix.field(channels, dtype=type, shape=shape) - -@ti.kernel -def draw(): - for i, j in pixels: - for k in ti.static(range(channels)): - pixels[i, j][k] = ti.random() # floats between [0, 1] for ti.f32 - -draw() - -ti.imwrite(pixels, f"export_f32.png") -``` - -[`ti.imread(filename)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=imread#taichi.misc.image.imread) -loads an image from the target filename and returns it as a `np.ndarray(dtype=np.uint8)`. -Each value in this returned field is an integer in [0, 255]. - -[`ti.imshow(img, windname)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=imshow#taichi.misc.image.imshow) -creates an instance of ti.GUI and show the input image on the screen. It has the same logic as `ti.imwrite` for different data types. - -[`ti.imresize(img, w)`](https://api-docs.taichi.graphics/src/taichi.misc.html?highlight=imresize#taichi.misc.image.imresize) -resizes the img specified. - -## Zero-copying frame buffer -When the GUI resolution (window size) is large, it sometimes becomes difficult to achieve 60 FPS even without any kernel -invocations between two frames. - -This is mainly due to the copy overhead, where Taichi GUI needs to copy the image buffer from one place to another. -This process is necessary for the 2D drawing functions, such as `gui.circles`, to work. The larger the image shape is, -the larger the overhead. - -Fortunately, sometimes your program only needs `gui.set_image` alone. In such cases, you can enable the `fast_gui` option -for better performance. This mode allows Taichi GUI to directly write the image data to the frame buffer without additional -copying, resulting in a much better FPS. - -```python -gui = ti.GUI(res, title, fast_gui=True) -``` - -:::note - -Because of the zero-copying mechanism, the image passed into `gui.set_image` must already be in the display-compatible -format. That is, this field must either be a `ti.Vector(3)` (RGB) or a `ti.Vector(4)` (RGBA). In addition, each channel -must be of type `ti.f32`, `ti.f64` or `ti.u8`. - -::: - -:::note - -If possible, consider enabling this option, especially when `fullscreen=True`. - -::: - -:::caution - -Despite the performance boost, it has many limitations as trade off: - -`gui.set_image` is the only available paint API in this mode. - -`gui.set_image` will only take Taichi 3D or 4D vector fields (RGB or RGBA) as input. - -::: diff --git a/docs/lang/articles/misc/internal.md b/docs/lang/articles/misc/internal.md index 610f93292347a..692b50f26947f 100644 --- a/docs/lang/articles/misc/internal.md +++ b/docs/lang/articles/misc/internal.md @@ -330,7 +330,7 @@ a = ti.field(ti.f32, shape=(128, 32, 8)) b = ti.field(ti.f32) ti.root.dense(ti.j, 32).dense(ti.i, 16).place(b) -ti.get_runtime().materialize() +ti.lang.impl.get_runtime().materialize() # This is an internal api for dev, we don't make sure it is stable for user. mapping_a = a.snode().physical_index_position() diff --git a/docs/lang/articles/tutorials/_category_.json b/docs/lang/articles/tutorials/_category_.json new file mode 100644 index 0000000000000..6e052e4632618 --- /dev/null +++ b/docs/lang/articles/tutorials/_category_.json @@ -0,0 +1,4 @@ +{ + "label": "Tutorials", + "position": 7 +} diff --git a/docs/lang/articles/tutorials/ndarray_android.md b/docs/lang/articles/tutorials/ndarray_android.md new file mode 100644 index 0000000000000..0349ed9b2e1a0 --- /dev/null +++ b/docs/lang/articles/tutorials/ndarray_android.md @@ -0,0 +1,370 @@ +--- +sidebar_position: 1 +--- + +# Run a Taichi Program using Ndarray on Android + +Taichi's JIT (Just In Time) module compiles a Taichi kernel to the compute shaders according to the specified backend (`arch` in `ti.init()`) and executes these shaders in Taichi's JIT runtime. Taichi's AOT (Ahead of Time) module, however, builds and saves the necessary compute shaders so that you can load and execute these shaders in your own runtime without a Python environment. + +Taking a simulation of celestial bodies' orbits as an example, this tutorial walks you through the process of running a Taichi program using Ndarray on Android. + +> [Taichi's AOT (Ahead Of Time) module](https://github.com/taichi-dev/taichi/issues/3642) is currently a proof of concept under development and subject to change in the future. + +## A definition of Ndarray + +Taichi provides a data container called Ndarray. An Ndarray is a multidimensional container of elements of the same type and size; an element in an Ndarray is virtually a scalar or a tensor. + +### Ndarray shape + +Ndarray shape defines the Ndarray's layout; element shape defines the element's layout. For example: + +- An Ndarray with an Ndarray shape of [2, 1024] and an element shape of [] is an array of 2 x 1,024 = 2,048 scalars. +- An Ndarray with an Ndarray shape of [128, 128] and an element shape of [2, 4] is an array of 128 x 128 = 16,384 2 x 4 matrices. + +### Ndarray dimension + +The dimension here refers to the number of dimensions of an Ndarray. For example: + +- The dimension of an Ndarray with a shape of [1, 2, 3] is three. +- The dimension of an Ndarray with a shape of [500] is one. + +### Benefit of Ndarray + +Each Ndarray has a fixed dimension but gives you the flexibility of changing its shape in accordance with its dimension. + +Unlike a field's shape, which requires you to rewrite and recompile your Taichi program once it is changed, an Ndarray's shape can be *dynamically* changed without the need to recompile. + +Taking the simulation of celestial bodies' orbits as an example, suppose you wish to double the number of your celestial bodies to 2,000: + +- With Taichi field, you have to compile twice; +- With Ndarray, all you need to do is to update your runtime program. + +## Run a Taichi program using Ndarray on Android + +The following section walks you through the process of running a Taichi program using Ndarray on Android. + +1. [Generate necessary compute shaders](#generate-necessary-compute-shaders) +2. [Parse the generated JSON file](#parse-the-generated-json-file) +3. [Prepare SSBO and shape information](#prepare-ssbo-and-shape-information) +4. [Prepare rendering shaders](#prepare-rendering-shaders) +5. [Execute all shaders](#execute-all-shaders) + +:::note + +From Step 2, you are required to come up with your own runtime program. We provide an [example Java runtime program for Android](https://github.com/taichi-dev/taichi-aot-demo/blob/master/nbody_ndarray/java_runtime/NbodyNdarray.java) for your reference, but you may need to adapt these codes for your platform and programming language. + +::: + +### Generate necessary compute shaders + +The following Python script defines a Taichi AOT module for generating and saving the necessary compute shaders (GLES shaders in this case) based on the chosen backend (OpenGL). + +> Taichi kernels and compute shaders are *not* one-to-one mapping. Each Taichi kernel can generate multiple compute shaders, the number *usually* comparable to that of the loops in the kernel. + + + +```python +import taichi as ti + +ti.init(arch=ti.opengl, use_gles=True, allow_nv_shader_extension=False) + +# Define constants for computation +G = 1 +PI = 3.141592653 +N = 1000 +m = 5 +galaxy_size = 0.4 +planet_radius = 1 +init_vel = 120 +h = 1e-5 +substepping = 10 + +# Define Taichi kernels +@ti.kernel +def initialize(pos: ti.any_arr(element_dim=1), vel: ti.any_arr(element_dim=1)): + center=ti.Vector([0.5, 0.5]) + for i in pos: + theta = ti.random() * 2 * PI + r = (ti.sqrt(ti.random()) * 0.7 + 0.3) * galaxy_size + offset = r * ti.Vector([ti.cos(theta), ti.sin(theta)]) + pos[i] = center+offset + vel[i] = [-offset.y, offset.x] + vel[i] *= init_vel + +@ti.kernel +def compute_force(pos: ti.any_arr(element_dim=1), vel: ti.any_arr(element_dim=1), force: ti.any_arr(element_dim=1)): + for i in pos: + force[i] = ti.Vector([0.0, 0.0]) + for i in pos: + p = pos[i] + for j in pos: + if i != j: + diff = p-pos[j] + r = diff.norm(1e-5) + f = -G * m * m * (1.0/r)**3 * diff + force[i] += f + dt = h/substepping + for i in pos: + vel[i].atomic_add(dt*force[i]/m) + pos[i].atomic_add(dt*vel[i]) + +# Define Ndarrays +pos = ti.Vector.ndarray(2, ti.f32, N) +vel = ti.Vector.ndarray(2, ti.f32, N) +force = ti.Vector.ndarray(2, ti.f32, N) + +# Run the AOT module builder +def aot(): + m = ti.aot.Module(ti.opengl) + m.add_kernel(initialize, (pos, vel)) + m.add_kernel(compute_force, (pos, vel, force)) + + dir_name = 'nbody_aot' + m.save(dir_name, '') +aot() +``` + +**In line 3, you initialize Taichi:** + +1. Set `use_gles` to `True` to generate GLES compute shaders for Android. +2. Set `allow_nv_shader_extension` to `False` to prevent the generated GLES compute shaders from using Nvidia GL extensions on Android. + +> This setting is because Android supports GLES APIs but GLES does not support `NV_SHADER_EXTENSION`. + +**In line 50-58, you define and build the Taichi AOT module:** + +1. Create a Taichi AOT module, specifying its backend as OpenGL: + +```python + m = ti.aot.Module(ti.opengl) +``` + +2. Add the required kernels `initialize` and `compute_force`, each with its own Ndarrays, to the module: + +```python +m.add_kernel(initialize, (pos, vel)) + +m.add_kernel(compute_force, (pos, vel, force)) +``` + +3. Specify a folder under your current working directory for holding the files that the module generates: + +```python +dir_name = 'nbody_aot' + +m.save(dir_name, '') +``` + +*The necessary compute shaders together with a JSON file appear under the specified directory.* + +### Parse the generated JSON file + +:::note + +From this section, you are required to come up with your own runtime program. We provide an [example Java runtime program for Android](https://github.com/taichi-dev/taichi-aot-demo/blob/master/nbody_ndarray/java_runtime/NbodyNdarray.java) for your reference, but you may need to adapt these codes for your platform and programming language. + +::: + +After generating the necessary GLES compute shaders, you need to write your runtime program to parse the following JSON file to some data structures. The JSON file contains all the necessary information for executing the compute shaders. Organized by Taichi kernel, it provides a clearer image of the compute shaders and Ndarrays in each kernel. Let's take a closer look at the structure. + +> Here, the JSON object for the `compute_force` kernel is omitted for brevity. For a complete JSON file, see [metadata.json](https://github.com/taichi-dev/taichi-aot-demo/blob/master/nbody_ndarray/res/metadata.json). + +- **Organized by Taichi kernel** + + - `initialize` (line 4) + - `compute_force` (line 51) + +- **Kernel-specific compute shaders** + + Taking `initialize` as an example, the kernel has generated one compute shader named `initialize_c54_00` (line 7) and the other named `initialize_c54_01` (line 13). + +- **Kernel-specific** `args_buff` + + The `initialize` kernel is assigned an `args_buffer` of `128` Bytes (line 21). Note that the size of `args_buffer` is dependent on the number of Ndarrays (`pos` and `vel`) that the kernel takes, (see `arg_count` in line 19). The `initialize` kernel, or each kernel more precisely, has a dedicated `args_buffer` for storing scalar arguments specified in `scalar_args` (line 27) and Ndarray shape information in accordance with what `array_args` (line 28-45) specifies. + + Ndarrays' shape information is organized by their argument index in the `array_args` JSON array: `0` (line 29) corresponds to the `pos` Ndarray, and `1` (line 37) corresponds to the `vel` Ndarray. The argument index is determined by the sequence by which you pass in the Ndarrays when calling `add_kernel()`. See line 53 in the Python script. + + The `pos` Ndarray's shape information in `args_buffer` has an offset of `64` Bytes in `args_buffer` (line 64). According to line 35 and line 43, the `pos` Ndarray's shape information occupies 96 - 64 = 32 Bytes in `args_buffer`. + + :::tip ATTENTION + The JSON file only specifies the dimension of the corresponding Ndarray (line 30, 38), allowing you to dynamically update an Ndarray's shape in your runtime program. + ::: + +- **Kernel-specific binding index** + + `used.arr_arg_to_bind_idx` (line 46) maps the SSBO of each Ndarray in the kernel to a "more global" binding index for the compute shaders. For example, `"1": 5` (line 48) binds the `vel` Ndarray to the binding index `5`. + +```json +{ + "aot_data": { + "kernels": { + "initialize": { + "tasks": [ + { + "name": "initialize_c54_00", + "src": "nbody_aot/initialize_c54_00.glsl", + "workgroup_size": 1, + "num_groups": 1 + }, + { + "name": "initialize_c54_01", + "src": "nbody_aot/initialize_c54_01.glsl", + "workgroup_size": 128, + "num_groups": 256 + } + ], + "arg_count": 2, + "ret_count": 0, + "args_buf_size": 128, + "ret_buf_size": 0, + "ext_arr_access": { + "0": 2, + "1": 3 + }, + "scalar_args": {}, + "arr_args": { + "0": { + "field_dim": 1, + "is_scalar": false, + "element_shape": [ + 2 + ], + "shape_offset_in_bytes_in_args_buf": 64 + }, + "1": { + "field_dim": 1, + "is_scalar": false, + "element_shape": [ + 2 + ], + "shape_offset_in_bytes_in_args_buf": 96 + } + }, + "used.arr_arg_to_bind_idx": { + "0": 4, + "1": 5 + } + }, + "compute_force": {...} + }, + "kernel_tmpls": {}, + "fields": [], + "root_buffer_size": 0 + } +} +``` +The following provides a detailed description of the keys in the generated JSON file: + +`aot_data`: The overarching JSON object. + - `kernels`: All Taichi kernels. + - `$(kernel_name)`: Name of a specific Taichi kernel. + - `tasks`: A JSON array of the generated compute shaders. + - `name`: Name of a specific compute shader. + - `src`: Relative path to the shader file. + - `workgroup_size`: N/A + - `num_groups`: N/A + - `arg_count`: Number of the arguments that the Taichi kernel takes. + - `ret_count`: Number of the values that the Taichi kernel returns. + - `args_buf_size`: The size of `args_buf` in Bytes. + - `ret_buf_size`: The size of `ret_buf` in Bytes. + - `scalar_args`: Scalar arguments that the kernel takes. + - `arr_args`: Shape information of the Ndarrays in the kernel. + - `$(arg_index)`: Argument index of an Ndarray + - `field_dim`: The dimension of the Ndarray. + - `is_scalar`: Whether the elements in the Ndarray are scalar. + - `element_shape`: An `int` array indicating the shape of each element in the Ndarray. + - `shape_offset_in_bytes_in_args_buf`: The offset of the Ndarray's shape information in `args_buf`. + - `used.arr_arg_to_bind_idx`: A map specifying the SSBO to bind for a given Ndarray. For example, `"1": 5` (line 48) binds the `vel` Ndarray to the binding index `5`. + +*Well, we hope you were not overwhelmed with that much information coming in all at once. In the following section, we will revisit the JSON file, as well as provide tables and graphs that help illustrate some of the concepts and notions listed above.* + +### Prepare SSBO and shape information + +Before executing the GLES compute shaders in your runtime program, you need to get all your resources ready, including: + +- Bind SSBO for the corresponding buffer +- Bind SSBO for each Ndarray +- Fill `args_buffer` with Ndarray shape information + +#### Bind SSBO for the corresponding buffer + +The following table lists the buffers commonly used in a Taichi program together with their binding indexes: + +| **Buffer** | **Global/kernel-spedific** | **Storing** | **Binding index** | +| ------------- | -------------------------- | ------------------------------------------------------------ | ----------------- | +| `root_buffer` | Global | All fields with fixed offsets and of fixed sizes. | `0` | +| `gtmp_buffer` | Global | Global temporary data | `1` | +| `args_buffer` | Kernel-specific | Arguments passed to the Taichi kernel
  • Scalar arguments
  • Each Ndarray's shape information:
    • Shape of the Ndarray
    • Element shape
| `2` | + +1. You *only* need to bind an SSBO for `root_buffer` if your Taichi script uses at least one field. Skip this step if your script does not involve field. +2. Bind a small SSBO, say an SSBO of 1,024 Bytes, to `1`, the binding index of `gtmp_buffer`. +3. Bind an SSBO of 64 x 5 = 320 Bytes to `2`, the binding index of `args_buffer`. + +#### Bind SSBO for each Ndarray + +Before running a specific kernel in your runtime program (the `initialize` kernel for example), you must bind SSBO of a proper size for each Ndarray in the kernel in accordance to the value of `used.arr_arg_to_bind_idx`. + +The following is a summary of line 29-49 of the above JSON file: + +| Ndarray | Taichi kernel | Dimension | Element shape | Argument index | Binding index | +| ------- | ------------- | --------- | ------------- | -------------- | ------------- | +| `pos` | `initialize` | `1` | `[2]` | `0` | `4` | +| `vel` | `initialize` | `1` | `[2]` | `1` | `5` | + +If you give each Ndarray a shape [500], and an element shape [2] (meaning that each element is a 2-D vector): + +- Each Ndarray has 500 x 2 = 1,000 numbers +- Because the number type is float (as specified in the above Python script), the size of each Ndarray's SSBO is 1,000 x 4 = 4,000 Bytes. + +Therefore you need to: + +- Bind an SSBO of 4,000 Bytes to the binding index `4` for the `pos` Ndarray. +- Bind an SSBO of 4,000 Bytes to the binding index `5` for the `vel` Ndarray. + +#### Fill `args_buffer` with Ndarray shape information + +When explaining the JSON file, we mention that each kernel has a dedicated `args_buffer` for storing scalar arguments specified in `scalar_args` and Ndarray shape information in accordance with what `array_args` specifies. `array_args` does not specify the Ndarray shape, therefore the final step in your preparation is to fill `args_buffer` with each Ndarray's shape information in your runtime program. + +The typical size of an `args_buffer` is 64 + 64 x 4 Bytes. The first 64 Bytes are reserved for scalar arguments; the remaining buffer is then 64 x 4 Bytes. Each Ndarray is allocated 8 x 4 Bytes for storing its shape information (each has *at most* 8 numbers to indicate its shape information), therefore the remaining buffer can store up to 8 Ndarrays' shape information. + +- If your Ndarray shape is [100, 200] and element dimension [3, 2], then you fill 100, 200, 3, and 2 in the corresponding location. +- In this case, both `pos` and `vel` have an Ndarray shape of [500] and an element dimension of [2]. Therefore, you fill 500 and 2 in the corresponding locations. + +### Prepare rendering shaders + +To perform the rendering (drawing celestial bodies in this case), you are required to write a vertex shader and a fragment shader. + +### Execute all shaders + +When executing shaders in your runtime program, ensure that you bind SSBOs before executing a Taichi kernel and unbind them when you are done. + + Our [example Android Java runtime](https://github.com/taichi-dev/taichi-aot-demo/blob/master/nbody_ndarray/java_runtime/NbodyNdarray.java) does the following: + +1. Run the GLES compute shaders in `initialize` once. +2. For each frame: + 1. Run the GLES compute shaders in `compute_force` 10 times. + 2. Run the vertex and fragment shaders once to do the rendering. + + + +## OpenGL-specific Terms & Definitions + +### OpenGL ES (GLES) + +OpenGL ES (GLES) is the OpenGL APIs for Embedded Systems. According to [its specifications](https://www.khronos.org/api/opengles), a desktop OpenGL driver supports all GLES APIs. + +### OpenGL Shading Language (GLSL) + +The OpenGL Shading Language (GLSL) is the primary shading language for OpenGL. GLSL is a C-style language supported directly by OpenGL without extensions. + +### Shader + +A shader is a user-defined program designed for computing or rendering at a certain stage of a graphics processor. + +### SSBO (Shader Storage Buffer Object) + +Each Taichi kernel can generate multiple compute shaders, which use SSBO (Shader Storage Buffer Object) as buffer for accessing data. + +There are two types of SSBOs: One type corresponds to the buffers maintained by Taichi and includes `root_buffer`, `gtmp_buffer`, and `args_buffer`; the other type corresponds to the Ndarrays maintained by developers and used for sharing data. + +> You are required to bind the generated shaders to the corresponding SSBOs in your runtime program. The binding index of an Ndarray's SSBO starts off with `4`. diff --git a/docs/variable.json b/docs/variable.json new file mode 100644 index 0000000000000..0967ef424bce6 --- /dev/null +++ b/docs/variable.json @@ -0,0 +1 @@ +{} diff --git a/examples/algorithm/print_offset.py b/examples/algorithm/print_offset.py deleted file mode 100644 index b905126bd4f9b..0000000000000 --- a/examples/algorithm/print_offset.py +++ /dev/null @@ -1,36 +0,0 @@ -import taichi as ti - -ti.init(arch=ti.cpu, print_ir=True) - -n = 4 -m = 8 - -a = ti.field(dtype=ti.i32) -ti.root.dense(ti.ij, (1, 2)).dense(ti.ij, 2).dense(ti.ij, 2).place(a) - - -@ti.kernel -def fill(): - for i, j in a: - base = ti.get_addr(a.snode, [0, 0]) - a[i, j] = int(ti.get_addr(a.snode, [i, j]) - base) // 4 - - -fill() -print(a.to_numpy()) - -ti.get_runtime().prog.visualize_layout('layout.pdf') - -gui = ti.GUI('layout', res=(256, 512), background_color=0xFFFFFF) - -while True: - for i in range(1, m): - gui.line(begin=(0, i / m), end=(1, i / m), radius=2, color=0x000000) - for i in range(1, n): - gui.line(begin=(i / n, 0), end=(i / n, 1), radius=2, color=0x000000) - for i in range(n): - for j in range(m): - gui.text(f'{a[i, j]}', ((i + 0.3) / n, (j + 0.75) / m), - font_size=30, - color=0x0) - gui.show() diff --git a/examples/autodiff/minimization.py b/examples/autodiff/minimization.py deleted file mode 100644 index c59af83e0cdec..0000000000000 --- a/examples/autodiff/minimization.py +++ /dev/null @@ -1,40 +0,0 @@ -import random - -import taichi as ti - -ti.init(arch=ti.cpu) - -n = 8 -x = ti.field(dtype=ti.f32, shape=n, needs_grad=True) -y = ti.field(dtype=ti.f32, shape=n) -L = ti.field(dtype=ti.f32, shape=(), needs_grad=True) - - -@ti.kernel -def reduce(): - for i in range(n): - L[None] += 0.5 * (x[i] - y[i])**2 - - -# Initialize vectors -for i in range(n): - x[i] = random.random() - y[i] = random.random() - - -@ti.kernel -def gradient_descent(): - for i in x: - x[i] -= x.grad[i] * 0.1 - - -# Optimize with 100 gradient descent iterations -for k in range(100): - with ti.Tape(loss=L): - reduce() - print('Loss =', L[None]) - gradient_descent() - -for i in range(n): - # Now you should approximately have x[i] == y[i] - print(x[i], y[i]) diff --git a/examples/autodiff/regression.py b/examples/autodiff/regression.py deleted file mode 100644 index 2692b8afafa87..0000000000000 --- a/examples/autodiff/regression.py +++ /dev/null @@ -1,90 +0,0 @@ -import random - -import matplotlib.pyplot as plt -import numpy as np - -import taichi as ti -import taichi as tc - -ti.init(arch=ti.cpu) - -tc.set_gdb_trigger(True) - -number_coeffs = 4 -learning_rate = 1e-4 - -N = 32 -x, y = ti.field(ti.f32, shape=N, needs_grad=True), ti.field(ti.f32, - shape=N, - needs_grad=True) -coeffs = ti.field(ti.f32, shape=number_coeffs, needs_grad=True) -loss = ti.field(ti.f32, shape=(), needs_grad=True) - - -@ti.kernel -def regress(): - for i in x: - v = x[i] - est = 0.0 - for j in ti.static(range(number_coeffs)): - est += coeffs[j] * (v**j) - loss[None] += 0.5 * (y[i] - est)**2 - - -@ti.kernel -def update(): - for i in ti.static(range(number_coeffs)): - coeffs[i] -= learning_rate * coeffs.grad[i] - - -xs = [] -ys = [] - -for i in range(N): - v = random.random() * 5 - 2.5 - xs.append(v) - x[i] = v - y[i] = (v - 1) * (v - 2) * (v + 2) + random.random() - 0.5 - -regress() - -print('y') -for i in range(N): - y.grad[i] = 1 - ys.append(y[i]) -print() - -use_tape = True - -for i in range(1000): - if use_tape: - with ti.Tape(loss=loss): - regress() - else: - ti.clear_all_gradients() - loss[None] = 0 - loss.grad[None] = 1 - regress() - regress.grad() - print('Loss =', loss[None]) - update() - for i in range(number_coeffs): - print(coeffs[i], end=', ') - print() - -curve_xs = np.arange(-2.5, 2.5, 0.01) -curve_ys = curve_xs * 0 -for i in range(number_coeffs): - curve_ys += coeffs[i] * np.power(curve_xs, i) - -plt.title('Nonlinear Regression with Gradient Descent (3rd order polynomial)') -ax = plt.gca() -ax.scatter(xs, ys, label='data', color='r') -ax.plot(curve_xs, curve_ys, label='fitted') -ax.legend() -ax.grid(True) -ax.spines['left'].set_position('zero') -ax.spines['right'].set_color('none') -ax.spines['bottom'].set_position('zero') -ax.spines['top'].set_color('none') -plt.show() diff --git a/examples/autodiff/simple_derivative.py b/examples/autodiff/simple_derivative.py deleted file mode 100644 index 32d3e2bf93eb9..0000000000000 --- a/examples/autodiff/simple_derivative.py +++ /dev/null @@ -1,58 +0,0 @@ -import matplotlib.pyplot as plt - -import taichi as ti - -ti.init(arch=ti.cpu) - -N = 2048 -x, y = ti.field(ti.f32), ti.field(ti.f32) - -ti.root.dense(ti.i, N).place(x, x.grad, y, y.grad) - - -@ti.kernel -def poly(): - for i in x: - v = x[i] - ret = 0.0 - guard = 0.2 - if v < -guard or v > guard: - ret = 4 / ti.max(v, 0.1) - else: - ret = 0 - y[i] = ret - - -xs = [] -ys = [] -grad_xs = [] - -for i in range(N): - v = ((i + 0.5) / N) * 2 - 1 - xs.append(v) - x[i] = v - -poly() - -print('y') -for i in range(N): - y.grad[i] = 1 - ys.append(y[i]) -print() - -poly.grad() -print('grad_x') -for i in range(N): - grad_xs.append(x.grad[i]) - -plt.title('Auto Diff') -ax = plt.gca() -ax.plot(xs, ys, label='f(x)') -ax.plot(xs, grad_xs, label='f\'(x)') -ax.legend() -ax.grid(True) -ax.spines['left'].set_position('zero') -ax.spines['right'].set_color('none') -ax.spines['bottom'].set_position('zero') -ax.spines['top'].set_color('none') -plt.show() diff --git a/examples/chi_examples/CMakeLists.txt b/examples/chi_examples/CMakeLists.txt deleted file mode 100644 index c83db5cb1895f..0000000000000 --- a/examples/chi_examples/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -cmake_minimum_required(VERSION 3.12) - -set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fsized-deallocation -Wall -march=nehalem -DTI_ARCH_x64 -DTI_INCLUDED -O3 -DNDEBUG") -set(CHI_EXAMPLES chi_examples) - -project(chi_examples) - -add_executable(${CHI_EXAMPLES} main.cpp) -target_include_directories(${CHI_EXAMPLES} - PUBLIC $ENV{TAICHI_REPO_DIR} - PUBLIC $ENV{TAICHI_REPO_DIR}/external/include - PUBLIC $ENV{TAICHI_REPO_DIR}/external/spdlog/include - PUBLIC $ENV{TAICHI_REPO_DIR}/external/glad/include - PUBLIC $ENV{TAICHI_REPO_DIR}/external/glfw/include -) -target_link_libraries(${CHI_EXAMPLES} $ENV{TAICHI_REPO_DIR}/build/libtaichi_export_core.so) diff --git a/examples/chi_examples/README.md b/examples/chi_examples/README.md deleted file mode 100644 index 280a58c5ad5c2..0000000000000 --- a/examples/chi_examples/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# How to use CHI IR Builder - -## Build Taichi - -Follow the steps in `https://docs.taichi.graphics/lang/articles/contribution/dev_install` - -Add option `-DTI_EXPORT_CORE=ON` to your `cmake` command (i.e. use `cmake .. DTI_EXPORT_CORE=ON`). - -Make sure taichi is built under `$TAICHI_REPO_DIR/build` directory. - -After building, `$TAICHI_REPO_DIR/build/libtaichi_export_core.so` should exist. - -## Link with the Taichi Shared Library - -`main.cpp` shows how to construct and run Taichi kernels using CHI IR Builder. - -```bash -mkdir build -cd build -cmake .. -make -./chi_examples -``` diff --git a/examples/rendering/taichi_logo.py b/examples/rendering/taichi_logo.py deleted file mode 100644 index a3ad0a5bd67a6..0000000000000 --- a/examples/rendering/taichi_logo.py +++ /dev/null @@ -1,22 +0,0 @@ -import taichi as ti - -ti.init() - -n = 512 -x = ti.field(dtype=ti.f32, shape=(n, n)) - - -@ti.kernel -def paint(): - for i, j in ti.ndrange(n * 4, n * 4): - # 4x4 super sampling: - ret = ti.taichi_logo(ti.Vector([i, j]) / (n * 4)) - x[i // 4, j // 4] += ret / 16 - - -paint() - -gui = ti.GUI('Logo', (n, n)) -while gui.running: - gui.set_image(x) - gui.show() diff --git a/external/FP16 b/external/FP16 new file mode 160000 index 0000000000000..0a92994d729ff --- /dev/null +++ b/external/FP16 @@ -0,0 +1 @@ +Subproject commit 0a92994d729ff76a58f692d3028ca1b64b145d91 diff --git a/external/SPIRV-Cross b/external/SPIRV-Cross index 97a438d214b24..131278458ea8e 160000 --- a/external/SPIRV-Cross +++ b/external/SPIRV-Cross @@ -1 +1 @@ -Subproject commit 97a438d214b24e4958ca137a18639670648cedd0 +Subproject commit 131278458ea8eebe6a6e9c476fbcf71278726e1a diff --git a/external/SPIRV-Headers b/external/SPIRV-Headers index 5ea2d62e8c0dd..b42ba6d92faf6 160000 --- a/external/SPIRV-Headers +++ b/external/SPIRV-Headers @@ -1 +1 @@ -Subproject commit 5ea2d62e8c0ddd9e2a7d0ca5e3f2335e09e5f408 +Subproject commit b42ba6d92faf6b4938e6f22ddd186dbdacc98d78 diff --git a/external/SPIRV-Reflect b/external/SPIRV-Reflect index 272e050728de8..1aceb6af56e74 160000 --- a/external/SPIRV-Reflect +++ b/external/SPIRV-Reflect @@ -1 +1 @@ -Subproject commit 272e050728de8d4a4ce9e7101c1244e6ff56e5b0 +Subproject commit 1aceb6af56e74b92a00378842dda5c5a73f49a4b diff --git a/external/SPIRV-Tools b/external/SPIRV-Tools index b46995741b97c..845f3efb8a4eb 160000 --- a/external/SPIRV-Tools +++ b/external/SPIRV-Tools @@ -1 +1 @@ -Subproject commit b46995741b97c714e211fe5df8590991ae998475 +Subproject commit 845f3efb8a4eb32e5f484aa6ea9b9e3716d6f7ec diff --git a/external/Vulkan-Headers b/external/Vulkan-Headers new file mode 160000 index 0000000000000..5c0fa1d68f48a --- /dev/null +++ b/external/Vulkan-Headers @@ -0,0 +1 @@ +Subproject commit 5c0fa1d68f48ab7fdb331d75f31efc11aa313090 diff --git a/external/VulkanMemoryAllocator b/external/VulkanMemoryAllocator index b0fce340b6c58..5c710e86a0ce0 160000 --- a/external/VulkanMemoryAllocator +++ b/external/VulkanMemoryAllocator @@ -1 +1 @@ -Subproject commit b0fce340b6c581d2bb75ca6c8c6e55235a52d8e2 +Subproject commit 5c710e86a0ce0664bbc2374c39b956b9928b2f29 diff --git a/external/glad b/external/glad index 7ec5f98f091e5..23685e3caeb24 160000 --- a/external/glad +++ b/external/glad @@ -1 +1 @@ -Subproject commit 7ec5f98f091e5cb082e0f1b4f251ee5cb560552b +Subproject commit 23685e3caeb249de6e082ae513bc1402e68b0643 diff --git a/external/glfw b/external/glfw index 8bc966bbae496..168d6d24b9ce8 160000 --- a/external/glfw +++ b/external/glfw @@ -1 +1 @@ -Subproject commit 8bc966bbae4967a008252f1ac9b625e4dc77ad64 +Subproject commit 168d6d24b9ce85ef3a37114f00b1251b36162c89 diff --git a/external/volk b/external/volk index b4eb550e25615..96fda88ef5a72 160000 --- a/external/volk +++ b/external/volk @@ -1 +1 @@ -Subproject commit b4eb550e2561556db7bf477543d42abd9c4c3217 +Subproject commit 96fda88ef5a7222d7db85e016cbae9343613483b diff --git a/misc/benchmark_bit_struct_stores.py b/misc/benchmark_bit_struct_stores.py index 1f80578f1840c..748b41a06bb9b 100644 --- a/misc/benchmark_bit_struct_stores.py +++ b/misc/benchmark_bit_struct_stores.py @@ -7,7 +7,7 @@ n = 1024 * 1024 * 256 if quant: - ci16 = ti.quant.int(16, True) + ci16 = ti.types.quantized_types.quant.int(16, True) x = ti.field(dtype=ci16) y = ti.field(dtype=ci16) diff --git a/misc/benchmark_rebuild_graph.py b/misc/benchmark_rebuild_graph.py index 8d43870e2c83e..416605fa0411e 100644 --- a/misc/benchmark_rebuild_graph.py +++ b/misc/benchmark_rebuild_graph.py @@ -1,3 +1,5 @@ +from taichi.lang import impl + import taichi as ti ti.init(arch=ti.cuda, async_mode=True) @@ -23,4 +25,4 @@ def foo(): for i in range(1000): foo() -ti.get_runtime().prog.benchmark_rebuild_graph() +impl.get_runtime().prog.benchmark_rebuild_graph() diff --git a/misc/ci_check_pr_title.py b/misc/ci_check_pr_title.py index b6ac85d445562..90945f8dcc46c 100644 --- a/misc/ci_check_pr_title.py +++ b/misc/ci_check_pr_title.py @@ -27,7 +27,8 @@ def get_old_ver(): with open(json_path) as f: prtags = json.load(f) -if not title.startswith('['): +# PR must be properly tagged. The only exception allowed here is the revert pr automatically generated by github. +if not title.startswith('[') and not title.startswith('Revert '): exit(f'PR title does not start with any tag: {title}') if title.endswith(' '): diff --git a/misc/ci_create_pr_card.py b/misc/ci_create_pr_card.py new file mode 100644 index 0000000000000..5794e3f88a565 --- /dev/null +++ b/misc/ci_create_pr_card.py @@ -0,0 +1,112 @@ +import json +import os +from typing import Any, List, Mapping + +from github import Github +from github.Project import Project +from github.Repository import Repository + + +def load_project_map() -> Mapping[str, str]: + with open(os.path.join(os.path.dirname(__file__), + 'tag_to_project.json')) as f: + return json.load(f) + + +PROJECT_MAP = load_project_map() + + +def extract_tags(title: str) -> List[str]: + """ + Extract tags from PR title like "[ci] [bug] fix a bug" + """ + tags: List[str] = [] + for x in title.split('] ')[:-1]: + if x[0] != '[': + raise ValueError(f'No starting [ for tag: {x}]') + tags.append(x[1:].lower()) + return tags + + +def get_project(repo: Repository, name: str) -> Project: + """ + Get project from repository by name + """ + for project in repo.get_projects(): + if project.name == name: + return project + raise ValueError(f'No project with name: {name}') + + +def _create_pr_card(pr: dict, project: Project) -> None: + to_do_column = next(iter(project.get_columns())) + print(f"Creating card for PR #{pr['number']} in project {project.name}") + to_do_column.create_card(content_id=pr['id'], content_type="PullRequest") + + +def _remove_pr_card(pr: dict, project: Project) -> None: + to_do_column = next(iter(project.get_columns())) + for card in to_do_column.get_cards(): + if not card.content_url: + continue + if card.content_url.split('/')[-1] == str(pr['number']): + print(f"Deleting PR #{pr['number']} from project {project.name}") + card.delete() + return + print( + f"PR #{pr['number']} doesn't exist in the To-do column of project {project.name}" + ) + + +def create_pr_card(event: Mapping[str, Any]) -> None: + new_projects = { + PROJECT_MAP[tag] + for tag in extract_tags(event['pull_request']['title']) + if tag in PROJECT_MAP + } + gh = Github(os.environ['GITHUB_TOKEN']) + repo = gh.get_repo(event['repository']['full_name']) + pr = event['pull_request'] + if event['action'] == 'opened': + for project_name in new_projects: + _create_pr_card(pr, get_project(repo, project_name)) + else: + old_title = event.get("changes", {}).get("title", {}).get("from") + if not old_title: + print("PR title isn't changed, nothing to do") + return + old_projects = { + PROJECT_MAP[tag] + for tag in extract_tags(old_title) if tag in PROJECT_MAP + } + to_remove = old_projects - new_projects + to_add = new_projects - old_projects + for project_name in to_remove: + _remove_pr_card(pr, get_project(repo, project_name)) + for project_name in to_add: + _create_pr_card(pr, get_project(repo, project_name)) + + +def main() -> None: + event = json.loads(os.environ['GH_EVENT']) + create_pr_card(event) + + +def test(): + event = { + "action": "opened", + "repository": { + "full_name": "taichi-dev/taichi" + }, + "pull_request": { + "id": 841657847, + "number": 4224, + "title": "[lang] Annotate constants with dtype without casting." + } + } + os.environ["GH_EVENT"] = json.dumps(event) + main() + + +if __name__ == '__main__': + main() diff --git a/misc/ci_download.py b/misc/ci_download.py index 5bbeec6d32b00..3b24ab46bc3ed 100644 --- a/misc/ci_download.py +++ b/misc/ci_download.py @@ -1,4 +1,6 @@ import os +import sys +import urllib.request platform = os.environ['CI_PLATFORM'] if platform.startswith('macos'): @@ -11,5 +13,12 @@ raise Exception(f'Bad CI_PLATFORM={platform}') llvm_url = f'https://github.com/taichi-dev/taichi_assets/releases/download/llvm10/taichi-llvm-10.0.0-{platform}.zip' +target_dir = 'taichi-llvm' print(f'Downloading LLVM from {llvm_url}...') -os.system(f'wget {llvm_url} --waitretry=3 --tries=5 -O taichi-llvm.zip') +urllib.request.urlretrieve(llvm_url, "taichi-llvm.zip") +print(f'Extract zip to local dir {target_dir}...') +if not os.path.exists(target_dir): + os.makedirs(target_dir) + +retcode = os.system(f"unzip taichi-llvm.zip -d {target_dir}") +sys.exit(retcode) diff --git a/misc/code_format.py b/misc/code_format.py index 0b08ff8a81935..a7259bd9c297f 100644 --- a/misc/code_format.py +++ b/misc/code_format.py @@ -130,6 +130,8 @@ def find_diff_or_empty(s): continue if fn.find('docs/build/') != -1: continue + if fn.find(os.path.join('tests', 'python', 'test_exception.py')) != -1: + continue if re.match(r'.*examples\/[a-z_]+\d\d+\.py$', fn): print(f'Skipping example file "{fn}"...') continue diff --git a/misc/demo_external_func.py b/misc/demo_external_func.py deleted file mode 100644 index 09bef0b1c82d3..0000000000000 --- a/misc/demo_external_func.py +++ /dev/null @@ -1,77 +0,0 @@ -import ctypes -import os - -import taichi as ti - -ti.init() - -N = 1024 -x = ti.field(ti.i32, shape=N) -y = ti.field(ti.i32, shape=N) -z = ti.field(ti.i32, shape=N) - -source = ''' -extern "C" { - void add_and_mul(float a, float b, float *c, float *d, int *e) { - *c = a + b; - *d = a * b; - *e = int(a * b + a); - } - void pow_int(int a, int b, int *c) { - int ret = 1; - for (int i = 0; i < b; i++) - ret = ret * a; - *c = ret; - } -} -''' - -with open('a.cpp', 'w') as f: - f.write(source) - -os.system("g++ a.cpp -o a.so -fPIC -shared") - -so = ctypes.CDLL("./a.so") - - -@ti.kernel -def call_ext() -> ti.i32: - a = 2.0 - b = 3.0 - c = 0.0 - d = 0.0 - e = 3 - ti.external_func_call(func=so.add_and_mul, args=(a, b), outputs=(c, d, e)) - p = 0 - ti.external_func_call(func=so.pow_int, args=(int(c + d), e), outputs=(p, )) - return p - - -# Wrap the external function to make it easier to use -@ti.func -def pow_int_wrapper(a, b): - p = 0 - ti.external_func_call(func=so.pow_int, - args=(int(a), int(b)), - outputs=(p, )) - return p - - -@ti.kernel -def call_parallel(): - for i in range(N): - z[i] = pow_int_wrapper(x[i], y[i]) - - -assert call_ext() == 11**8 - -for i in range(N): - x[i] = i - y[i] = 3 - -call_parallel() -for i in range(N): - assert z[i] == i**3 - -os.remove('a.cpp') -os.remove('a.so') diff --git a/misc/demo_excepthook.py b/misc/demo_trackback.py similarity index 90% rename from misc/demo_excepthook.py rename to misc/demo_trackback.py index 7e51a466ccf7e..9f9c3b8d5f735 100644 --- a/misc/demo_excepthook.py +++ b/misc/demo_trackback.py @@ -1,7 +1,6 @@ import taichi as ti ti.init() -ti.enable_excepthook() @ti.func diff --git a/misc/demo_warning.py b/misc/demo_warning.py deleted file mode 100644 index f41bfb1e99628..0000000000000 --- a/misc/demo_warning.py +++ /dev/null @@ -1,14 +0,0 @@ -import taichi as ti - -x = ti.Vector([2, 3]) - -x.transposed(x) - - -@ti.kernel -def func(): - x = 0 - x = 0.1 - - -func() diff --git a/misc/examples.md b/misc/examples.md index 909de5a194163..c714afe28a9d1 100644 --- a/misc/examples.md +++ b/misc/examples.md @@ -1,11 +1,11 @@ # More examples - - - - + + + + - - - - + + + + diff --git a/misc/generate_commit_hash.py b/misc/generate_commit_hash.py deleted file mode 100644 index 1b23d7968ba54..0000000000000 --- a/misc/generate_commit_hash.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from git import Repo - -repo_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../') -repo = Repo(repo_dir) -commit_hash = str(repo.head.commit) -print(f"Building commit {commit_hash}") - -output_fn = os.path.join(repo_dir, 'taichi/common/commit_hash.h') -content = f"#define TI_COMMIT_HASH \"{commit_hash}\"\n" - -# First read the file to see if an update is needed -# This reduces unnecessary file changes/linkings -if os.path.exists(output_fn): - with open(output_fn, 'r') as f: - old_content = f.read() - if old_content == content: - # No update needed - exit(0) - -with open(output_fn, 'w') as f: - f.write(content) diff --git a/misc/generate_example_videos.py b/misc/generate_example_videos.py new file mode 100644 index 0000000000000..d11520cc3cd45 --- /dev/null +++ b/misc/generate_example_videos.py @@ -0,0 +1,24 @@ +import argparse +import os +import re +import subprocess + +parser = argparse.ArgumentParser(description='Generate all videos of examples') +parser.add_argument('output_directory', + help='output directory of generated videos') +output_dir = parser.parse_args().output_directory + +example_root = os.path.join('..', 'tests', 'python', 'examples') +for example_dir in os.listdir(example_root): + full_dir = os.path.join(example_root, example_dir) + if not os.path.isdir(full_dir): + continue + for filename in os.listdir(full_dir): + match = re.match(r'test_(\w+)\.py', filename) + if match: + subprocess.run([ + "python", + os.path.join(full_dir, filename), + os.path.join(output_dir, match.group(1)) + ], + check=True) diff --git a/misc/links.md b/misc/links.md index e8c3ce2c04f78..8b564922781b2 100644 --- a/misc/links.md +++ b/misc/links.md @@ -1,14 +1,11 @@ # Links -- [DiffTaichi](https://github.com/yuanming-hu/difftaichi): 10 differentiable physical simulators built with Taichi differentiable programming, by [Yuanming Hu (yuanming-hu)](https://github.com/yuanming-hu). +- [DiffTaichi](https://github.com/taichi-dev/difftaichi): 10 differentiable physical simulators built with Taichi differentiable programming, by [Yuanming Hu (yuanming-hu)](https://github.com/yuanming-hu). - [Taichi elements](https://github.com/taichi-dev/taichi_elements): A high-performance multi-material continuum physics engine based on Taichi (work in progress). - [taichi-fvm2d-fluid-ns](https://github.com/hejob/taichi-fvm2d-fluid-ns/): 2D FVM Compressible CFD Solver for multiblock structured mesh, by [hejob](https://github.com/hejob). - [karman_tachi](https://github.com/houkensjtu/karman_taichi): An incompressible fluid solver using FVM that simulates a Karman vortex street, by [houkensjtu](https://github.com/houkensjtu). - [TaichiMD](https://github.com/victoriacity/taichimd): Interactive, GPU-accelerated Molecular (& Macroscopic) Dynamics using Taichi, by [Andrew Sun (victoriacity)](https://github.com/victoriacity). - [LBM Taichi](https://github.com/hietwll/LBM_Taichi): A fluid solver based on the Lattice Boltzmann Method (LBM) using Taichi, by [Zhuo Wang (hietwll)](https://github.com/hietwll). - [Gitee mirror of Taichi](https://gitee.com/mirrors/Taichi): For the convenience of Chinese contributors, clone from the mirror repo hosted on Gitee (码云). -- [Taichi THREE](https://github.com/taichi-dev/taichi_three): A 3D rendering library based on Taichi. -- [Taichi GLSL](https://github.com/taichi-dev/taichi_glsl): A Taichi extension library that provides a set of GLSL-style helper functions. -- [Taichi Blend](https://github.com/taichi-dev/taichi_blend): Taichi Blender intergration for physics-based animations (work in progress) -- [Taichi.js](https://github.com/taichi-dev/taichi.js): Run compiled Taichi programs in Javascript and WASM (work in progress). +- [Taichi GLSL](https://github.com/taichi-dev/taichi_glsl): A Taichi extension library, which provides a set of GLSL-style helper functions. - [Shadertoy in Taichi](https://github.com/Phonicavi/Shadertoy-taichi): Some shadertoy examples implemented in Taichi, by [Qiu Feng (Phonicavi)](https://github.com/Phonicavi). diff --git a/misc/make_changelog.py b/misc/make_changelog.py index 6007db5ace5ad..4a2032d10fd34 100644 --- a/misc/make_changelog.py +++ b/misc/make_changelog.py @@ -17,7 +17,7 @@ def load_pr_tags(): return details -def main(ver='master', repo_dir='.'): +def main(ver=None, repo_dir='.'): g = Repo(repo_dir) commits_with_tags = set([tag.commit for tag in g.tags]) commits = list(g.iter_commits(ver, max_count=200)) @@ -33,11 +33,8 @@ def format(c): for i, c in enumerate(commits): s = format(c) - if c in commits_with_tags: - if i == 0: - continue - else: - break + if c in commits_with_tags and i > 0: + break tags = [] while s[0] == '[': @@ -78,11 +75,11 @@ def format(c): if __name__ == '__main__': - ver = sys.argv[1] if len(sys.argv) > 1 else 'master' + ver = sys.argv[1] if len(sys.argv) > 1 else None repo = sys.argv[2] if len(sys.argv) > 2 else '.' save = sys.argv[3] if len(sys.argv) > 3 else False res = main(ver, repo) if save: - with open('../python/taichi/CHANGELOG.md', 'w') as f: + with open('./python/taichi/CHANGELOG.md', 'w') as f: f.write(res) print(res) diff --git a/misc/prtags.json b/misc/prtags.json index e93b009704ddb..c4e1d3ff73fab 100644 --- a/misc/prtags.json +++ b/misc/prtags.json @@ -1,8 +1,10 @@ { + "javascript" : "Taichi in javascript", "ci" : "CI/CD workflow", "cpu" : "CPU backends", "cuda" : "CUDA backend", "doc" : "Documentation", + "docs" : "Documentation", "infra" : "Infrastructure", "cli" : "Command line interface", "ir" : "Intermediate representation", @@ -11,6 +13,8 @@ "metal" : "Metal backend", "opengl" : "OpenGL backend", "vulkan" : "Vulkan backend", + "dx11" : "DirectX 11 backend", + "spirv" : "SPIR-V common codegen", "wasm" : "WebAssembly backend", "misc" : "Miscellaneous", "std" : "Standard library", @@ -26,8 +30,10 @@ "test" : "Tests", "benchmark" : "Benchmarking", "async" : "AsyncEngine", + "mesh" : "MeshTaichi", "workflow" : "GitHub Actions/Workflows", "linux" : "Linux", + "android" : "Android", "mac" : "Mac OS X", "windows" : "Windows", "docker" : "Docker container", @@ -38,5 +44,6 @@ "blender" : "Blender intergration", "export" : "Exporting kernels", "type" : "Type system", - "release" : "Release" + "release" : "Release", + "build" : "Build system" } diff --git a/misc/save_new_version.py b/misc/save_new_version.py new file mode 100644 index 0000000000000..1db89916a374e --- /dev/null +++ b/misc/save_new_version.py @@ -0,0 +1,47 @@ +import os +from datetime import date + +import requests + +version = os.getenv('RELEASE_VERSION') +version = version[1:] +version_num = version.split('.') +major = int(version_num[0]) +minor = int(version_num[1]) +patch = int(version_num[2]) +release_date = date.today().strftime('%Y-%m-%d') + +payload = { + 'version': version, + 'major': major, + 'minor': minor, + 'patch': patch, + 'date': release_date +} + +username = os.getenv('METADATA_USERNAME') +password = os.getenv('METADATA_PASSWORD') +url = os.getenv('METADATA_URL') + +try: + response = requests.post(f'https://{url}/add_version/main', + json=payload, + auth=(username, password), + timeout=5) + response.raise_for_status() +except requests.exceptions.ConnectionError as err: + print('Updating latest version failed: No internet,', err) + exit(1) +except requests.exceptions.HTTPError as err: + print('Updating latest version failed: Server error,', err) + exit(1) +except requests.exceptions.Timeout as err: + print('Updating latest version failed: Time out when connecting server,', + err) + exit(1) +except requests.exceptions.RequestException as err: + print('Updating latest version failed:', err) + exit(1) + +response = response.json() +print(response['message']) diff --git a/misc/spMv_linear_solve.py b/misc/spMv_linear_solve.py index 87bea0728673e..721648f45bb8c 100644 --- a/misc/spMv_linear_solve.py +++ b/misc/spMv_linear_solve.py @@ -9,7 +9,7 @@ @ti.kernel -def fill(A: ti.linalg.sparse_matrix_builder(), b: ti.template(), +def fill(A: ti.types.sparse_matrix_builder(), b: ti.template(), interval: ti.i32): for i in range(n): A[i, i] += 2.0 diff --git a/misc/sparse_matrix.py b/misc/sparse_matrix.py index 49cac7af58e87..a56054f7aedcd 100644 --- a/misc/sparse_matrix.py +++ b/misc/sparse_matrix.py @@ -9,8 +9,8 @@ @ti.kernel -def fill(A: ti.linalg.sparse_matrix_builder(), - b: ti.linalg.sparse_matrix_builder(), interval: ti.i32): +def fill(A: ti.types.sparse_matrix_builder(), + b: ti.types.sparse_matrix_builder(), interval: ti.i32): for i in range(n): if i > 0: A[i - 1, i] += -1.0 diff --git a/misc/tag_to_project.json b/misc/tag_to_project.json new file mode 100644 index 0000000000000..23bffd761cd9a --- /dev/null +++ b/misc/tag_to_project.json @@ -0,0 +1,12 @@ +{ + "aot": "AOT", + "autodiff": "Autodiff", + "benchmark": "Backend Performance", + "build": "CI/CD & Build & Tests", + "ci": "CI/CD & Build & Tests", + "doc": "Docs & Examples & Tutorials", + "docs": "Docs & Examples & Tutorials", + "gui": "GGUI", + "ir": "Compiler Frontend & Middle-end", + "lang": "Lang Features & Python" +} diff --git a/misc/test_poly_timed.py b/misc/test_poly_timed.py index 368798185fa04..b104d6c6c8adc 100644 --- a/misc/test_poly_timed.py +++ b/misc/test_poly_timed.py @@ -1,7 +1,7 @@ from autograd import grad +from taichi._testing import approx import taichi as ti -from taichi import approx # Note: test happens at v = 0.2 diff --git a/misc/upload_release.py b/misc/upload_release.py new file mode 100644 index 0000000000000..d240c7b4bc201 --- /dev/null +++ b/misc/upload_release.py @@ -0,0 +1,65 @@ +import os +import subprocess +import sys + +import requests + + +def upload_taichi_version(): + username = os.getenv('METADATA_USERNAME') + password = os.getenv('METADATA_PASSWORD') + url = os.getenv('METADATA_URL') + for filename in os.listdir('./dist'): + filename = filename[:len(filename) - 4] + parts = filename.split('-') + payload = { + 'version': parts[1], + 'platform': parts[4], + 'python': parts[2] + } + try: + response = requests.post(f'https://{url}/add_version/detail', + json=payload, + auth=(username, password), + timeout=5) + response.raise_for_status() + except requests.exceptions.ConnectionError as err: + print('Updating latest version failed: No internet,', err) + except requests.exceptions.HTTPError as err: + print('Updating latest version failed: Server error,', err) + except requests.exceptions.Timeout as err: + print( + 'Updating latest version failed: Time out when connecting server,', + err) + except requests.exceptions.RequestException as err: + print('Updating latest version failed:', err) + else: + response = response.json() + print(response['message']) + + +def upload_artifact(is_taichi): + pwd_env = 'PROD_PWD' if is_taichi else 'NIGHT_PWD' + twine_password = os.getenv(pwd_env) + if not twine_password: + sys.exit(f'Missing password env var {pwd_env}') + command = [sys.executable, '-m', 'twine', 'upload'] + if not is_taichi: + command.extend(['--repository', 'testpypi']) + command.extend( + ['--verbose', '-u', '__token__', '-p', twine_password, 'dist/*']) + try: + subprocess.check_call(command) + except subprocess.CalledProcessError as e: + sys.exit(f"Twine upload returns error {e.returncode}") + + +if __name__ == '__main__': + if os.getenv('GITHUB_REPOSITORY', + 'taichi-dev/taichi') != 'taichi-dev/taichi': + print('This script should be run from taichi repo') + sys.exit(0) + is_taichi = os.getenv('PROJECT_NAME', 'taichi') == 'taichi' + upload_artifact(is_taichi) + if is_taichi: + upload_taichi_version() diff --git a/misc/visualize_quant_types.py b/misc/visualize_quant_types.py index 24767428cc390..6f51d5ab03ada 100644 --- a/misc/visualize_quant_types.py +++ b/misc/visualize_quant_types.py @@ -7,9 +7,9 @@ ti.init() -f19 = ti.quant.float(exp=6, frac=13, signed=True) -f16 = ti.quant.float(exp=5, frac=11, signed=True) -fixed16 = ti.quant.fixed(frac=16, range=2) +f19 = ti.types.quantized_types.quant.float(exp=6, frac=13, signed=True) +f16 = ti.types.quantized_types.quant.float(exp=5, frac=11, signed=True) +fixed16 = ti.types.quantized_types.quant.fixed(frac=16, range=2) vf19 = ti.Vector.field(2, dtype=f19) bs_vf19 = ti.root.bit_struct(num_bits=32) diff --git a/misc/windows_build.py b/misc/windows_build.py deleted file mode 100644 index b7897054f39cc..0000000000000 --- a/misc/windows_build.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import shutil -import sys - - -def execute_cmd(cmd): - print('Executing', resolve_env(cmd)) - return os.system(cmd) - - -def resolve_env(v): - # replace `%` - modified = True - while modified: - modified = False - for i in range(len(v)): - if v[i] == '%': - for j in range(i + 1, len(v)): - if v[j] == '%': - var = v[i + 1:j] - v = v[:i] + os.environ[var] + v[j + 1:] - modified = True - print(v) - break - break - return v - - -def set_env(**kwargs): - for k, v in kwargs.items(): - v = resolve_env(v) - print(f"Setting {k} to '{v}'") - os.environ[k] = v - - -repo_dir = "E:\\repos\\taichi" -assert len( - sys.argv -) == 3, 'Usage: windows_build.py [python_executable] [cuda_version=10.X]' -python_executable = sys.argv[1] -cuda_version = sys.argv[2] -assert cuda_version in ["10.0", "10.1", "10.2"] -print("Python =", python_executable) -print("CUDA =", cuda_version) -set_env(PYTHON=python_executable) -set_env(TAICHI_REPO_DIR=repo_dir) -set_env(PYTHONPATH="%TAICHI_REPO_DIR%\\python") -set_env(PATH=r"%TAICHI_REPO_DIR%\bin;%PATH%") -execute_cmd("clang --version") - -os.chdir(repo_dir) -build_dir = os.path.join(repo_dir, 'build') -if os.path.exists(build_dir): - shutil.rmtree(build_dir) -os.mkdir(build_dir) -os.chdir(build_dir) -llvm_dir = "E:\\repos\\llvm-8.0.1\\build\\installed\\lib\\cmake\\llvm" -cuda_dir = f"C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v{cuda_version}" -execute_cmd( - f'cmake .. -G"Visual Studio 15 2017 Win64" -DPYTHON_EXECUTABLE="%PYTHON%" -DLLVM_DIR="{llvm_dir}" -DTI_WITH_CUDA:BOOL=True -DCUDA_VERSION={cuda_version} -DCUDA_DIR="f{cuda_dir}"' -) -execute_cmd( - r'msbuild /p:Configuration=RelWithDebInfo /p:Platform=x64 /m taichi.sln') -os.chdir(repo_dir) -execute_cmd('%PYTHON% -c "import taichi"') -execute_cmd('%PYTHON% examples/laplace.py') -execute_cmd('%PYTHON% bin/taichi test') -os.chdir(os.path.join(repo_dir, 'python')) -execute_cmd('%PYTHON% build.py try_upload') diff --git a/netlify.toml b/netlify.toml index 461c6e304a23b..d5c51c2b39cee 100644 --- a/netlify.toml +++ b/netlify.toml @@ -1,8 +1,7 @@ [build] - command = "git clone https://github.com/taichi-dev/docs.taichi.graphics.git; rm -rf docs.taichi.graphics/website/docs/lang; cp -rf docs/lang docs.taichi.graphics/website/docs/lang; cd docs.taichi.graphics/website; npm install --global yarn@1.22; yarn install; yarn build" + command = "git clone https://github.com/taichi-dev/docs.taichi.graphics.git; rm -rf docs.taichi.graphics/website/docs/lang; cp -rf docs/lang docs.taichi.graphics/website/docs/lang; git clone https://github.com/taichi-dev/docstring-gen docsgen; export DOCSTRING_GEN_PATH=\"$(pwd)/docsgen\"; export TAICHI_PATH=\"$(pwd)/python/taichi\"; export TAICHI_WEBSITE=\"$(pwd)/docs.taichi.graphics\"; pip install sphinx-autoapi==1.8.4 gitpython pydata-sphinx-theme==0.7.2; cd $DOCSTRING_GEN_PATH/experimental; export current_version=master; make clean; make version; make apideploy; cd $TAICHI_WEBSITE/website; npm install --global yarn@1.22; yarn install; yarn build; yarn run apiversion;" publish = "docs.taichi.graphics/website/build" - # Cancel the build if there're no changes detected in docs/ folder. - ignore = "git remote add upstream https://github.com/taichi-dev/taichi.git; git fetch upstream master; git diff --quiet $COMMIT_REF upstream/master -- docs/" + ignore = "git remote add upstream https://github.com/taichi-dev/taichi.git; git fetch upstream master; git diff --quiet $COMMIT_REF upstream/master -- docs/ python/" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000..facad75d3914b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel", "numpy", "pybind11", "cmake"] +build-backend = "setuptools.build_meta" diff --git a/python/.gitignore b/python/.gitignore index 12389e8d49a2f..1a30ec8f567f0 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,7 +1,9 @@ -lib +taichi/_lib/runtime +taichi/_lib/core/*.so +taichi/_lib/core/*.pyd taichi.egg-info taichi/include -taichi/examples taichi/assets taichi/tests +taichi/tests38 release diff --git a/python/build.py b/python/build.py deleted file mode 100644 index e320fb95a1023..0000000000000 --- a/python/build.py +++ /dev/null @@ -1,148 +0,0 @@ -import argparse -import os -import platform -import re -import shutil -import sys - - -def get_os_name(): - name = platform.platform() - # in python 3.8, platform.platform() uses mac_ver() on macOS - # it will return 'macOS-XXXX' instead of 'Darwin-XXXX' - if name.lower().startswith('darwin') or name.lower().startswith('macos'): - return 'osx' - elif name.lower().startswith('windows'): - return 'win' - elif name.lower().startswith('linux'): - return 'linux' - assert False, "Unknown platform name %s" % name - - -def get_python_executable(): - return '"' + sys.executable.replace('\\', '/') + '"' - - -def build(project_name): - """Build and package the wheel file in root `dist` dir""" - if platform.system() == 'Linux': - if re.search("^clang\+\+-*\d*", str(os.environ.get('CXX'))) is None: - raise RuntimeError( - 'Only the wheel with clang will be released to PyPI') - - print("Using python executable", get_python_executable()) - os.system( - '{} -m pip install --user --upgrade twine setuptools wheel'.format( - get_python_executable())) - - os.system( - f'{get_python_executable()} ../misc/make_changelog.py origin/master ../ True' - ) - - # This env var is used in setup.py below. - os.environ['PROJECT_NAME'] = project_name - project_tag = '' - if project_name == 'taichi-nightly': - project_tag = 'egg_info --tag-date' - if get_os_name() == 'linux': - os.system( - f'cd ..; {get_python_executable()} setup.py {project_tag} bdist_wheel -p manylinux1_x86_64' - ) - else: - os.system( - f'cd .. && {get_python_executable()} setup.py {project_tag} bdist_wheel' - ) - - try: - os.remove('taichi/CHANGELOG.md') - except FileNotFoundError: - pass - - -def parse_args(): - parser = argparse.ArgumentParser(description=( - 'Build and uploads wheels to PyPI. Make sure to run this script ' - 'inside `python/`')) - parser.add_argument('mode', - type=str, - default='', - help=('Choose one of the modes: ' - '[build, test, try_upload, upload]')) - parser.add_argument('--skip_build', - action='store_true', - help=('Skip the build process if this is enabled')) - parser.add_argument('--testpypi', - action='store_true', - help='Upload to test server if this is enabled') - parser.add_argument('--project_name', - action='store', - dest='project_name', - default='taichi', - help='Set the project name') - return parser.parse_args() - - -def main(): - args = parse_args() - mode = args.mode - pypi_user = '__token__' - pypi_repo = '' - project_name = args.project_name - - env_pypi_pwd = os.environ.get('PYPI_PWD', '') - - if not args.skip_build: - shutil.rmtree('../dist', ignore_errors=True) - - if mode == 'try_upload': - if env_pypi_pwd == '': - print("Missing environment variable PYPI_PWD") - print("Giving up and exiting 0 [try_upload mode]") - exit(0) - mode = 'upload' - - if mode == 'upload' and env_pypi_pwd == '': - raise RuntimeError("Missing environment variable PYPI_PWD") - - os.environ['TWINE_PASSWORD'] = env_pypi_pwd - - if mode == 'upload' and args.testpypi: - pypi_repo = '--repository testpypi' - - if not args.skip_build: - build(project_name) - - if mode == 'build': - return - elif mode == 'upload': - os.system('{} -m twine upload {} ../dist/* --verbose -u {}'.format( - get_python_executable(), pypi_repo, pypi_user)) - elif mode == 'test': - print('Uninstalling old taichi packages...') - os.system( - f'{get_python_executable()} -m pip uninstall -y taichi-nightly') - os.system(f'{get_python_executable()} -m pip uninstall -y taichi') - dists = os.listdir('../dist') - assert len(dists) == 1 - dist = dists[0] - print('Installing ', dist) - os.environ['PYTHONPATH'] = '' - os.makedirs('test_env', exist_ok=True) - os.system( - 'cd test_env && {} -m pip install ../../dist/{} --user'.format( - get_python_executable(), dist)) - print('Entering test environment...') - if get_os_name() == 'win': - os.system( - 'cmd /V /C "set PYTHONPATH=&& set TAICHI_REPO_DIR=&& cd test_env && cmd"' - ) - else: - os.system( - 'cd test_env && PYTHONPATH= TAICHI_REPO_DIR= bash --noprofile --norc ' - ) - else: - raise ValueError("Unknown mode: %s" % mode) - - -if __name__ == '__main__': - main() diff --git a/python/make_release.py b/python/make_release.py deleted file mode 100644 index 0f61f316e98f2..0000000000000 --- a/python/make_release.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import shutil -import zipfile - -import requests - -projects = ['nightly', 'nightly-cuda-10-0', 'nightly-cuda-10-1'] - - -def download(url): - fn = url.split('/')[-1] - with requests.get(url, stream=True) as r: - with open(fn, 'wb') as f: - shutil.copyfileobj(r.raw, f) - return fn - - -for p in projects: - pkg_name_dash = f'taichi-{p}' - pkg_name_underscore = pkg_name_dash.replace('-', '_') - package = requests.get( - f"https://pypi.python.org/pypi/{pkg_name_dash}/json").json() - version = '0.0.75' - wheels = package["releases"][version] - for wheel in wheels: - py_ver = wheel['python_version'] - print(py_ver, wheel['url']) - fn = download(wheel['url']) - folder = wheel['python_version'] + '-' + fn[:-4] - package_extracted_folder = f"release/{folder}" - with zipfile.ZipFile(fn, 'r') as zip_ref: - zip_ref.extractall(package_extracted_folder) - os.remove(fn) - - pkg_ver = f"{pkg_name_underscore}-{version}" - shutil.make_archive( - f'release/{folder}', 'zip', - f'release/{folder}/{pkg_ver}.data/purelib/taichi/lib') - shutil.rmtree(package_extracted_folder) diff --git a/python/taichi/__init__.py b/python/taichi/__init__.py index 8823b50caee62..7d834259bcba4 100644 --- a/python/taichi/__init__.py +++ b/python/taichi/__init__.py @@ -1,48 +1,69 @@ import sys -import taichi.ad as ad +from taichi._funcs import * +from taichi._lib import core as _ti_core from taichi._logging import * -from taichi.core import get_os_name, package_root, require_version -from taichi.core import ti_core as core -from taichi.lang import * # TODO(archibate): It's `taichi.lang.core` overriding `taichi.core` -from taichi.main import main -from taichi.misc import * -from taichi.testing import * -from taichi.tools import * -from taichi.torch_io import from_torch, to_torch -from taichi.type import * - -import taichi.ui as ui +from taichi._snode import * +from taichi.lang import * # pylint: disable=W0622 # TODO(archibate): It's `taichi.lang.core` overriding `taichi.core` +from taichi.types.annotations import * +# Provide a shortcut to types since they're commonly used. +from taichi.types.primitive_types import * + +from taichi import ad, experimental, linalg, tools +from taichi.ui import GUI, hex_to_rgb, rgb_to_hex, ui # Issue#2223: Do not reorder, or we're busted with partially initialized module from taichi import aot # isort:skip -deprecated_names = {'SOA': 'Layout.SOA', 'AOS': 'Layout.AOS'} +__deprecated_names__ = { + 'SOA': 'Layout.SOA', + 'AOS': 'Layout.AOS', + 'print_profile_info': 'profiler.print_scoped_profiler_info', + 'clear_profile_info': 'profiler.clear_scoped_profiler_info', + 'print_memory_profile_info': 'profiler.print_memory_profiler_info', + 'CuptiMetric': 'profiler.CuptiMetric', + 'get_predefined_cupti_metrics': 'profiler.get_predefined_cupti_metrics', + 'print_kernel_profile_info': 'profiler.print_kernel_profiler_info', + 'query_kernel_profile_info': 'profiler.query_kernel_profiler_info', + 'clear_kernel_profile_info': 'profiler.clear_kernel_profiler_info', + 'kernel_profiler_total_time': 'profiler.get_kernel_profiler_total_time', + 'set_kernel_profiler_toolkit': 'profiler.set_kernel_profiler_toolkit', + 'set_kernel_profile_metrics': 'profiler.set_kernel_profiler_metrics', + 'collect_kernel_profile_metrics': + 'profiler.collect_kernel_profiler_metrics', + 'VideoManager': 'tools.VideoManager', + 'PLYWriter': 'tools.PLYWriter', + 'imread': 'tools.imread', + 'imresize': 'tools.imresize', + 'imshow': 'tools.imshow', + 'imwrite': 'tools.imwrite', + 'quant': 'types.quantized_types.quant', + 'type_factory': 'types.quantized_types.type_factory' +} + if sys.version_info.minor < 7: - for name, alter in deprecated_names.items(): + for name, alter in __deprecated_names__.items(): exec(f'{name} = {alter}') else: def __getattr__(attr): - if attr in deprecated_names: - warning('ti.{} is deprecated. Please use ti.{} instead.'.format( - attr, deprecated_names[attr]), - DeprecationWarning, - stacklevel=2) - exec(f'{attr} = {deprecated_names[attr]}') + # There's no easy way to hook accessing attribute with function calls in python3.6. + # So let's skip it for now. + import warnings # pylint: disable=C0415,W0621 + if attr == 'cfg': + return None if lang.impl.get_runtime( + ).prog is None else lang.impl.current_cfg() + if attr in __deprecated_names__: + warnings.warn( + f'ti.{attr} is deprecated. Please use ti.{__deprecated_names__[attr]} instead.', + DeprecationWarning) + exec(f'{attr} = {__deprecated_names__[attr]}') return locals()[attr] raise AttributeError(f"module '{__name__}' has no attribute '{attr}'") -__all__ = [ - 'ad', 'core', 'misc', 'lang', 'tools', 'main', 'torch_io', 'ui', 'profiler' -] - -complex_kernel = deprecated('ti.complex_kernel', - 'ti.ad.grad_replaced')(ad.grad_replaced) - -complex_kernel_grad = deprecated('ti.complex_kernel_grad', - 'ti.ad.grad_for')(ad.grad_for) +__version__ = (_ti_core.get_version_major(), _ti_core.get_version_minor(), + _ti_core.get_version_patch()) -__version__ = (core.get_version_major(), core.get_version_minor(), - core.get_version_patch()) +del sys +del _ti_core diff --git a/python/taichi/__main__.py b/python/taichi/__main__.py index 5d6a8109e6ce6..925fa9116efa4 100644 --- a/python/taichi/__main__.py +++ b/python/taichi/__main__.py @@ -1,3 +1,3 @@ -from .main import main +from ._main import main main() diff --git a/python/taichi/_funcs.py b/python/taichi/_funcs.py new file mode 100644 index 0000000000000..0d9e0392e5223 --- /dev/null +++ b/python/taichi/_funcs.py @@ -0,0 +1,437 @@ +import math + +from taichi.lang import impl, matrix, ops +from taichi.lang.impl import expr_init, get_runtime, static +from taichi.lang.kernel_impl import func, pyfunc +from taichi.lang.matrix import Matrix, Vector +from taichi.types import f32, f64 + + +@func +def _randn(dt): + """ + Generate a random float sampled from univariate standard normal + (Gaussian) distribution of mean 0 and variance 1, using the + Box-Muller transformation. + """ + assert dt == f32 or dt == f64 + u1 = ops.cast(1.0, dt) - ops.random(dt) + u2 = ops.random(dt) + r = ops.sqrt(-2 * ops.log(u1)) + c = ops.cos(math.tau * u2) + return r * c + + +def randn(dt=None): + """Generate a random float sampled from univariate standard normal + (Gaussian) distribution of mean 0 and variance 1, using the + Box-Muller transformation. Must be called in Taichi scope. + + Args: + dt (DataType): Data type of the required random number. Default to `None`. + If set to `None` `dt` will be determined dynamically in runtime. + + Returns: + The generated random float. + + Example:: + + >>> @ti.kernel + >>> def main(): + >>> print(ti.randn()) + >>> + >>> main() + -0.463608 + """ + if dt is None: + dt = impl.get_runtime().default_fp + return _randn(dt) + + +@pyfunc +def _matrix_transpose(mat): + """Permute the first two axes of the matrix. + + Args: + mat (:class:`~taichi.lang.matrix.Matrix`): Input matrix. + + Returns: + Transpose of the input matrix. + """ + return matrix.Matrix([[mat[i, j] for i in range(mat.n)] + for j in range(mat.m)]) + + +@pyfunc +def _matrix_cross3d(self, other): + return matrix.Matrix([ + self[1] * other[2] - self[2] * other[1], + self[2] * other[0] - self[0] * other[2], + self[0] * other[1] - self[1] * other[0], + ]) + + +@pyfunc +def _matrix_cross2d(self, other): + return self[0] * other[1] - self[1] * other[0] + + +@pyfunc +def _matrix_outer_product(self, other): + """Perform the outer product with the input Vector (1-D Matrix). + + Args: + other (:class:`~taichi.lang.matrix.Matrix`): The input Vector (1-D Matrix) to perform the outer product. + + Returns: + :class:`~taichi.lang.matrix.Matrix`: The outer product result (Matrix) of the two Vectors. + + """ + impl.static( + impl.static_assert(self.m == 1, + "lhs for outer_product is not a vector")) + impl.static( + impl.static_assert(other.m == 1, + "rhs for outer_product is not a vector")) + return matrix.Matrix([[self[i] * other[j] for j in range(other.n)] + for i in range(self.n)]) + + +@func +def polar_decompose2d(A, dt): + """Perform polar decomposition (A=UP) for 2x2 matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. + + Args: + A (ti.Matrix(2, 2)): input 2x2 matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed 2x2 matrices `U` and `P`. + """ + x, y = A(0, 0) + A(1, 1), A(1, 0) - A(0, 1) + scale = (1.0 / ops.sqrt(x * x + y * y)) + c = x * scale + s = y * scale + r = Matrix([[c, -s], [s, c]], dt=dt) + return r, r.transpose() @ A + + +@func +def polar_decompose3d(A, dt): + """Perform polar decomposition (A=UP) for 3x3 matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. + + Args: + A (ti.Matrix(3, 3)): input 3x3 matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed 3x3 matrices `U` and `P`. + """ + U, sig, V = svd(A, dt) + return U @ V.transpose(), V @ sig @ V.transpose() + + +# https://www.seas.upenn.edu/~cffjiang/research/svd/svd.pdf +@func +def svd2d(A, dt): + """Perform singular value decomposition (A=USV^T) for 2x2 matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. + + Args: + A (ti.Matrix(2, 2)): input 2x2 matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed 2x2 matrices `U`, 'S' and `V`. + """ + R, S = polar_decompose2d(A, dt) + c, s = ops.cast(0.0, dt), ops.cast(0.0, dt) + s1, s2 = ops.cast(0.0, dt), ops.cast(0.0, dt) + if abs(S[0, 1]) < 1e-5: + c, s = 1, 0 + s1, s2 = S[0, 0], S[1, 1] + else: + tao = ops.cast(0.5, dt) * (S[0, 0] - S[1, 1]) + w = ops.sqrt(tao**2 + S[0, 1]**2) + t = ops.cast(0.0, dt) + if tao > 0: + t = S[0, 1] / (tao + w) + else: + t = S[0, 1] / (tao - w) + c = 1 / ops.sqrt(t**2 + 1) + s = -t * c + s1 = c**2 * S[0, 0] - 2 * c * s * S[0, 1] + s**2 * S[1, 1] + s2 = s**2 * S[0, 0] + 2 * c * s * S[0, 1] + c**2 * S[1, 1] + V = Matrix.zero(dt, 2, 2) + if s1 < s2: + tmp = s1 + s1 = s2 + s2 = tmp + V = Matrix([[-s, c], [-c, -s]], dt=dt) + else: + V = Matrix([[c, s], [-s, c]], dt=dt) + U = R @ V + return U, Matrix([[s1, ops.cast(0, dt)], [ops.cast(0, dt), s2]], dt=dt), V + + +def svd3d(A, dt, iters=None): + """Perform singular value decomposition (A=USV^T) for 3x3 matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. + + Args: + A (ti.Matrix(3, 3)): input 3x3 matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + iters (int): iteration number to control algorithm precision. + + Returns: + Decomposed 3x3 matrices `U`, 'S' and `V`. + """ + assert A.n == 3 and A.m == 3 + inputs = tuple([e.ptr for e in A.entries]) + assert dt in [f32, f64] + if iters is None: + if dt == f32: + iters = 5 + else: + iters = 8 + if dt == f32: + rets = get_runtime().prog.current_ast_builder().sifakis_svd_f32( + *inputs, iters) + else: + rets = get_runtime().prog.current_ast_builder().sifakis_svd_f64( + *inputs, iters) + assert len(rets) == 21 + U_entries = rets[:9] + V_entries = rets[9:18] + sig_entries = rets[18:] + U = expr_init(Matrix.zero(dt, 3, 3)) + V = expr_init(Matrix.zero(dt, 3, 3)) + sigma = expr_init(Matrix.zero(dt, 3, 3)) + for i in range(3): + for j in range(3): + U(i, j)._assign(U_entries[i * 3 + j]) + V(i, j)._assign(V_entries[i * 3 + j]) + sigma(i, i)._assign(sig_entries[i]) + return U, sigma, V + + +@func +def eig2x2(A, dt): + """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. + + Args: + A (ti.Matrix(2, 2)): input 2x2 matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + eigenvalues (ti.Matrix(2, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part. + eigenvectors: (ti.Matrix(4, 2)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of 2 entries, each of which is represented by two numbers for its real part and imaginary part. + """ + tr = A.trace() + det = A.determinant() + gap = tr**2 - 4 * det + lambda1 = Vector.zero(dt, 2) + lambda2 = Vector.zero(dt, 2) + v1 = Vector.zero(dt, 4) + v2 = Vector.zero(dt, 4) + if gap > 0: + lambda1 = Vector([tr + ops.sqrt(gap), 0.0], dt=dt) * 0.5 + lambda2 = Vector([tr - ops.sqrt(gap), 0.0], dt=dt) * 0.5 + A1 = A - lambda1[0] * Matrix.identity(dt, 2) + A2 = A - lambda2[0] * Matrix.identity(dt, 2) + if all(A1 == Matrix.zero(dt, 2, 2)) and all( + A1 == Matrix.zero(dt, 2, 2)): + v1 = Vector([0.0, 0.0, 1.0, 0.0]).cast(dt) + v2 = Vector([1.0, 0.0, 0.0, 0.0]).cast(dt) + else: + v1 = Vector([A2[0, 0], 0.0, A2[1, 0], 0.0], dt=dt).normalized() + v2 = Vector([A1[0, 0], 0.0, A1[1, 0], 0.0], dt=dt).normalized() + else: + lambda1 = Vector([tr, ops.sqrt(-gap)], dt=dt) * 0.5 + lambda2 = Vector([tr, -ops.sqrt(-gap)], dt=dt) * 0.5 + A1r = A - lambda1[0] * Matrix.identity(dt, 2) + A1i = -lambda1[1] * Matrix.identity(dt, 2) + A2r = A - lambda2[0] * Matrix.identity(dt, 2) + A2i = -lambda2[1] * Matrix.identity(dt, 2) + v1 = Vector([A2r[0, 0], A2i[0, 0], A2r[1, 0], A2i[1, 0]], + dt=dt).normalized() + v2 = Vector([A1r[0, 0], A1i[0, 0], A1r[1, 0], A1i[1, 0]], + dt=dt).normalized() + eigenvalues = Matrix.rows([lambda1, lambda2]) + eigenvectors = Matrix.cols([v1, v2]) + + return eigenvalues, eigenvectors + + +@func +def sym_eig2x2(A, dt): + """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real symmetric matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. + + Args: + A (ti.Matrix(2, 2)): input 2x2 symmetric matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + eigenvalues (ti.Vector(2)): The eigenvalues. Each entry store one eigen value. + eigenvectors (ti.Matrix(2, 2)): The eigenvectors. Each column stores one eigenvector. + """ + tr = A.trace() + det = A.determinant() + gap = tr**2 - 4 * det + lambda1 = (tr + ops.sqrt(gap)) * 0.5 + lambda2 = (tr - ops.sqrt(gap)) * 0.5 + eigenvalues = Vector([lambda1, lambda2], dt=dt) + + A1 = A - lambda1 * Matrix.identity(dt, 2) + A2 = A - lambda2 * Matrix.identity(dt, 2) + v1 = Vector.zero(dt, 2) + v2 = Vector.zero(dt, 2) + if all(A1 == Matrix.zero(dt, 2, 2)) and all(A1 == Matrix.zero(dt, 2, 2)): + v1 = Vector([0.0, 1.0]).cast(dt) + v2 = Vector([1.0, 0.0]).cast(dt) + else: + v1 = Vector([A2[0, 0], A2[1, 0]], dt=dt).normalized() + v2 = Vector([A1[0, 0], A1[1, 0]], dt=dt).normalized() + eigenvectors = Matrix.cols([v1, v2]) + return eigenvalues, eigenvectors + + +@func +def _svd(A, dt): + """Perform singular value decomposition (A=USV^T) for arbitrary size matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. + 2D implementation refers to :func:`taichi.svd2d`. + 3D implementation refers to :func:`taichi.svd3d`. + + Args: + A (ti.Matrix(n, n)): input nxn matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed nxn matrices `U`, 'S' and `V`. + """ + if static(A.n == 2): # pylint: disable=R1705 + ret = svd2d(A, dt) + return ret + else: + return svd3d(A, dt) + + +@func +def _polar_decompose(A, dt): + """Perform polar decomposition (A=UP) for arbitrary size matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. + 2D implementation refers to :func:`taichi.polar_decompose2d`. + 3D implementation refers to :func:`taichi.polar_decompose3d`. + + Args: + A (ti.Matrix(n, n)): input nxn matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed nxn matrices `U` and `P`. + """ + if static(A.n == 2): # pylint: disable=R1705 + ret = polar_decompose2d(A, dt) + return ret + else: + return polar_decompose3d(A, dt) + + +def polar_decompose(A, dt=None): + """Perform polar decomposition (A=UP) for arbitrary size matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. + This is only a wrapper for :func:`taichi.polar_decompose`. + + Args: + A (ti.Matrix(n, n)): input nxn matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed nxn matrices `U` and `P`. + """ + if dt is None: + dt = impl.get_runtime().default_fp + if A.n != 2 and A.n != 3: + raise Exception( + "Polar decomposition only supports 2D and 3D matrices.") + return _polar_decompose(A, dt) + + +def svd(A, dt=None): + """Perform singular value decomposition (A=USV^T) for arbitrary size matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. + This is only a wrappers for :func:`taichi.svd`. + + Args: + A (ti.Matrix(n, n)): input nxn matrix `A`. + dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. + + Returns: + Decomposed nxn matrices `U`, 'S' and `V`. + """ + if dt is None: + dt = impl.get_runtime().default_fp + if A.n != 2 and A.n != 3: + raise Exception("SVD only supports 2D and 3D matrices.") + return _svd(A, dt) + + +def eig(A, dt=None): + """Compute the eigenvalues and right eigenvectors of a real matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. + 2D implementation refers to :func:`taichi.eig2x2`. + + Args: + A (ti.Matrix(n, n)): 2D Matrix for which the eigenvalues and right eigenvectors will be computed. + dt (DataType): The datatype for the eigenvalues and right eigenvectors. + + Returns: + eigenvalues (ti.Matrix(n, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part. + eigenvectors (ti.Matrix(n*2, n)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of n entries, each of which is represented by two numbers for its real part and imaginary part. + """ + if dt is None: + dt = impl.get_runtime().default_fp + if A.n == 2: + return eig2x2(A, dt) + raise Exception("Eigen solver only supports 2D matrices.") + + +def sym_eig(A, dt=None): + """Compute the eigenvalues and right eigenvectors of a real symmetric matrix. + + Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. + 2D implementation refers to :func:`taichi.sym_eig2x2`. + + Args: + A (ti.Matrix(n, n)): Symmetric Matrix for which the eigenvalues and right eigenvectors will be computed. + dt (DataType): The datatype for the eigenvalues and right eigenvectors. + + Returns: + eigenvalues (ti.Vector(n)): The eigenvalues. Each entry store one eigen value. + eigenvectors (ti.Matrix(n, n)): The eigenvectors. Each column stores one eigenvector. + """ + assert all(A == A.transpose()), "A needs to be symmetric" + if dt is None: + dt = impl.get_runtime().default_fp + if A.n == 2: + return sym_eig2x2(A, dt) + raise Exception("Symmetric eigen solver only supports 2D matrices.") + + +__all__ = ['randn', 'polar_decompose', 'eig', 'sym_eig', 'svd'] diff --git a/python/taichi/_kernels.py b/python/taichi/_kernels.py new file mode 100644 index 0000000000000..683b703ba14df --- /dev/null +++ b/python/taichi/_kernels.py @@ -0,0 +1,235 @@ +from taichi._lib.utils import get_os_name +from taichi.lang import ops +from taichi.lang._ndrange import ndrange +from taichi.lang.expr import Expr +from taichi.lang.field import ScalarField +from taichi.lang.impl import grouped, static, static_assert +from taichi.lang.kernel_impl import kernel +from taichi.lang.runtime_ops import sync +from taichi.lang.snode import deactivate +from taichi.types.annotations import any_arr, ext_arr, template +from taichi.types.primitive_types import f16, f32, f64, u8 + + +# A set of helper (meta)functions +@kernel +def fill_tensor(tensor: template(), val: template()): + for I in grouped(tensor): + tensor[I] = val + + +@kernel +def fill_ndarray(ndarray: any_arr(), val: template()): + for I in grouped(ndarray): + ndarray[I] = val + + +@kernel +def fill_ndarray_matrix(ndarray: any_arr(), val: template()): + for I in grouped(ndarray): + ndarray[I].fill(val) + + +@kernel +def tensor_to_ext_arr(tensor: template(), arr: ext_arr()): + for I in grouped(tensor): + arr[I] = tensor[I] + + +@kernel +def ndarray_to_ext_arr(ndarray: any_arr(), arr: ext_arr()): + for I in grouped(ndarray): + arr[I] = ndarray[I] + + +@kernel +def ndarray_matrix_to_ext_arr(ndarray: any_arr(), arr: ext_arr(), + as_vector: template()): + for I in grouped(ndarray): + for p in static(range(ndarray[I].n)): + for q in static(range(ndarray[I].m)): + if static(as_vector): + arr[I, p] = ndarray[I][p] + else: + arr[I, p, q] = ndarray[I][p, q] + + +@kernel +def vector_to_fast_image(img: template(), out: ext_arr()): + # FIXME: Why is ``for i, j in img:`` slower than: + for i, j in ndrange(*img.shape): + r, g, b = 0, 0, 0 + color = img[i, img.shape[1] - 1 - j] + if static(img.dtype in [f16, f32, f64]): + r, g, b = min(255, max(0, int(color * 255))) + else: + static_assert(img.dtype == u8) + r, g, b = color + idx = j * img.shape[0] + i + # We use i32 for |out| since OpenGL and Metal doesn't support u8 types + if static(get_os_name() != 'osx'): + out[idx] = (r << 16) + (g << 8) + b + else: + # What's -16777216? + # + # On Mac, we need to set the alpha channel to 0xff. Since Mac's GUI + # is big-endian, the color is stored in ABGR order, and we need to + # add 0xff000000, which is -16777216 in I32's legit range. (Albeit + # the clarity, adding 0xff000000 doesn't work.) + alpha = -16777216 + out[idx] = (b << 16) + (g << 8) + r + alpha + + +@kernel +def tensor_to_image(tensor: template(), arr: ext_arr()): + for I in grouped(tensor): + t = ops.cast(tensor[I], f32) + arr[I, 0] = t + arr[I, 1] = t + arr[I, 2] = t + + +@kernel +def vector_to_image(mat: template(), arr: ext_arr()): + for I in grouped(mat): + for p in static(range(mat.n)): + arr[I, p] = ops.cast(mat[I][p], f32) + if static(mat.n <= 2): + arr[I, 2] = 0 + + +@kernel +def tensor_to_tensor(tensor: template(), other: template()): + for I in grouped(tensor): + tensor[I] = other[I] + + +@kernel +def ext_arr_to_tensor(arr: ext_arr(), tensor: template()): + for I in grouped(tensor): + tensor[I] = arr[I] + + +@kernel +def ndarray_to_ndarray(ndarray: any_arr(), other: any_arr()): + for I in grouped(ndarray): + ndarray[I] = other[I] + + +@kernel +def ext_arr_to_ndarray(arr: ext_arr(), ndarray: any_arr()): + for I in grouped(ndarray): + ndarray[I] = arr[I] + + +@kernel +def ext_arr_to_ndarray_matrix(arr: ext_arr(), ndarray: any_arr(), + as_vector: template()): + for I in grouped(ndarray): + for p in static(range(ndarray[I].n)): + for q in static(range(ndarray[I].m)): + if static(as_vector): + ndarray[I][p] = arr[I, p] + else: + ndarray[I][p, q] = arr[I, p, q] + + +@kernel +def matrix_to_ext_arr(mat: template(), arr: ext_arr(), as_vector: template()): + for I in grouped(mat): + for p in static(range(mat.n)): + for q in static(range(mat.m)): + if static(as_vector): + arr[I, p] = mat[I][p] + else: + arr[I, p, q] = mat[I][p, q] + + +@kernel +def ext_arr_to_matrix(arr: ext_arr(), mat: template(), as_vector: template()): + for I in grouped(mat): + for p in static(range(mat.n)): + for q in static(range(mat.m)): + if static(as_vector): + mat[I][p] = arr[I, p] + else: + mat[I][p, q] = arr[I, p, q] + + +@kernel +def clear_gradients(_vars: template()): + for I in grouped(ScalarField(Expr(_vars[0]))): + for s in static(_vars): + ScalarField(Expr(s))[I] = 0 + + +@kernel +def clear_loss(l: template()): + # Using SNode writers would result in a forced sync, therefore we wrap these + # writes into a kernel. + l[None] = 0 + l.grad[None] = 1 + + +@kernel +def fill_matrix(mat: template(), vals: template()): + for I in grouped(mat): + for p in static(range(mat.n)): + for q in static(range(mat.m)): + mat[I][p, q] = vals[p][q] + + +@kernel +def snode_deactivate(b: template()): + for I in grouped(b): + deactivate(b, I) + + +@kernel +def snode_deactivate_dynamic(b: template()): + for I in grouped(b.parent()): + deactivate(b, I) + + +# Odd-even merge sort +# References: +# https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting +# https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort +@kernel +def sort_stage(keys: template(), use_values: int, values: template(), N: int, + p: int, k: int, invocations: int): + for inv in range(invocations): + j = k % p + inv * 2 * k + for i in range(0, min(k, N - j - k)): + a = i + j + b = i + j + k + if int(a / (p * 2)) == int(b / (p * 2)): + key_a = keys[a] + key_b = keys[b] + if key_a > key_b: + keys[a] = key_b + keys[b] = key_a + if use_values != 0: + temp = values[a] + values[a] = values[b] + values[b] = temp + + +def parallel_sort(keys, values=None): + N = keys.shape[0] + + num_stages = 0 + p = 1 + while p < N: + k = p + while k >= 1: + invocations = int((N - k - k % p) / (2 * k)) + 1 + if values is None: + sort_stage(keys, 0, keys, N, p, k, invocations) + else: + sort_stage(keys, 1, values, N, p, k, invocations) + num_stages += 1 + sync() + k = int(k / 2) + p = int(p * 2) + print(num_stages) diff --git a/python/taichi/_lib/__init__.py b/python/taichi/_lib/__init__.py new file mode 100644 index 0000000000000..1334e8635e899 --- /dev/null +++ b/python/taichi/_lib/__init__.py @@ -0,0 +1 @@ +from taichi._lib.utils import ti_core as core diff --git a/python/taichi/_lib/core/__init__.py b/python/taichi/_lib/core/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/python/taichi/_lib/core/__init__.py @@ -0,0 +1 @@ + diff --git a/python/taichi/core/util.py b/python/taichi/_lib/utils.py similarity index 57% rename from python/taichi/core/util.py rename to python/taichi/_lib/utils.py index 7341bc6c59d77..8b2db08261479 100644 --- a/python/taichi/core/util.py +++ b/python/taichi/_lib/utils.py @@ -1,28 +1,19 @@ -import ctypes -import datetime -import multiprocessing import os import platform -import random -import shutil import sys -import time -from colorama import Back, Fore, Style +from colorama import Fore, Style if sys.version_info[0] < 3 or sys.version_info[1] <= 5: raise RuntimeError( "\nPlease restart with Python 3.6+\n" + "Current Python version:", sys.version_info) -ti_core = None - def in_docker(): if os.environ.get("TI_IN_DOCKER", "") == "": return False - else: - return True + return True def get_os_name(): @@ -31,23 +22,26 @@ def get_os_name(): # it will return 'macOS-XXXX' instead of 'Darwin-XXXX' if name.lower().startswith('darwin') or name.lower().startswith('macos'): return 'osx' - elif name.lower().startswith('windows'): + if name.lower().startswith('windows'): return 'win' - elif name.lower().startswith('linux'): + if name.lower().startswith('linux'): return 'linux' - assert False, "Unknown platform name %s" % name + if 'bsd' in name.lower(): + return 'unix' + assert False, f"Unknown platform name {name}" def import_ti_core(): - global ti_core if get_os_name() != 'win': + # pylint: disable=E1101 old_flags = sys.getdlopenflags() sys.setdlopenflags(2 | 8) # RTLD_NOW | RTLD_DEEPBIND else: - pyddir = os.path.join(package_root(), 'lib') - os.environ['PATH'] += ';' + pyddir + pyddir = os.path.dirname(os.path.realpath(__file__)) + os.environ['PATH'] += os.pathsep + pyddir try: - import taichi_core as core # pylint: disable=C0415 + from taichi._lib.core import \ + taichi_core as core # pylint: disable=C0415 except Exception as e: if isinstance(e, ImportError): print(Fore.YELLOW + "Share object taichi_core import failed, " @@ -55,28 +49,28 @@ def import_ti_core(): "https://docs.taichi.graphics/lang/articles/misc/install" + Fore.RESET) if get_os_name() == 'win': + # pylint: disable=E1101 e.msg += '\nConsider installing Microsoft Visual C++ Redistributable: https://aka.ms/vs/16/release/vc_redist.x64.exe' - elif get_os_name() == 'linux': - e.msg += '\nConsider installing libtinfo5: sudo apt-get install libtinfo5' raise e from None - ti_core = core + if get_os_name() != 'win': - sys.setdlopenflags(old_flags) - lib_dir = os.path.join(package_root(), 'lib') + sys.setdlopenflags(old_flags) # pylint: disable=E1101 + lib_dir = os.path.join(package_root, '_lib', 'runtime') core.set_lib_dir(locale_encode(lib_dir)) + return core def locale_encode(path): try: import locale # pylint: disable=C0415 return path.encode(locale.getdefaultlocale()[1]) - except: + except (UnicodeEncodeError, TypeError): try: return path.encode(sys.getfilesystemencoding()) - except: + except UnicodeEncodeError: try: return path.encode() - except: + except UnicodeEncodeError: return path @@ -84,12 +78,12 @@ def is_ci(): return os.environ.get('TI_CI', '') == '1' -def package_root(): - return os.path.join(os.path.dirname(os.path.realpath(__file__)), '../') +package_root = os.path.join( + os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) def get_core_shared_object(): - directory = os.path.join(package_root(), 'lib') + directory = os.path.join(package_root, '_lib') return os.path.join(directory, 'libtaichi_core.so') @@ -106,16 +100,9 @@ def check_exists(src): ) -def get_unique_task_id(): - return datetime.datetime.now().strftime('task-%Y-%m-%d-%H-%M-%S-r') + ( - '%05d' % random.randint(0, 10000)) +ti_core = import_ti_core() - -sys.path.append(os.path.join(package_root(), 'lib')) -import_ti_core() - -ti_core.set_python_package_dir(package_root()) -os.makedirs(ti_core.get_repo_dir(), exist_ok=True) +ti_core.set_python_package_dir(package_root) log_level = os.environ.get('TI_LOG_LEVEL', '') if log_level: @@ -124,43 +111,42 @@ def get_unique_task_id(): def get_dll_name(name): if get_os_name() == 'linux': - return 'libtaichi_%s.so' % name - elif get_os_name() == 'osx': - return 'libtaichi_%s.dylib' % name - elif get_os_name() == 'win': - return 'taichi_%s.dll' % name - else: - raise Exception(f"Unknown OS: {get_os_name()}") + return f'libtaichi_{name}.so' + if get_os_name() == 'osx': + return f'libtaichi_{name}.dylib' + if get_os_name() == 'win': + return f'taichi_{name}.dll' + raise Exception(f"Unknown OS: {get_os_name()}") def at_startup(): ti_core.set_core_state_python_imported(True) -def require_version(major, minor=None, patch=None): - versions = [ - int(ti_core.get_version_major()), - int(ti_core.get_version_minor()), - int(ti_core.get_version_patch()), - ] - match = major == versions[0] and ( - minor < versions[1] or minor == versions[1] and patch <= versions[2]) - if match: - return - else: - print("Taichi version mismatch. required >= {}.{}.{}".format( - major, minor, patch)) - print("Installed =", ti_core.get_version_string()) - raise Exception("Taichi version mismatch") +at_startup() -at_startup() +def compare_version(latest, current): + latest_num = map(int, latest.split('.')) + current_num = map(int, current.split('.')) + return tuple(latest_num) > tuple(current_num) def _print_taichi_header(): header = '[Taichi] ' header += f'version {ti_core.get_version_string()}, ' + try: + timestamp_path = os.path.join(ti_core.get_repo_dir(), 'timestamp') + if os.path.exists(timestamp_path): + latest_version = '' + with open(timestamp_path, 'r') as f: + latest_version = f.readlines()[1].rstrip() + if compare_version(latest_version, ti_core.get_version_string()): + header += f'latest version {latest_version}, ' + except: + pass + llvm_version = ti_core.get_llvm_version_string() header += f'llvm {llvm_version}, ' @@ -177,10 +163,3 @@ def _print_taichi_header(): _print_taichi_header() - -__all__ = [ - 'ti_core', - 'get_os_name', - 'package_root', - 'require_version', -] diff --git a/python/taichi/_logging.py b/python/taichi/_logging.py index 927c9b2e59036..29696b9fe062d 100644 --- a/python/taichi/_logging.py +++ b/python/taichi/_logging.py @@ -1,10 +1,19 @@ import inspect import os -from taichi.core import ti_core +from taichi._lib import core as ti_core def _get_logging(name): + """Generates a decorator to decorate a specific logger function. + + Args: + name (str): The string represents logging level. + Effective levels include: 'trace', 'debug', 'info', 'warn', 'error', 'critical'. + + Returns: + Callabe: The decorated function. + """ def logger(msg, *args, **kwargs): # Python inspection takes time (~0.1ms) so avoid it as much as possible if ti_core.logging_effective(name): @@ -20,19 +29,83 @@ def logger(msg, *args, **kwargs): def set_logging_level(level): + """Setting the logging level to a specified value. + Available levels are: 'trace', 'debug', 'info', 'warn', 'error', 'critical'. + + Note that after calling this function, logging levels below the specified one will + also be effective. For example if `level` is set to 'warn', then the levels below + it, which are 'error' and 'critical' in this case, will also be effective. + + See also https://docs.taichi.graphics/lang/articles/contribution/utilities#logging. + + Args: + level (str): Logging level. + + Example:: + + >>> set_logging_level('debug') + """ ti_core.set_logging_level(level) def is_logging_effective(level): + """Check if the specified logging level is effective. + All levels below current level will be effective. + The default level is 'info'. + + See also https://docs.taichi.graphics/lang/articles/contribution/utilities#logging. + + Args: + level (str): The string represents logging level. \ + Effective levels include: 'trace', 'debug', 'info', 'warn', 'error', 'critical'. + + Returns: + Bool: Indicate whether the logging level is effective. + + Example:: + + >>> # assume current level is 'info' + >>> print(ti.is_logging_effective("trace")) # False + >>> print(ti.is_logging_effective("debug")) # False + >>> print(ti.is_logging_effective("info")) # True + >>> print(ti.is_logging_effective("warn")) # True + >>> print(ti.is_logging_effective("error")) # True + >>> print(ti.is_logging_effective("critical")) # True + """ return ti_core.logging_effective(level) +# ------------------------ + DEBUG = 'debug' +"""The `str` 'debug', used for the `debug` logging level. +""" +# ------------------------ + TRACE = 'trace' +"""The `str` 'trace', used for the `debug` logging level. +""" +# ------------------------ + INFO = 'info' +"""The `str` 'info', used for the `info` logging level. +""" +# ------------------------ + WARN = 'warn' +"""The `str` 'warn', used for the `warn` logging level. +""" +# ------------------------ + ERROR = 'error' +"""The `str` 'error', used for the `error` logging level. +""" +# ------------------------ + CRITICAL = 'critical' +"""The `str` 'critical', used for the `critical` logging level. +""" +# ------------------------ supported_log_levels = [DEBUG, TRACE, INFO, WARN, ERROR, CRITICAL] @@ -44,7 +117,6 @@ def is_logging_effective(level): critical = _get_logging(CRITICAL) __all__ = [ - 'DEBUG', 'TRACE', 'INFO', 'WARN', 'ERROR', 'CRITICAL', 'debug', 'trace', - 'info', 'warn', 'error', 'critical', 'supported_log_levels', - 'set_logging_level', 'is_logging_effective' + 'DEBUG', 'TRACE', 'INFO', 'WARN', 'ERROR', 'CRITICAL', 'set_logging_level', + 'is_logging_effective' ] diff --git a/python/taichi/main.py b/python/taichi/_main.py similarity index 66% rename from python/taichi/main.py rename to python/taichi/_main.py index f1886b75d9984..254553f3ec841 100644 --- a/python/taichi/main.py +++ b/python/taichi/_main.py @@ -1,7 +1,6 @@ import argparse import math import os -import random import runpy import shutil import subprocess @@ -12,11 +11,10 @@ from pathlib import Path import numpy as np -import taichi.cc_compose -import taichi.diagnose -from colorama import Back, Fore, Style -from taichi.core import ti_core as _ti_core -from taichi.tools import video +from colorama import Fore +from taichi._lib import core as _ti_core +from taichi._lib import utils +from taichi.tools import cc_compose, diagnose, video import taichi as ti @@ -78,7 +76,7 @@ def __call__(self): # Parse the command args = self.main_parser.parse_args(sys.argv[1:2]) - if args.command not in self.registered_commands: + if args.command not in self.registered_commands: # pylint: disable=E1101 # TODO: do we really need this? if args.command.endswith(".py"): TaichiMain._exec_python_file(args.command) @@ -89,25 +87,19 @@ def __call__(self): return getattr(self, args.command)(sys.argv[2:]) - def _get_friend_links(self): - uri = 'en/stable' - try: - import locale # pylint: disable=C0415 - if 'zh' in locale.getdefaultlocale()[0]: - uri = 'zh_CN/latest' - except: - pass + @staticmethod + def _get_friend_links(): return '\n' \ - f'Docs: https://taichi.rtfd.io/{uri}\n' \ - f'GitHub: https://github.com/taichi-dev/taichi\n' \ - f'Forum: https://forum.taichi.graphics\n' + 'Docs: https://docs.taichi.graphics/\n' \ + 'GitHub: https://github.com/taichi-dev/taichi/\n' \ + 'Forum: https://forum.taichi.graphics/\n' def _usage(self) -> str: """Compose deterministic usage message based on registered_commands.""" # TODO: add some color to commands msg = "\n" space = 20 - for command in sorted(self.registered_commands): + for command in sorted(self.registered_commands): # pylint: disable=E1101 msg += f" {command}{' ' * (space - len(command))}|-> {getattr(self, command).__doc__}\n" return msg @@ -121,7 +113,7 @@ def _exec_python_file(filename: str): def _get_examples_dir() -> Path: """Get the path to the examples directory.""" - root_dir = ti.package_root() + root_dir = utils.package_root examples_dir = Path(root_dir) / 'examples' return examples_dir @@ -130,10 +122,7 @@ def _get_available_examples() -> set: """Get a set of all available example names.""" examples_dir = TaichiMain._get_examples_dir() all_examples = examples_dir.rglob('*.py') - all_example_names = { - Path(f).stem: Path(f).parent - for f in all_examples - } + all_example_names = {f.stem: f.parent for f in all_examples} return all_example_names @staticmethod @@ -142,8 +131,7 @@ def support_choice_with_dot_py(choice): if choice.endswith('.py') and choice.split('.')[0] in choices: # try to find and remove python file extension return choice.split('.')[0] - else: - return choice + return choice return support_choice_with_dot_py @@ -156,7 +144,7 @@ def example(self, arguments: list = sys.argv[2:]): description=f"{self.example.__doc__}") parser.add_argument( "name", - help=f"Name of an example (supports .py extension too)\n", + help="Name of an example (supports .py extension too)\n", type=TaichiMain._example_choices_type(choices.keys()), choices=sorted(choices.keys())) parser.add_argument( @@ -180,6 +168,8 @@ def example(self, arguments: list = sys.argv[2:]): dest='save', action='store_true', help="Save source code to current directory instead of running it") + + # TODO: Pass the arguments to downstream correctly(#3216). args = parser.parse_args(arguments) examples_dir = TaichiMain._get_examples_dir() @@ -189,7 +179,8 @@ def example(self, arguments: list = sys.argv[2:]): sys.path.append(str((examples_dir / choices[args.name]).resolve())) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args if args.save: print(f"Saving example {args.name} to current directory...") @@ -200,7 +191,7 @@ def example(self, arguments: list = sys.argv[2:]): try: import rich.console # pylint: disable=C0415 import rich.syntax # pylint: disable=C0415 - except ImportError as e: + except ImportError: print('To make -P work, please: python3 -m pip install rich') return 1 # https://rich.readthedocs.io/en/latest/syntax.html @@ -218,15 +209,19 @@ def example(self, arguments: list = sys.argv[2:]): runpy.run_path(target, run_name='__main__') + return None + + @staticmethod @register - def changelog(self, arguments: list = sys.argv[2:]): + def changelog(arguments: list = sys.argv[2:]): """Display changelog of current version""" - changelog_md = os.path.join(ti.package_root(), 'CHANGELOG.md') + changelog_md = os.path.join(utils.package_root, 'CHANGELOG.md') with open(changelog_md) as f: print(f.read()) + @staticmethod @register - def release(self, arguments: list = sys.argv[2:]): + def release(arguments: list = sys.argv[2:]): """Make source code release""" raise RuntimeError('TBD') @@ -257,12 +252,15 @@ def gif(self, arguments: list = sys.argv[2:]): args = parser.parse_args(arguments) args.output_file = str(Path(args.input_file).with_suffix('.gif')) - ti.info(f"Converting {args.input_file} to {args.output_file}") + ti._logging.info(f"Converting {args.input_file} to {args.output_file}") # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args video.mp4_to_gif(args.input_file, args.output_file, args.framerate) + return None + @register def video_speed(self, arguments: list = sys.argv[2:]): """Speed up video in the same directory""" @@ -290,9 +288,12 @@ def video_speed(self, arguments: list = sys.argv[2:]): )) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args video.accelerate_video(args.input_file, args.output_file, args.speed) + return None + @register def video_crop(self, arguments: list = sys.argv[2:]): """Crop video in the same directory""" @@ -332,10 +333,13 @@ def video_crop(self, arguments: list = sys.argv[2:]): )) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args video.crop_video(args.input_file, args.output_file, args.x_begin, args.x_end, args.y_begin, args.y_end) + return None + @register def video_scale(self, arguments: list = sys.argv[2:]): """Scale video resolution in the same directory""" @@ -373,10 +377,13 @@ def video_scale(self, arguments: list = sys.argv[2:]): )) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args video.scale_video(args.input_file, args.output_file, args.ratio_width, args.ratio_height) + return None + @register def video(self, arguments: list = sys.argv[2:]): """Make a video using *.png files in the current directory""" @@ -413,36 +420,42 @@ def video(self, arguments: list = sys.argv[2:]): assert 1 <= args.crf <= 51, "The range of the CRF scale is 1-51, where 1 is almost lossless, 20 is the default, and 51 is worst quality possible." - ti.info(f'Making video using {len(args.inputs)} png files...') - ti.info(f'frame_rate = {args.framerate}') + ti._logging.info(f'Making video using {len(args.inputs)} png files...') + ti._logging.info(f'frame_rate = {args.framerate}') # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args video.make_video(args.inputs, output_path=str(args.output_file), crf=args.crf, frame_rate=args.framerate) - ti.info(f'Done! Output video file = {args.output_file}') + ti._logging.info(f'Done! Output video file = {args.output_file}') + return None + + @staticmethod @register - def doc(self, arguments: list = sys.argv[2:]): + def doc(arguments: list = sys.argv[2:]): """Build documentation""" raise RuntimeError('TBD') + @staticmethod @register - def format(self, arguments: list = sys.argv[2:]): + def format(arguments: list = sys.argv[2:]): """Reformat modified source files""" raise RuntimeError('Please run python misc/code_format.py instead') + @staticmethod @register - def format_all(self, arguments: list = sys.argv[2:]): + def format_all(arguments: list = sys.argv[2:]): """Reformat all source files""" raise RuntimeError('Please run python misc/code_format.py instead') @staticmethod def _display_benchmark_regression(xd, yd, args): def parse_dat(file): - dict = {} + _dict = {} with open(file) as f: for line in f.readlines(): try: @@ -452,28 +465,27 @@ def parse_dat(file): b = float(b) if abs(b % 1.0) < 1e-5: # codegen_* b = int(b) - dict[a.strip()] = b - return dict + _dict[a.strip()] = b + return _dict def parse_name(file): if file[0:5] == 'test_': return file[5:-4].replace('__test_', '::', 1) - elif file[0:10] == 'benchmark_': + if file[0:10] == 'benchmark_': return '::'.join(reversed(file[10:-4].split('__arch_'))) - else: - raise Exception(f'bad benchmark file name {file}') + raise Exception(f'bad benchmark file name {file}') - def get_dats(dir): - list = [] - for x in os.listdir(dir): + def get_dats(directory): + _list = [] + for x in os.listdir(directory): if x.endswith('.dat'): - list.append(x) - dict = {} - for x in list: + _list.append(x) + _dict = {} + for x in _list: name = parse_name(x) - path = os.path.join(dir, x) - dict[name] = parse_dat(path) - return dict + path = os.path.join(directory, x) + _dict[name] = parse_dat(path) + return _dict def plot_in_gui(scatter): @@ -508,13 +520,16 @@ def plot_in_gui(scatter): else: res = b / a scatter[key].append(res) - if res == 1: continue + if res == 1: + continue if not single_line: ret += f'{key:<30}' res -= 1 color = Fore.RESET - if res > 0: color = Fore.RED - elif res < 0: color = Fore.GREEN + if res > 0: + color = Fore.RED + elif res < 0: + color = Fore.GREEN if isinstance(a, float): a = f'{a:>7.2}' else: @@ -560,13 +575,16 @@ def regression(self, arguments: list = sys.argv[2:]): args = parser.parse_args(arguments) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args baseline_dir = TaichiMain._get_benchmark_baseline_dir() output_dir = TaichiMain._get_benchmark_output_dir() TaichiMain._display_benchmark_regression(baseline_dir, output_dir, args) + return None + @register def baseline(self, arguments: list = sys.argv[2:]): """Archive current benchmark result as baseline""" @@ -575,7 +593,8 @@ def baseline(self, arguments: list = sys.argv[2:]): args = parser.parse_args(arguments) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args baseline_dir = TaichiMain._get_benchmark_baseline_dir() output_dir = TaichiMain._get_benchmark_output_dir() @@ -583,77 +602,7 @@ def baseline(self, arguments: list = sys.argv[2:]): shutil.copytree(output_dir, baseline_dir) print(f"[benchmark] baseline data saved to {baseline_dir}") - @staticmethod - def _test_python(args): - print("\nRunning Python tests...\n") - - root_dir = ti.package_root() - test_dir = os.path.join(root_dir, 'tests') - pytest_args = [] - - # TODO: use pathlib to deal with suffix and stem name manipulation - if args.files: - # run individual tests - for f in args.files: - # auto-complete file names - if not f.startswith('test_'): - f = 'test_' + f - if not f.endswith('.py'): - f = f + '.py' - pytest_args.append(os.path.join(test_dir, f)) - else: - # run all the tests - pytest_args = [test_dir] - if args.verbose: - pytest_args += ['-v'] - if args.rerun: - pytest_args += ['--reruns', args.rerun] - try: - if args.coverage: - pytest_args += ['--cov-branch', '--cov=python/taichi'] - if args.cov_append: - pytest_args += ['--cov-append'] - if args.keys: - pytest_args += ['-k', args.keys] - if args.marks: - pytest_args += ['-m', args.marks] - if args.failed_first: - pytest_args += ['--failed-first'] - if args.fail_fast: - pytest_args += ['--exitfirst'] - except AttributeError: - pass - - try: - from multiprocessing import cpu_count # pylint: disable=C0415 - threads = min(8, cpu_count()) # To prevent running out of memory - except NotImplementedError: - threads = 2 - - if not os.environ.get('TI_DEVICE_MEMORY_GB'): - os.environ['TI_DEVICE_MEMORY_GB'] = '1.0' # Discussion: #769 - - env_threads = os.environ.get('TI_TEST_THREADS', '') - threads = args.threads or env_threads or threads - print(f'Starting {threads} testing thread(s)...') - if args.show_output: - pytest_args += ['-s'] - print( - f'Due to how pytest-xdist is implemented, the -s option does not work with multiple thread...' - ) - else: - if int(threads) > 1: - pytest_args += ['-n', str(threads)] - import pytest # pylint: disable=C0415 - return int(pytest.main(pytest_args)) - - @staticmethod - def _test_cpp(args): - # Cpp tests use the legacy non LLVM backend - ti.reset() - print("Running C++ tests...") - task = ti.Task('test') - return int(task.run(*args.files)) + return None @register def benchmark(self, arguments: list = sys.argv[2:]): @@ -689,7 +638,8 @@ def benchmark(self, arguments: list = sys.argv[2:]): args = parser.parse_args(arguments) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args commit_hash = _ti_core.get_commit_hash() with os.popen('git rev-parse HEAD') as f: @@ -707,136 +657,18 @@ def benchmark(self, arguments: list = sys.argv[2:]): os.system('python benchmarks/run.py') # TODO: benchmark_python(args) else: - TaichiMain._test_python(args) + # TODO: shall we replace this with the new benchmark tools? + os.system('python tests/run_tests.py') + + return None + @staticmethod @register def test(self, arguments: list = sys.argv[2:]): - """Run the tests""" - parser = argparse.ArgumentParser(prog='ti test', - description=f"{self.test.__doc__}") - parser.add_argument('files', - nargs='*', - help='Test name(s) to be run, e.g. "cli"') - parser.add_argument('-c', - '--cpp', - dest='cpp', - action='store_true', - help='Only run the C++ tests') - parser.add_argument('-s', - '--show', - dest='show_output', - action='store_true', - help='Show output (do not capture)') - parser.add_argument('-v', - '--verbose', - dest='verbose', - action='store_true', - help='Run with verbose outputs') - parser.add_argument('-r', - '--rerun', - required=False, - default=None, - dest='rerun', - type=str, - help='Rerun failed tests for given times') - parser.add_argument('-k', - '--keys', - required=False, - default=None, - dest='keys', - type=str, - help='Only run tests that match the keys') - parser.add_argument('-m', - '--marks', - required=False, - default=None, - dest='marks', - type=str, - help='Only run tests with specific marks') - parser.add_argument('-f', - '--failed-first', - required=False, - default=None, - dest='failed_first', - action='store_true', - help='Run the previously failed test first') - parser.add_argument('-x', - '--fail-fast', - required=False, - default=None, - dest='fail_fast', - action='store_true', - help='Exit instantly on the first failed test') - parser.add_argument('-C', - '--coverage', - required=False, - default=None, - dest='coverage', - action='store_true', - help='Run tests and record the coverage result') - parser.add_argument( - '-A', - '--cov-append', - required=False, - default=None, - dest='cov_append', - action='store_true', - help= - 'Append coverage result to existing one instead of overriding it') - parser.add_argument( - '-t', - '--threads', - required=False, - default=None, - dest='threads', - type=str, - help='Custom number of threads for parallel testing') - parser.add_argument( - '-a', - '--arch', - required=False, - default=None, - dest='arch', - type=str, - help='Custom the arch(s) (backend) to run tests on') - parser.add_argument( - '-n', - '--exclusive', - required=False, - default=False, - dest='exclusive', - action='store_true', - help= - 'Exclude arch(s) from test instead of include them, together with -a' + raise RuntimeError( + 'ti test is deprecated. Please run `python tests/run_tests.py` instead.' ) - args = parser.parse_args(arguments) - - if args.arch: - arch = args.arch - if args.exclusive: - arch = '^' + arch - print(f'Running on Arch={arch}') - os.environ['TI_WANTED_ARCHS'] = arch - - # Short circuit for testing - if self.test_mode: return args - - if args.files: - if args.cpp: - return TaichiMain._test_cpp(args) - else: - return TaichiMain._test_python(args) - elif args.cpp: - # Only run C++ tests - return TaichiMain._test_cpp(args) - else: - # Run both C++ and Python tests - ret = TaichiMain._test_python(args) - if ret != 0: - return ret - return TaichiMain._test_cpp(args) - @register def run(self, arguments: list = sys.argv[2:]): """Run a single script""" @@ -848,10 +680,13 @@ def run(self, arguments: list = sys.argv[2:]): args = parser.parse_args(arguments) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args runpy.run_path(args.filename) + return None + @register def debug(self, arguments: list = sys.argv[2:]): """Debug a single script""" @@ -864,56 +699,21 @@ def debug(self, arguments: list = sys.argv[2:]): args = parser.parse_args(arguments) # Short circuit for testing - if self.test_mode: return args + if self.test_mode: + return args _ti_core.set_core_trigger_gdb_when_crash(True) os.environ['TI_DEBUG'] = '1' runpy.run_path(args.filename, run_name='__main__') - @register - def task(self, arguments: list = sys.argv[2:]): - """Run a specific task""" - parser = argparse.ArgumentParser(prog='ti task', - description=f"{self.task.__doc__}") - parser.add_argument('taskname', - help='A single task name to run, e.g. test_math') - parser.add_argument('taskargs', - nargs='*', - help='Optional task argument(s) to run with task') - args = parser.parse_args(arguments) - - # Short circuit for testing - if self.test_mode: return args - - task = ti.Task(args.taskname) - task.run(*args.taskargs) - - @register - def dist(self, arguments: list = sys.argv[2:]): - """Build package and test in release mode""" - parser = argparse.ArgumentParser(prog='ti dist', - description=f"{self.dist.__doc__}") - parser.add_argument('mode', - nargs='?', - default='test', - choices=['upload', 'try_upload', 'test'], - help='Which mode shall we run?') - args = parser.parse_args(arguments) - - os.chdir(os.path.join(_ti_core.get_repo_dir(), 'python')) - sys.argv.pop(0) - sys.argv.append(args.mode) - runpy.run_path('build.py') + return None + @staticmethod @register - def diagnose(self, arguments: list = sys.argv[2:]): + def diagnose(arguments: list = sys.argv[2:]): """System diagnose information""" - parser = argparse.ArgumentParser( - prog='ti diagnose', description=f"{self.diagnose.__doc__}") - args = parser.parse_args(arguments) - - taichi.diagnose.main() + diagnose.main() @register def cc_compose(self, arguments: list = sys.argv[2:]): @@ -939,16 +739,13 @@ def cc_compose(self, arguments: list = sys.argv[2:]): help='Generate output C file for Emscripten instead of raw C') args = parser.parse_args(arguments) - taichi.cc_compose.main(args.fin_name, args.fout_name, args.hdrout_name, - args.emscripten) + cc_compose.main(args.fin_name, args.fout_name, args.hdrout_name, + args.emscripten) + @staticmethod @register - def repl(self, arguments: list = sys.argv[2:]): + def repl(arguments: list = sys.argv[2:]): """Start Taichi REPL / Python shell with 'import taichi as ti'""" - parser = argparse.ArgumentParser(prog='ti repl', - description=f"{self.repl.__doc__}") - args = parser.parse_args(arguments) - def local_scope(): try: @@ -956,18 +753,18 @@ def local_scope(): IPython.embed() except ImportError: import code # pylint: disable=C0415 - __name__ = '__console__' + __name__ = '__console__' # pylint: disable=W0622 code.interact(local=locals()) local_scope() + @staticmethod @register - def lint(self, arguments: list = sys.argv[2:]): + def lint(arguments: list = sys.argv[2:]): """Run pylint checker for the Python codebase of Taichi""" - parser = argparse.ArgumentParser(prog='ti lint', - description=f"{self.lint.__doc__}") # TODO: support arguments for lint specific files - args = parser.parse_args(arguments) + # parser = argparse.ArgumentParser(prog='ti lint', description=f"{self.lint.__doc__}") + # args = parser.parse_args(arguments) options = [os.path.dirname(__file__)] diff --git a/python/taichi/_snode/__init__.py b/python/taichi/_snode/__init__.py new file mode 100644 index 0000000000000..c323e8e6ab814 --- /dev/null +++ b/python/taichi/_snode/__init__.py @@ -0,0 +1,3 @@ +from taichi._snode.fields_builder import FieldsBuilder + +__all__ = ['FieldsBuilder'] diff --git a/python/taichi/snode/fields_builder.py b/python/taichi/_snode/fields_builder.py similarity index 68% rename from python/taichi/snode/fields_builder.py rename to python/taichi/_snode/fields_builder.py index 94fecba268535..4409790e1a298 100644 --- a/python/taichi/snode/fields_builder.py +++ b/python/taichi/_snode/fields_builder.py @@ -1,12 +1,10 @@ -import functools -import types from typing import Any, Optional, Sequence, Union -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core +from taichi._snode.snode_tree import SNodeTree from taichi.lang import impl, snode -from taichi.lang.exception import InvalidOperationError -from taichi.misc.util import warning -from taichi.snode.snode_tree import SNodeTree +from taichi.lang.exception import TaichiRuntimeError +from taichi.lang.util import warning _snode_registry = _ti_core.SNodeRegistry() @@ -36,14 +34,14 @@ class FieldsBuilder: fb.finalize() """ def __init__(self): - self._ptr = _snode_registry.create_root() - self._root = snode.SNode(self._ptr) - self._finalized = False - self._empty = True + self.ptr = _snode_registry.create_root(impl.get_runtime().prog) + self.root = snode.SNode(self.ptr) + self.finalized = False + self.empty = True # TODO: move this into SNodeTree @classmethod - def finalized_roots(cls): + def _finalized_roots(cls): """Gets all the roots of the finalized SNodeTree. Returns: @@ -56,27 +54,11 @@ def finalized_roots(cls): roots_ptr.append(snode.SNode(res)) return roots_ptr - @property - def ptr(self): - return self._ptr - - @property - def root(self): - return self._root - - @property - def empty(self): - return self._empty - - @property - def finalized(self): - return self._finalized - # TODO: move this to SNodeTree class. def deactivate_all(self): """Same as :func:`taichi.lang.snode.SNode.deactivate_all`""" - if self._finalized: - self._root.deactivate_all() + if self.finalized: + self.root.deactivate_all() else: warning( """'deactivate_all()' would do nothing if FieldsBuilder is not finalized""" @@ -86,17 +68,17 @@ def dense(self, indices: Union[Sequence[_Axis], _Axis], dimensions: Union[Sequence[int], int]): """Same as :func:`taichi.lang.snode.SNode.dense`""" self._check_not_finalized() - self._empty = False - return self._root.dense(indices, dimensions) + self.empty = False + return self.root.dense(indices, dimensions) def pointer(self, indices: Union[Sequence[_Axis], _Axis], dimensions: Union[Sequence[int], int]): """Same as :func:`taichi.lang.snode.SNode.pointer`""" self._check_not_finalized() - self._empty = False - return self._root.pointer(indices, dimensions) + self.empty = False + return self.root.pointer(indices, dimensions) - def hash(self, indices, dimensions): + def _hash(self, indices, dimensions): """Same as :func:`taichi.lang.snode.SNode.hash`""" raise NotImplementedError() @@ -106,28 +88,28 @@ def dynamic(self, chunk_size: Optional[int] = None): """Same as :func:`taichi.lang.snode.SNode.dynamic`""" self._check_not_finalized() - self._empty = False - return self._root.dynamic(index, dimension, chunk_size) + self.empty = False + return self.root.dynamic(index, dimension, chunk_size) def bitmasked(self, indices: Union[Sequence[_Axis], _Axis], dimensions: Union[Sequence[int], int]): """Same as :func:`taichi.lang.snode.SNode.bitmasked`""" self._check_not_finalized() - self._empty = False - return self._root.bitmasked(indices, dimensions) + self.empty = False + return self.root.bitmasked(indices, dimensions) def bit_struct(self, num_bits: int): """Same as :func:`taichi.lang.snode.SNode.bit_struct`""" self._check_not_finalized() - self._empty = False - return self._root.bit_struct(num_bits) + self.empty = False + return self.root.bit_struct(num_bits) def bit_array(self, indices: Union[Sequence[_Axis], _Axis], dimensions: Union[Sequence[int], int], num_bits: int): """Same as :func:`taichi.lang.snode.SNode.bit_array`""" self._check_not_finalized() - self._empty = False - return self._root.bit_array(indices, dimensions, num_bits) + self.empty = False + return self.root.bit_array(indices, dimensions, num_bits) def place(self, *args: Any, @@ -135,29 +117,37 @@ def place(self, shared_exponent: bool = False): """Same as :func:`taichi.lang.snode.SNode.place`""" self._check_not_finalized() - self._empty = False - self._root.place(*args, offset=offset, shared_exponent=shared_exponent) + self.empty = False + self.root.place(*args, offset=offset, shared_exponent=shared_exponent) def lazy_grad(self): """Same as :func:`taichi.lang.snode.SNode.lazy_grad`""" # TODO: This complicates the implementation. Figure out why we need this self._check_not_finalized() - self._empty = False - self._root.lazy_grad() + self.empty = False + self.root.lazy_grad() def finalize(self, raise_warning=True): """Constructs the SNodeTree and finalizes this builder. Args: raise_warning (bool): Raise warning or not.""" + return self._finalize(raise_warning, compile_only=False) + + def _finalize_for_aot(self): + """Constructs the SNodeTree and compiles the type for AOT purpose.""" + return self._finalize(raise_warning=False, compile_only=True) + + def _finalize(self, raise_warning, compile_only): self._check_not_finalized() - if self._empty and raise_warning: + if self.empty and raise_warning: warning("Finalizing an empty FieldsBuilder!") - self._finalized = True + self.finalized = True return SNodeTree( - _ti_core.finalize_snode_tree(_snode_registry, self._ptr, - impl.get_runtime().prog)) + _ti_core.finalize_snode_tree(_snode_registry, self.ptr, + impl.get_runtime().prog, + compile_only)) def _check_not_finalized(self): - if self._finalized: - raise InvalidOperationError('FieldsBuilder finalized') + if self.finalized: + raise TaichiRuntimeError('FieldsBuilder finalized') diff --git a/python/taichi/snode/snode_tree.py b/python/taichi/_snode/snode_tree.py similarity index 70% rename from python/taichi/snode/snode_tree.py rename to python/taichi/_snode/snode_tree.py index 72c1dd7ecf400..18887421daf15 100644 --- a/python/taichi/snode/snode_tree.py +++ b/python/taichi/_snode/snode_tree.py @@ -3,9 +3,8 @@ # loaded during the import procedure, it's probably still good to delay the # access to it. -from taichi.core.util import ti_core as _ti_core from taichi.lang import impl -from taichi.lang.exception import InvalidOperationError +from taichi.lang.exception import TaichiRuntimeError class SNodeTree: @@ -15,12 +14,12 @@ def __init__(self, ptr): def destroy(self): if self.destroyed: - raise InvalidOperationError('SNode tree has been destroyed') + raise TaichiRuntimeError('SNode tree has been destroyed') self.ptr.destroy_snode_tree(impl.get_runtime().prog) self.destroyed = True @property def id(self): if self.destroyed: - raise InvalidOperationError('SNode tree has been destroyed') + raise TaichiRuntimeError('SNode tree has been destroyed') return self.ptr.id() diff --git a/python/taichi/_version_check.py b/python/taichi/_version_check.py new file mode 100644 index 0000000000000..5958aecab2455 --- /dev/null +++ b/python/taichi/_version_check.py @@ -0,0 +1,108 @@ +import datetime +import json +import os +import platform +import threading +import uuid +from urllib import request + +from taichi._lib import core as _ti_core + + +def check_version(cur_uuid): + # Check Taichi version for the user. + major = _ti_core.get_version_major() + minor = _ti_core.get_version_minor() + patch = _ti_core.get_version_patch() + version = f'{major}.{minor}.{patch}' + payload = {'version': version, 'platform': '', 'python': ''} + + system = platform.system() + if system == 'Linux': + payload['platform'] = 'manylinux1_x86_64' + elif system == 'Windows': + payload['platform'] = 'win_amd64' + elif system == 'Darwin': + if platform.release() < '19.0.0': + payload['platform'] = 'macosx_10_14_x86_64' + elif platform.machine() == 'x86_64': + payload['platform'] = 'macosx_10_15_x86_64' + else: + payload['platform'] = 'macosx_11_0_arm64' + + python_version = platform.python_version() + if python_version.startswith('3.6.'): + payload['python'] = 'cp36' + elif python_version.startswith('3.7.'): + payload['python'] = 'cp37' + elif python_version.startswith('3.8.'): + payload['python'] = 'cp38' + elif python_version.startswith('3.9.'): + payload['python'] = 'cp39' + elif python_version.startswith('3.10.'): + payload['python'] = 'cp310' + + payload['uuid'] = cur_uuid + # We do not want request exceptions break users' usage of Taichi. + try: + payload = json.dumps(payload) + payload = payload.encode() + req = request.Request('https://metadata.taichi.graphics/check_version', + method='POST') + req.add_header('Content-Type', 'application/json') + with request.urlopen(req, data=payload, timeout=5) as response: + response = json.loads(response.read().decode('utf-8')) + return response + except: + return None + + +def write_version_info(response, cur_uuid, version_info_path, cur_date): + if response is None: + return + with open(version_info_path, 'w') as f: + f.write((cur_date).strftime('%Y-%m-%d')) + f.write('\n') + if response['status'] == 1: + f.write(response['latest_version']) + else: + f.write('0.0.0') + f.write('\n') + f.write(cur_uuid) + f.write('\n') + + +def try_check_version(): + try: + os.makedirs(_ti_core.get_repo_dir(), exist_ok=True) + version_info_path = os.path.join(_ti_core.get_repo_dir(), + 'version_info') + cur_date = datetime.date.today() + if os.path.exists(version_info_path): + with open(version_info_path, 'r') as f: + version_info_file = f.readlines() + last_time = version_info_file[0].rstrip() + cur_uuid = version_info_file[2].rstrip() + if cur_date.strftime('%Y-%m-%d') > last_time: + response = check_version(cur_uuid) + write_version_info(response, cur_uuid, version_info_path, + cur_date) + else: + cur_uuid = str(uuid.uuid4()) + response = check_version(cur_uuid) + write_version_info(response, cur_uuid, version_info_path, cur_date) + # Wildcard exception to catch potential file writing errors. + except: + pass + + +def start_version_check_thread(): + skip = os.environ.get("TI_SKIP_VERSION_CHECK") + if skip != 'ON': + # We don't join this thread because we do not wish to block users. + check_version_thread = threading.Thread(target=try_check_version, + daemon=True) + check_version_thread.start() + + +__all__ = [] diff --git a/python/taichi/ad.py b/python/taichi/ad.py index 04a7ba61e6096..1a896fb3587d8 100644 --- a/python/taichi/ad.py +++ b/python/taichi/ad.py @@ -66,7 +66,7 @@ def decorated(*args, **kwargs): ) if primal.grad is not None: raise RuntimeError( - f'Primal function must be a **python** function instead of a taichi kernel. Please wrap the taichi kernel in a @ti.ad.grad_replaced decorated python function instead.' + 'Primal function must be a **python** function instead of a taichi kernel. Please wrap the taichi kernel in a @ti.ad.grad_replaced decorated python function instead.' ) primal.grad = decorated return decorated diff --git a/python/taichi/aot/module.py b/python/taichi/aot/module.py index d96b71a0328ec..696875199033b 100644 --- a/python/taichi/aot/module.py +++ b/python/taichi/aot/module.py @@ -1,9 +1,13 @@ from contextlib import contextmanager +from pathlib import Path, PurePosixPath from taichi.lang import impl, kernel_impl +from taichi.lang._ndarray import ScalarNdarray +from taichi.lang.enums import Layout from taichi.lang.field import ScalarField -from taichi.lang.matrix import MatrixField -from taichi.type.annotations import ArgAnyArray, template +from taichi.lang.matrix import MatrixField, MatrixNdarray, VectorNdarray +from taichi.types.annotations import ArgAnyArray, template +from taichi.types.primitive_types import f32 class KernelTemplate: @@ -14,11 +18,11 @@ def __init__(self, kernel_fn, aot_module): @staticmethod def keygen(v, key_p, fields): if isinstance(v, (int, float, bool)): - key_p += '=' + str(v) + '/' + key_p += '=' + str(v) + ',' return key_p for ky, val in fields: - if (val is v): - key_p += '=' + ky + '/' + if val is v: + key_p += '=' + ky + ',' return key_p raise RuntimeError('Arg type must be of type int/float/boolean' + 'or taichi field. Type ' + str(type(v)) + @@ -36,8 +40,7 @@ def instantiate(self, **kwargs): for index, (key, value) in enumerate(kwargs.items()): template_args[index] = (key, value) - for i in range(len(kernel.argument_annotations)): - anno = kernel.argument_annotations[i] + for anno in kernel.argument_annotations: if isinstance(anno, template): (k, v) = template_args[anno_index] key_p += k @@ -74,12 +77,18 @@ class Module: # for running ``foo`` and ``bar``. """ def __init__(self, arch): + """Creates a new AOT module instance + + Args: + arch: Target backend architecture. This is ignored for now. The AOT + backend still uses the one specified in :func:`~taichi.lang.init`. + """ self._arch = arch self._kernels = [] self._fields = {} - impl.get_runtime().materialize() - self._aot_builder = impl.get_runtime().prog.make_aot_module_builder( - arch) + rtm = impl.get_runtime() + rtm._finalize_root_fb_for_aot() + self._aot_builder = rtm.prog.make_aot_module_builder(arch) def add_field(self, name, field): """Add a taichi field to the AOT module. @@ -88,16 +97,15 @@ def add_field(self, name, field): name: name of taichi field field: taichi field - Example: - Usage:: - - a = ti.field(ti.f32, shape=(4,4)) - b = ti.field("something") + Example:: - m.add_field(a) - m.add_field(b) - - # Must add in sequence + >>> a = ti.field(ti.f32, shape=(4,4)) + >>> b = ti.field("something") + >>> + >>> m.add_field(a) + >>> m.add_field(b) + >>> + >>> # Must add in sequence """ is_scalar = True self._fields[name] = field @@ -109,32 +117,61 @@ def add_field(self, name, field): column_num = field.n else: assert isinstance(field, ScalarField) - self._aot_builder.add_field(name, is_scalar, field.dtype, - field.snode.shape, row_num, column_num) + self._aot_builder.add_field(name, field.snode.ptr, is_scalar, + field.dtype, field.snode.shape, row_num, + column_num) - def add_kernel(self, kernel_fn, name=None): + def add_kernel(self, kernel_fn, example_any_arrays=None, name=None): """Add a taichi kernel to the AOT module. Args: kernel_fn (Function): the function decorated by taichi `kernel`. + example_any_arrays (Dict[int, ti.ndarray]): a dict where key is arg_id and key is example any_arr input. name (str): Name to identify this kernel in the module. If not provided, uses the built-in ``__name__`` attribute of `kernel_fn`. - TODO: - * Support external array """ name = name or kernel_fn.__name__ kernel = kernel_fn._primal assert isinstance(kernel, kernel_impl.Kernel) injected_args = [] - for i in range(len(kernel.argument_annotations)): - anno = kernel.argument_annotations[i] + num_arr = len([ + anno for anno in kernel.argument_annotations + if isinstance(anno, ArgAnyArray) + ]) + assert example_any_arrays is None or num_arr == len( + example_any_arrays + ), f'Need {num_arr} example any_arr inputs but got {len(example_any_arrays)}' + i = 0 + for anno in kernel.argument_annotations: if isinstance(anno, ArgAnyArray): - raise RuntimeError( - 'Arg type `ext_arr`/`any_arr` not supported yet') + if example_any_arrays: + injected_args.append(example_any_arrays[i]) + else: + assert anno.element_shape is not None and anno.field_dim is not None, 'Please either specify element_shape & field_dim in the kernel arg annotation or provide a dict of example ndarrays.' + if anno.element_dim == 0: + injected_args.append( + ScalarNdarray(dtype=f32, + shape=(2, ) * anno.field_dim)) + elif anno.element_dim == 1: + injected_args.append( + VectorNdarray(anno.element_shape[0], + dtype=f32, + shape=(2, ) * anno.field_dim, + layout=Layout.AOS)) + elif anno.element_dim == 2: + injected_args.append( + MatrixNdarray(anno.element_shape[0], + anno.element_shape[1], + dtype=f32, + shape=(2, ) * anno.field_dim, + layout=Layout.AOS)) + else: + raise RuntimeError('') else: # For primitive types, we can just inject a dummy value. injected_args.append(0) + i = i + 1 kernel.ensure_compiled(*injected_args) self._aot_builder.add(name, kernel.kernel_cpp) @@ -148,28 +185,27 @@ def add_kernel_template(self, kernel_fn): Args: kernel_fn (Function): the function decorated by taichi `kernel`. - Example: - Usage:: - - @ti.kernel - def bar_tmpl(a: ti.template()): - x = a - # or y = a - # do something with `x` or `y` - - m = ti.aot.Module(arch) - with m.add_kernel_template(bar_tmpl) as kt: - kt.instantiate(a=x) - kt.instantiate(a=y) - - @ti.kernel - def bar_tmpl_multiple_args(a: ti.template(), b: ti.template()) - x = a - y = b - # do something with `x` and `y` - - with m.add_kernel_template(bar_tmpl) as kt: - kt.instantiate(a=x, b=y) + Example:: + + >>> @ti.kernel + >>> def bar_tmpl(a: ti.template()): + >>> x = a + >>> # or y = a + >>> # do something with `x` or `y` + >>> + >>> m = ti.aot.Module(arch) + >>> with m.add_kernel_template(bar_tmpl) as kt: + >>> kt.instantiate(a=x) + >>> kt.instantiate(a=y) + >>> + >>> @ti.kernel + >>> def bar_tmpl_multiple_args(a: ti.template(), b: ti.template()) + >>> x = a + >>> y = b + >>> # do something with `x` and `y` + >>> + >>> with m.add_kernel_template(bar_tmpl) as kt: + >>> kt.instantiate(a=x, b=y) TODO: * Support external array @@ -178,4 +214,10 @@ def bar_tmpl_multiple_args(a: ti.template(), b: ti.template()) yield kt def save(self, filepath, filename): + """ + Args: + filepath (str): path to a folder to store aot files. + filename (str): filename prefix for stored aot files. + """ + filepath = str(PurePosixPath(Path(filepath))) self._aot_builder.dump(filepath, filename) diff --git a/python/taichi/aot/record.py b/python/taichi/aot/record.py index 0d2882d4efee2..176db9337fc3c 100644 --- a/python/taichi/aot/record.py +++ b/python/taichi/aot/record.py @@ -1,6 +1,6 @@ import os -from taichi.core import ti_core +from taichi._lib import core as ti_core def record_action_entry(name, contents): @@ -52,8 +52,4 @@ def __exit__(self, *args): __all__ = [ 'start_recording', 'stop_recording', - 'record_action_hint', - 'record_action_entry', - 'record_action_config', - 'RecordKernelGroup', ] diff --git a/python/taichi/core/__init__.py b/python/taichi/core/__init__.py deleted file mode 100644 index a19fb81c1e88e..0000000000000 --- a/python/taichi/core/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from taichi.core.util import * - -__all__ = [s for s in dir() if not s.startswith('_')] diff --git a/examples/algorithm/laplace.py b/python/taichi/examples/algorithm/laplace.py similarity index 70% rename from examples/algorithm/laplace.py rename to python/taichi/examples/algorithm/laplace.py index fbbb2c16c9c49..694acb30f80b7 100644 --- a/examples/algorithm/laplace.py +++ b/python/taichi/examples/algorithm/laplace.py @@ -19,10 +19,15 @@ def laplace(): y[i, j] = 0.0 -for i in range(10): - x[i, i + 1] = 1.0 +def main(): + for i in range(10): + x[i, i + 1] = 1.0 -laplace() + laplace() -for i in range(10): - print(y[i, i + 1]) + for i in range(10): + print(y[i, i + 1]) + + +if __name__ == '__main__': + main() diff --git a/examples/algorithm/marching_squares.py b/python/taichi/examples/algorithm/marching_squares.py similarity index 100% rename from examples/algorithm/marching_squares.py rename to python/taichi/examples/algorithm/marching_squares.py diff --git a/examples/algorithm/mciso_advanced.py b/python/taichi/examples/algorithm/mciso_advanced.py similarity index 99% rename from examples/algorithm/mciso_advanced.py rename to python/taichi/examples/algorithm/mciso_advanced.py index 8cd19a41b9bde..7ea9fa4ea305d 100644 --- a/examples/algorithm/mciso_advanced.py +++ b/python/taichi/examples/algorithm/mciso_advanced.py @@ -294,13 +294,6 @@ def __init__(self, N=64, dim=3, blk_size=None): self.use_sparse = blk_size is not None self.blk_size = blk_size - et = [self.et2, self.et3][dim - 2] - self.et = ti.Vector.field(dim, int, et.shape[:2]) - - @ti.materialize_callback - def init_et(): - self.et.from_numpy(et) - self.m = ti.field(float) # field to sample self.g = ti.Vector.field(self.dim, float) # normalized gradient indices = [ti.ij, ti.ijk][dim - 2] @@ -316,6 +309,10 @@ def init_et(): dim, float, (self.N**self.dim, self.dim)) # result buffer, TODO: optimize this + et = [self.et2, self.et3][dim - 2] + self.et = ti.Vector.field(dim, int, et.shape[:2]) + self.et.from_numpy(et) + @ti.kernel def compute_grad(self): for I in ti.grouped(self.g): diff --git a/examples/algorithm/mgpcg.py b/python/taichi/examples/algorithm/mgpcg.py similarity index 100% rename from examples/algorithm/mgpcg.py rename to python/taichi/examples/algorithm/mgpcg.py diff --git a/examples/algorithm/mgpcg_advanced.py b/python/taichi/examples/algorithm/mgpcg_advanced.py similarity index 100% rename from examples/algorithm/mgpcg_advanced.py rename to python/taichi/examples/algorithm/mgpcg_advanced.py diff --git a/python/taichi/examples/algorithm/print_offset.py b/python/taichi/examples/algorithm/print_offset.py new file mode 100644 index 0000000000000..d7fa00d262194 --- /dev/null +++ b/python/taichi/examples/algorithm/print_offset.py @@ -0,0 +1,49 @@ +from taichi.lang import impl + +import taichi as ti + +ti.init(arch=ti.cpu, print_ir=True) + +n = 4 +m = 8 + +a = ti.field(dtype=ti.i32) +ti.root.dense(ti.ij, (1, 2)).dense(ti.ij, 2).dense(ti.ij, 2).place(a) + + +@ti.kernel +def fill(): + for i, j in a: + base = ti.get_addr(a.snode, [0, 0]) + a[i, j] = int(ti.get_addr(a.snode, [i, j]) - base) // 4 + + +def main(): + fill() + print(a.to_numpy()) + + impl.get_runtime().prog.visualize_layout('layout.pdf') + + gui = ti.GUI('layout', res=(256, 512), background_color=0xFFFFFF) + + while True: + for i in range(1, m): + gui.line(begin=(0, i / m), + end=(1, i / m), + radius=2, + color=0x000000) + for i in range(1, n): + gui.line(begin=(i / n, 0), + end=(i / n, 1), + radius=2, + color=0x000000) + for i in range(n): + for j in range(m): + gui.text(f'{a[i, j]}', ((i + 0.3) / n, (j + 0.75) / m), + font_size=30, + color=0x0) + gui.show() + + +if __name__ == '__main__': + main() diff --git a/python/taichi/examples/autodiff/minimization.py b/python/taichi/examples/autodiff/minimization.py new file mode 100644 index 0000000000000..1a5f2708e05e5 --- /dev/null +++ b/python/taichi/examples/autodiff/minimization.py @@ -0,0 +1,44 @@ +import random + +import taichi as ti + +ti.init(arch=ti.cpu) + +n = 8 +x = ti.field(dtype=ti.f32, shape=n, needs_grad=True) +y = ti.field(dtype=ti.f32, shape=n) +L = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + + +@ti.kernel +def reduce(): + for i in range(n): + L[None] += 0.5 * (x[i] - y[i])**2 + + +@ti.kernel +def gradient_descent(): + for i in x: + x[i] -= x.grad[i] * 0.1 + + +def main(): + # Initialize vectors + for i in range(n): + x[i] = random.random() + y[i] = random.random() + + # Optimize with 100 gradient descent iterations + for k in range(100): + with ti.Tape(loss=L): + reduce() + print('Loss =', L[None]) + gradient_descent() + + for i in range(n): + # Now you should approximately have x[i] == y[i] + print(x[i], y[i]) + + +if __name__ == '__main__': + main() diff --git a/python/taichi/examples/autodiff/regression.py b/python/taichi/examples/autodiff/regression.py new file mode 100644 index 0000000000000..c70273e973903 --- /dev/null +++ b/python/taichi/examples/autodiff/regression.py @@ -0,0 +1,104 @@ +import random + +import matplotlib.pyplot as plt +import numpy as np + +import taichi as ti + +ti.init(arch=ti.cpu) + +number_coeffs = 4 +learning_rate = 1e-4 + +N = 32 +x, y = ti.field(ti.f32, shape=N, needs_grad=True), ti.field(ti.f32, + shape=N, + needs_grad=True) +coeffs = ti.field(ti.f32, shape=number_coeffs, needs_grad=True) +loss = ti.field(ti.f32, shape=(), needs_grad=True) + + +@ti.kernel +def regress(): + for i in x: + v = x[i] + est = 0.0 + for j in ti.static(range(number_coeffs)): + est += coeffs[j] * (v**j) + loss[None] += 0.5 * (y[i] - est)**2 + + +@ti.kernel +def update(): + for i in ti.static(range(number_coeffs)): + coeffs[i] -= learning_rate * coeffs.grad[i] + + +xs = [] +ys = [] + + +def initialize(): + for i in range(N): + v = random.random() * 5 - 2.5 + xs.append(v) + x[i] = v + y[i] = (v - 1) * (v - 2) * (v + 2) + random.random() - 0.5 + + regress() + + print('y') + for i in range(N): + y.grad[i] = 1 + ys.append(y[i]) + print() + + +def regress_raw(): + use_tape = True + + for i in range(1000): + if use_tape: + with ti.Tape(loss=loss): + regress() + else: + ti.clear_all_gradients() + loss[None] = 0 + loss.grad[None] = 1 + regress() + regress.grad() + print('Loss =', loss[None]) + update() + for i in range(number_coeffs): + print(coeffs[i], end=', ') + print() + + +def draw(): + curve_xs = np.arange(-2.5, 2.5, 0.01) + curve_ys = curve_xs * 0 + for i in range(number_coeffs): + curve_ys += coeffs[i] * np.power(curve_xs, i) + + plt.title( + 'Nonlinear Regression with Gradient Descent (3rd order polynomial)') + ax = plt.gca() + ax.scatter(xs, ys, label='data', color='r') + ax.plot(curve_xs, curve_ys, label='fitted') + ax.legend() + ax.grid(True) + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + plt.show() + + +def main(): + initialize() + regress_raw() + draw() + + +if __name__ == '__main__': + main() diff --git a/python/taichi/examples/autodiff/simple_derivative.py b/python/taichi/examples/autodiff/simple_derivative.py new file mode 100644 index 0000000000000..3a049e71e2129 --- /dev/null +++ b/python/taichi/examples/autodiff/simple_derivative.py @@ -0,0 +1,69 @@ +import matplotlib.pyplot as plt + +import taichi as ti + +ti.init(arch=ti.cpu) + +N = 2048 +x, y = ti.field(ti.f32), ti.field(ti.f32) + +ti.root.dense(ti.i, N).place(x, x.grad, y, y.grad) + + +@ti.kernel +def poly(): + for i in x: + v = x[i] + ret = 0.0 + guard = 0.2 + if v < -guard or v > guard: + ret = 4 / ti.max(v, 0.1) + else: + ret = 0 + y[i] = ret + + +xs = [] +ys = [] +grad_xs = [] + + +def initialize(): + for i in range(N): + v = ((i + 0.5) / N) * 2 - 1 + xs.append(v) + x[i] = v + + poly() + + for i in range(N): + y.grad[i] = 1 + ys.append(y[i]) + + poly.grad() + print('grad_x') + for i in range(N): + grad_xs.append(x.grad[i]) + + +def draw(): + plt.title('Auto Diff') + ax = plt.gca() + ax.plot(xs, ys, label='f(x)') + ax.plot(xs, grad_xs, label='f\'(x)') + ax.legend() + ax.grid(True) + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + plt.show() + + +def main(): + initialize() + draw() + + +if __name__ == '__main__': + main() diff --git a/examples/features/gui/fullscreen.py b/python/taichi/examples/features/gui/fullscreen.py similarity index 100% rename from examples/features/gui/fullscreen.py rename to python/taichi/examples/features/gui/fullscreen.py diff --git a/examples/features/gui/gui_image_io.py b/python/taichi/examples/features/gui/gui_image_io.py similarity index 100% rename from examples/features/gui/gui_image_io.py rename to python/taichi/examples/features/gui/gui_image_io.py diff --git a/examples/features/gui/gui_widgets.py b/python/taichi/examples/features/gui/gui_widgets.py similarity index 100% rename from examples/features/gui/gui_widgets.py rename to python/taichi/examples/features/gui/gui_widgets.py diff --git a/examples/features/gui/keyboard.py b/python/taichi/examples/features/gui/keyboard.py similarity index 100% rename from examples/features/gui/keyboard.py rename to python/taichi/examples/features/gui/keyboard.py diff --git a/examples/features/io/export_mesh.py b/python/taichi/examples/features/io/export_mesh.py similarity index 100% rename from examples/features/io/export_mesh.py rename to python/taichi/examples/features/io/export_mesh.py diff --git a/examples/features/io/export_ply.py b/python/taichi/examples/features/io/export_ply.py similarity index 100% rename from examples/features/io/export_ply.py rename to python/taichi/examples/features/io/export_ply.py diff --git a/examples/features/io/export_videos.py b/python/taichi/examples/features/io/export_videos.py similarity index 100% rename from examples/features/io/export_videos.py rename to python/taichi/examples/features/io/export_videos.py diff --git a/examples/features/sparse/explicit_activation.py b/python/taichi/examples/features/sparse/explicit_activation.py similarity index 100% rename from examples/features/sparse/explicit_activation.py rename to python/taichi/examples/features/sparse/explicit_activation.py diff --git a/examples/features/sparse/taichi_bitmasked.py b/python/taichi/examples/features/sparse/taichi_bitmasked.py similarity index 100% rename from examples/features/sparse/taichi_bitmasked.py rename to python/taichi/examples/features/sparse/taichi_bitmasked.py diff --git a/examples/features/sparse/taichi_dynamic.py b/python/taichi/examples/features/sparse/taichi_dynamic.py similarity index 100% rename from examples/features/sparse/taichi_dynamic.py rename to python/taichi/examples/features/sparse/taichi_dynamic.py diff --git a/examples/features/sparse/taichi_sparse.py b/python/taichi/examples/features/sparse/taichi_sparse.py similarity index 93% rename from examples/features/sparse/taichi_sparse.py rename to python/taichi/examples/features/sparse/taichi_sparse.py index 51039a55704b3..ebd8ea9091a6c 100644 --- a/examples/features/sparse/taichi_sparse.py +++ b/python/taichi/examples/features/sparse/taichi_sparse.py @@ -1,3 +1,5 @@ +from taichi.examples.patterns import taichi_logo + import taichi as ti ti.init(arch=ti.cuda) @@ -19,7 +21,7 @@ def activate(t: ti.f32): p = ti.Vector([i, j]) / n p = ti.Matrix.rotation2d(ti.sin(t)) @ (p - 0.5) + 0.5 - if ti.taichi_logo(p) == 0: + if taichi_logo(p) == 0: x[i, j] = 1 diff --git a/examples/features/sparse/tutorial.py b/python/taichi/examples/features/sparse/tutorial.py similarity index 100% rename from examples/features/sparse/tutorial.py rename to python/taichi/examples/features/sparse/tutorial.py diff --git a/examples/ggui_examples/fem128_ggui.py b/python/taichi/examples/ggui_examples/fem128_ggui.py similarity index 98% rename from examples/ggui_examples/fem128_ggui.py rename to python/taichi/examples/ggui_examples/fem128_ggui.py index 895bd180648c5..187befdcfa205 100644 --- a/examples/ggui_examples/fem128_ggui.py +++ b/python/taichi/examples/ggui_examples/fem128_ggui.py @@ -1,6 +1,7 @@ import taichi as ti -ti.init(arch=ti.gpu) +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) N = 12 dt = 5e-5 diff --git a/examples/ggui_examples/fractal3d_ggui.py b/python/taichi/examples/ggui_examples/fractal3d_ggui.py similarity index 98% rename from examples/ggui_examples/fractal3d_ggui.py rename to python/taichi/examples/ggui_examples/fractal3d_ggui.py index 5c283e7e4efb8..deedb922703f4 100644 --- a/examples/ggui_examples/fractal3d_ggui.py +++ b/python/taichi/examples/ggui_examples/fractal3d_ggui.py @@ -1,6 +1,7 @@ import taichi as ti -ti.init(ti.cuda) +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) @ti.func diff --git a/examples/ggui_examples/mass_spring_3d_ggui.py b/python/taichi/examples/ggui_examples/mass_spring_3d_ggui.py similarity index 97% rename from examples/ggui_examples/mass_spring_3d_ggui.py rename to python/taichi/examples/ggui_examples/mass_spring_3d_ggui.py index f9116d2ee0f58..2cfae527a3333 100644 --- a/examples/ggui_examples/mass_spring_3d_ggui.py +++ b/python/taichi/examples/ggui_examples/mass_spring_3d_ggui.py @@ -1,6 +1,7 @@ import taichi as ti -ti.init(arch=ti.cuda) # Alternatively, ti.init(arch=ti.cpu) +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) N = 128 cell_size = 1.0 / N diff --git a/examples/ggui_examples/mass_spring_game_ggui.py b/python/taichi/examples/ggui_examples/mass_spring_game_ggui.py similarity index 98% rename from examples/ggui_examples/mass_spring_game_ggui.py rename to python/taichi/examples/ggui_examples/mass_spring_game_ggui.py index 59314ba062212..d3a392ad73a9e 100644 --- a/examples/ggui_examples/mass_spring_game_ggui.py +++ b/python/taichi/examples/ggui_examples/mass_spring_game_ggui.py @@ -1,6 +1,7 @@ import taichi as ti -ti.init(arch=ti.cuda) +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) spring_Y = ti.field(dtype=ti.f32, shape=()) # Young's modulus paused = ti.field(dtype=ti.i32, shape=()) diff --git a/examples/ggui_examples/mpm128_ggui.py b/python/taichi/examples/ggui_examples/mpm128_ggui.py similarity index 98% rename from examples/ggui_examples/mpm128_ggui.py rename to python/taichi/examples/ggui_examples/mpm128_ggui.py index 279af236b6960..7c1eb5a6345f8 100644 --- a/examples/ggui_examples/mpm128_ggui.py +++ b/python/taichi/examples/ggui_examples/mpm128_ggui.py @@ -2,7 +2,8 @@ import taichi as ti -ti.init(arch=ti.gpu) # Try to run on GPU +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) quality = 1 # Use a larger value for higher-res simulations n_particles, n_grid = 9000 * quality**2, 128 * quality diff --git a/examples/ggui_examples/mpm3d_ggui.py b/python/taichi/examples/ggui_examples/mpm3d_ggui.py similarity index 95% rename from examples/ggui_examples/mpm3d_ggui.py rename to python/taichi/examples/ggui_examples/mpm3d_ggui.py index 65639580a5366..0e48dfcefba14 100644 --- a/examples/ggui_examples/mpm3d_ggui.py +++ b/python/taichi/examples/ggui_examples/mpm3d_ggui.py @@ -2,7 +2,8 @@ import taichi as ti -ti.init(ti.cuda) +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) #dim, n_grid, steps, dt = 2, 128, 20, 2e-4 #dim, n_grid, steps, dt = 2, 256, 32, 1e-4 @@ -35,8 +36,8 @@ shape=n_particles) # deformation gradient Jp = ti.field(float, n_particles) -colors = ti.Vector.field(3, float, n_particles) -colors_random = ti.Vector.field(3, float, n_particles) +colors = ti.Vector.field(4, float, n_particles) +colors_random = ti.Vector.field(4, float, n_particles) materials = ti.field(int, n_particles) grid_v = ti.Vector.field(dim, float, (n_grid, ) * dim) grid_m = ti.field(float, (n_grid, ) * dim) @@ -110,8 +111,8 @@ def substep(g_x: float, g_y: float, g_z: float): if grid_m[I] > 0: grid_v[I] /= grid_m[I] grid_v[I] += dt * ti.Vector([g_x, g_y, g_z]) - cond = I < bound and grid_v[I] < 0 or I > n_grid - bound and grid_v[ - I] > 0 + cond = (I < bound) & (grid_v[I] < 0) | \ + (I > n_grid - bound) & (grid_v[I] > 0) grid_v[I] = 0 if cond else grid_v[I] ti.block_dim(n_grid) for p in x: @@ -155,7 +156,9 @@ def init_cube_vol(first_par: int, last_par: int, x_begin: float, F[i] = ti.Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) v[i] = ti.Vector([0.0, 0.0, 0.0]) materials[i] = material - colors_random[i] = ti.Vector([ti.random(), ti.random(), ti.random()]) + colors_random[i] = ti.Vector( + [ti.random(), ti.random(), + ti.random(), ti.random()]) used[i] = 1 @@ -199,7 +202,7 @@ def set_color_by_material(material_colors: ti.ext_arr()): mat = materials[i] colors[i] = ti.Vector([ material_colors[mat, 0], material_colors[mat, 1], - material_colors[mat, 2] + material_colors[mat, 2], 1.0 ]) @@ -292,7 +295,7 @@ def show_options(): "snow color", material_colors[SNOW]) material_colors[JELLY] = window.GUI.color_edit_3( "jelly color", material_colors[JELLY]) - set_color_by_material(np.array(material_colors)) + set_color_by_material(np.array(material_colors, dtype=np.float32)) particles_radius = window.GUI.slider_float("particles radius ", particles_radius, 0, 0.1) if window.GUI.button("restart"): diff --git a/examples/ggui_examples/stable_fluid_ggui.py b/python/taichi/examples/ggui_examples/stable_fluid_ggui.py similarity index 98% rename from examples/ggui_examples/stable_fluid_ggui.py rename to python/taichi/examples/ggui_examples/stable_fluid_ggui.py index caf5fc74a6c6e..16536eb2dc9a6 100644 --- a/examples/ggui_examples/stable_fluid_ggui.py +++ b/python/taichi/examples/ggui_examples/stable_fluid_ggui.py @@ -21,7 +21,8 @@ debug = False paused = False -ti.init(arch=ti.cuda) +arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda +ti.init(arch=arch) _velocities = ti.Vector.field(2, float, shape=(res, res)) _new_velocities = ti.Vector.field(2, float, shape=(res, res)) diff --git a/examples/minimal.py b/python/taichi/examples/minimal.py similarity index 100% rename from examples/minimal.py rename to python/taichi/examples/minimal.py diff --git a/python/taichi/tools/patterns.py b/python/taichi/examples/patterns.py similarity index 100% rename from python/taichi/tools/patterns.py rename to python/taichi/examples/patterns.py diff --git a/examples/rendering/cornell_box.py b/python/taichi/examples/rendering/cornell_box.py similarity index 94% rename from examples/rendering/cornell_box.py rename to python/taichi/examples/rendering/cornell_box.py index 9560b09eebff5..bc3fff73805b3 100644 --- a/examples/rendering/cornell_box.py +++ b/python/taichi/examples/rendering/cornell_box.py @@ -485,33 +485,29 @@ def render(): @ti.kernel -def tonemap(accumulated: ti.f32) -> ti.f32: - sum = 0.0 - sum_sq = 0.0 - for i, j in color_buffer: - luma = color_buffer[i, j][0] * 0.2126 + color_buffer[ - i, j][1] * 0.7152 + color_buffer[i, j][2] * 0.0722 - sum += luma - sum_sq += ti.pow(luma / accumulated, 2.0) - mean = sum / (res[0] * res[1]) - var = sum_sq / (res[0] * res[1]) - ti.pow(mean / accumulated, 2.0) +def tonemap(accumulated: ti.f32): for i, j in tonemapped_buffer: - tonemapped_buffer[i, j] = ti.sqrt(color_buffer[i, j] / mean * 0.6) - return var - - -gui = ti.GUI('Cornell Box', res, fast_gui=True) -gui.fps_limit = 300 -last_t = time.time() -i = 0 -while gui.running: - render() - interval = 10 - if i % interval == 0: - var = tonemap(i) - print("{:.2f} samples/s ({} iters, var={})".format( - interval / (time.time() - last_t), i, var)) - last_t = time.time() - gui.set_image(tonemapped_buffer) - gui.show() - i += 1 + tonemapped_buffer[i, j] = ti.sqrt(color_buffer[i, j] / accumulated * + 100.0) + + +def main(): + gui = ti.GUI('Cornell Box', res, fast_gui=True) + gui.fps_limit = 300 + last_t = time.time() + i = 0 + while gui.running: + render() + interval = 10 + if i % interval == 0: + tonemap(i) + print("{:.2f} samples/s ({} iters)".format( + interval / (time.time() - last_t), i)) + last_t = time.time() + gui.set_image(tonemapped_buffer) + gui.show() + i += 1 + + +if __name__ == '__main__': + main() diff --git a/examples/rendering/rasterizer.py b/python/taichi/examples/rendering/rasterizer.py similarity index 100% rename from examples/rendering/rasterizer.py rename to python/taichi/examples/rendering/rasterizer.py diff --git a/examples/rendering/sdf_renderer.py b/python/taichi/examples/rendering/sdf_renderer.py similarity index 100% rename from examples/rendering/sdf_renderer.py rename to python/taichi/examples/rendering/sdf_renderer.py diff --git a/examples/rendering/simple_uv.py b/python/taichi/examples/rendering/simple_uv.py similarity index 100% rename from examples/rendering/simple_uv.py rename to python/taichi/examples/rendering/simple_uv.py diff --git a/python/taichi/examples/rendering/taichi_logo.py b/python/taichi/examples/rendering/taichi_logo.py new file mode 100644 index 0000000000000..40643f1f4e657 --- /dev/null +++ b/python/taichi/examples/rendering/taichi_logo.py @@ -0,0 +1,29 @@ +from taichi.examples.patterns import taichi_logo + +import taichi as ti + +ti.init() + +n = 512 +x = ti.field(dtype=ti.f32, shape=(n, n)) + + +@ti.kernel +def paint(): + for i, j in ti.ndrange(n * 4, n * 4): + # 4x4 super sampling: + ret = taichi_logo(ti.Vector([i, j]) / (n * 4)) + x[i // 4, j // 4] += ret / 16 + + +def main(): + paint() + + gui = ti.GUI('Logo', (n, n)) + while gui.running: + gui.set_image(x) + gui.show() + + +if __name__ == '__main__': + main() diff --git a/examples/simulation/ad_gravity.py b/python/taichi/examples/simulation/ad_gravity.py similarity index 82% rename from examples/simulation/ad_gravity.py rename to python/taichi/examples/simulation/ad_gravity.py index 3c3d19089b449..811346dd36756 100644 --- a/examples/simulation/ad_gravity.py +++ b/python/taichi/examples/simulation/ad_gravity.py @@ -43,10 +43,15 @@ def init(): x[i] = [ti.random(), ti.random()] -init() -gui = ti.GUI('Autodiff gravity') -while gui.running: - for i in range(50): - substep() - gui.circles(x.to_numpy(), radius=3) - gui.show() +def main(): + init() + gui = ti.GUI('Autodiff gravity') + while gui.running: + for i in range(50): + substep() + gui.circles(x.to_numpy(), radius=3) + gui.show() + + +if __name__ == "__main__": + main() diff --git a/examples/simulation/comet.py b/python/taichi/examples/simulation/comet.py similarity index 100% rename from examples/simulation/comet.py rename to python/taichi/examples/simulation/comet.py diff --git a/examples/simulation/euler.py b/python/taichi/examples/simulation/euler.py similarity index 99% rename from examples/simulation/euler.py rename to python/taichi/examples/simulation/euler.py index 994ecf74522ae..f59b8da013c2f 100644 --- a/examples/simulation/euler.py +++ b/python/taichi/examples/simulation/euler.py @@ -215,12 +215,12 @@ def HLLC_flux(qL, qR, n): HLLC = ti.Vector([0.0, 0.0, 0.0, 0.0]) if (0 <= sL): HLLC = fL - elif (sL <= 0) and (0 <= sM): + elif (0 <= sM): qsL = rL * (sL-vnL)/(sL-sM) \ * ti.Vector([1.0, sM*nx-vtL*ny,sM*ny+vtL*nx, \ qL[3]/rL + (sM-vnL)*(sM+pL/(rL*(sL-vnL)))]) HLLC = fL + sL * (qsL - qL) - elif (sM <= 0) and (0 <= sR): + elif (0 <= sR): qsR = rR * (sR-vnR)/(sR-sM) \ * ti.Vector([1.0, sM*nx-vtR*ny,sM*ny+vtR*nx, \ qR[3]/rR + (sM-vnR)*(sM+pR/(rR*(sR-vnR)))]) @@ -283,9 +283,7 @@ def sign(a): return sgn -ti.func - - +@ti.func def cosh(a): return (ti.exp(a) + ti.exp(-a)) / 2.0 diff --git a/examples/simulation/fem128.py b/python/taichi/examples/simulation/fem128.py similarity index 98% rename from examples/simulation/fem128.py rename to python/taichi/examples/simulation/fem128.py index 07a2d3b3a56a1..09062dff1cbac 100644 --- a/examples/simulation/fem128.py +++ b/python/taichi/examples/simulation/fem128.py @@ -60,7 +60,7 @@ def advance(): if disp2 <= ball_radius**2: NoV = vel[i].dot(disp) if NoV < 0: vel[i] -= NoV * disp / disp2 - cond = pos[i] < 0 and vel[i] < 0 or pos[i] > 1 and vel[i] > 0 + cond = (pos[i] < 0) & (vel[i] < 0) | (pos[i] > 1) & (vel[i] > 0) # rect boundary condition: for j in ti.static(range(pos.n)): if cond[j]: vel[i][j] = 0 diff --git a/examples/simulation/fem99.py b/python/taichi/examples/simulation/fem99.py similarity index 97% rename from examples/simulation/fem99.py rename to python/taichi/examples/simulation/fem99.py index 304d816c606cb..4683096178288 100644 --- a/examples/simulation/fem99.py +++ b/python/taichi/examples/simulation/fem99.py @@ -56,7 +56,7 @@ def advance(): NoV = vel[i].dot(disp) if NoV < 0: vel[i] -= NoV * disp / disp2 # rect boundary condition: - cond = pos[i] < 0 and vel[i] < 0 or pos[i] > 1 and vel[i] > 0 + cond = (pos[i] < 0) & (vel[i] < 0) | (pos[i] > 1) & (vel[i] > 0) for j in ti.static(range(pos.n)): if cond[j]: vel[i][j] = 0 pos[i] += dt * vel[i] diff --git a/examples/simulation/fractal.py b/python/taichi/examples/simulation/fractal.py similarity index 100% rename from examples/simulation/fractal.py rename to python/taichi/examples/simulation/fractal.py diff --git a/examples/simulation/game_of_life.py b/python/taichi/examples/simulation/game_of_life.py similarity index 56% rename from examples/simulation/game_of_life.py rename to python/taichi/examples/simulation/game_of_life.py index e5481b072e429..4d38ac319bf4a 100644 --- a/examples/simulation/game_of_life.py +++ b/python/taichi/examples/simulation/game_of_life.py @@ -58,31 +58,37 @@ def init(): alive[i, j] = 0 -gui = ti.GUI('Game of Life', (img_size, img_size)) -gui.fps_limit = 15 - -print('[Hint] Press `r` to reset') -print('[Hint] Press SPACE to pause') -print('[Hint] Click LMB, RMB and drag to add alive / dead cells') - -init() -paused = False -while gui.running: - for e in gui.get_events(gui.PRESS, gui.MOTION): - if e.key == gui.ESCAPE: - gui.running = False - elif e.key == gui.SPACE: - paused = not paused - elif e.key == 'r': - alive.fill(0) - - if gui.is_pressed(gui.LMB, gui.RMB): - mx, my = gui.get_cursor_pos() - alive[int(mx * n), int(my * n)] = gui.is_pressed(gui.LMB) - paused = True - - if not paused: - run() - - gui.set_image(ti.imresize(alive, img_size).astype(np.uint8) * 255) - gui.show() +def main(): + gui = ti.GUI('Game of Life', (img_size, img_size)) + gui.fps_limit = 15 + + print('[Hint] Press `r` to reset') + print('[Hint] Press SPACE to pause') + print('[Hint] Click LMB, RMB and drag to add alive / dead cells') + + init() + paused = False + while gui.running: + for e in gui.get_events(gui.PRESS, gui.MOTION): + if e.key == gui.ESCAPE: + gui.running = False + elif e.key == gui.SPACE: + paused = not paused + elif e.key == 'r': + alive.fill(0) + + if gui.is_pressed(gui.LMB, gui.RMB): + mx, my = gui.get_cursor_pos() + alive[int(mx * n), int(my * n)] = gui.is_pressed(gui.LMB) + paused = True + + if not paused: + run() + + gui.set_image( + ti.tools.imresize(alive, img_size).astype(np.uint8) * 255) + gui.show() + + +if __name__ == '__main__': + main() diff --git a/python/taichi/examples/simulation/implicit_fem.py b/python/taichi/examples/simulation/implicit_fem.py new file mode 100644 index 0000000000000..ad5c5c6c77637 --- /dev/null +++ b/python/taichi/examples/simulation/implicit_fem.py @@ -0,0 +1,353 @@ +import argparse + +import numpy as np +from taichi._lib import core as _ti_core + +import taichi as ti + +parser = argparse.ArgumentParser() +parser.add_argument('--exp', + choices=['implicit', 'explicit'], + default='implicit') +parser.add_argument('--dim', type=int, default=3) +parser.add_argument('--gui', choices=['auto', 'ggui', 'cpu'], default='auto') +parser.add_argument('place_holder', nargs='*') +args = parser.parse_args() + +ti.init(arch=ti.gpu, dynamic_index=True) + +if args.gui == 'auto': + if _ti_core.GGUI_AVAILABLE: + args.gui = 'ggui' + else: + args.gui = 'cpu' + +E, nu = 5e4, 0.0 +mu, la = E / (2 * (1 + nu)), E * nu / ((1 + nu) * (1 - 2 * nu)) # lambda = 0 +density = 1000.0 +dt = 2e-4 + +if args.exp == 'implicit': + dt = 1e-2 + +n_cube = np.array([5] * 3) +n_verts = np.product(n_cube) +n_cells = 5 * np.product(n_cube - 1) +dx = 1 / (n_cube.max() - 1) + +vertices = ti.Vector.field(4, dtype=ti.i32, shape=n_cells) + +x = ti.Vector.field(args.dim, dtype=ti.f32, shape=n_verts) +ox = ti.Vector.field(args.dim, dtype=ti.f32, shape=n_verts) +v = ti.Vector.field(args.dim, dtype=ti.f32, shape=n_verts) +f = ti.Vector.field(args.dim, dtype=ti.f32, shape=n_verts) +mul_ans = ti.Vector.field(args.dim, dtype=ti.f32, shape=n_verts) +m = ti.field(dtype=ti.f32, shape=n_verts) + +n_cells = (n_cube - 1).prod() * 5 +B = ti.Matrix.field(args.dim, args.dim, dtype=ti.f32, shape=n_cells) +W = ti.field(dtype=ti.f32, shape=n_cells) + + +@ti.func +def i2p(I): + return (I.x * n_cube[1] + I.y) * n_cube[2] + I.z + + +@ti.func +def set_element(e, I, verts): + for i in ti.static(range(args.dim + 1)): + vertices[e][i] = i2p(I + (([verts[i] >> k for k in range(3)] ^ I) & 1)) + + +@ti.kernel +def get_vertices(): + ''' + This kernel partitions the cube into tetrahedrons. + Each unit cube is divided into 5 tetrahedrons. + ''' + for I in ti.grouped(ti.ndrange(*(n_cube - 1))): + e = ((I.x * (n_cube[1] - 1) + I.y) * (n_cube[2] - 1) + I.z) * 5 + for i, j in ti.static(enumerate([0, 3, 5, 6])): + set_element(e + i, I, (j, j ^ 1, j ^ 2, j ^ 4)) + set_element(e + 4, I, (1, 2, 4, 7)) + for I in ti.grouped(ti.ndrange(*(n_cube))): + ox[i2p(I)] = I * dx + + +@ti.func +def Ds(verts): + return ti.Matrix.cols([x[verts[i]] - x[verts[3]] for i in range(3)]) + + +@ti.func +def ssvd(F): + U, sig, V = ti.svd(F) + if U.determinant() < 0: + for i in ti.static(range(3)): + U[i, 2] *= -1 + sig[2, 2] = -sig[2, 2] + if V.determinant() < 0: + for i in ti.static(range(3)): + V[i, 2] *= -1 + sig[2, 2] = -sig[2, 2] + return U, sig, V + + +@ti.func +def get_force_func(c, verts): + F = Ds(verts) @ B[c] + P = ti.Matrix.zero(ti.f32, 3, 3) + U, sig, V = ssvd(F) + P = 2 * mu * (F - U @ V.transpose()) + H = -W[c] * P @ B[c].transpose() + for i in ti.static(range(3)): + force = ti.Vector([H[j, i] for j in range(3)]) + f[verts[i]] += force + f[verts[3]] -= force + + +@ti.kernel +def get_force(): + for c in vertices: + get_force_func(c, vertices[c]) + for u in f: + f[u].y -= 9.8 * m[u] + + +@ti.kernel +def matmul_cell(ret: ti.template(), vel: ti.template()): + for i in ret: + ret[i] = vel[i] * m[i] + for c in vertices: + verts = vertices[c] + W_c = W[c] + B_c = B[c] + for u in range(4): + for d in range(3): + dD = ti.Matrix.zero(ti.f32, 3, 3) + if u == 3: + for j in range(3): + dD[d, j] = -1 + else: + dD[d, u] = 1 + dF = dD @ B_c + dP = 2.0 * mu * dF + dH = -W_c * dP @ B_c.transpose() + for i in range(3): + for j in range(3): + tmp = (vel[verts[i]][j] - vel[verts[3]][j]) + ret[verts[u]][d] += -dt**2 * dH[j, i] * tmp + + +@ti.kernel +def add(ans: ti.template(), a: ti.template(), k: ti.f32, b: ti.template()): + for i in ans: + ans[i] = a[i] + k * b[i] + + +@ti.kernel +def dot(a: ti.template(), b: ti.template()) -> ti.f32: + ans = 0.0 + for i in a: + ans += a[i].dot(b[i]) + return ans + + +b = ti.Vector.field(3, dtype=ti.f32, shape=n_verts) +r0 = ti.Vector.field(3, dtype=ti.f32, shape=n_verts) +p0 = ti.Vector.field(3, dtype=ti.f32, shape=n_verts) + + +@ti.kernel +def get_b(): + for i in b: + b[i] = m[i] * v[i] + dt * f[i] + + +def cg(): + def mul(x): + matmul_cell(mul_ans, x) + return mul_ans + + get_force() + get_b() + mul(v) + add(r0, b, -1, mul(v)) + + d = p0 + d.copy_from(r0) + r_2 = dot(r0, r0) + n_iter = 50 + epsilon = 1e-6 + r_2_init = r_2 + r_2_new = r_2 + for iter in range(n_iter): + q = mul(d) + alpha = r_2_new / dot(d, q) + add(v, v, alpha, d) + add(r0, r0, -alpha, q) + r_2 = r_2_new + r_2_new = dot(r0, r0) + if r_2_new <= r_2_init * epsilon**2: break + beta = r_2_new / r_2 + add(d, r0, beta, d) + f.fill(0) + add(x, x, dt, v) + + +@ti.kernel +def advect(): + for p in x: + v[p] += dt * (f[p] / m[p]) + x[p] += dt * v[p] + f[p] = ti.Vector([0, 0, 0]) + + +@ti.kernel +def init(): + for u in x: + x[u] = ox[u] + v[u] = [0.0] * 3 + f[u] = [0.0] * 3 + m[u] = 0.0 + for c in vertices: + F = Ds(vertices[c]) + B[c] = F.inverse() + W[c] = ti.abs(F.determinant()) / 6 + for i in range(4): + m[vertices[c][i]] += W[c] / 4 * density + for u in x: + x[u].y += 1.0 + + +@ti.kernel +def floor_bound(): + for u in x: + if x[u].y < 0: + x[u].y = 0 + if v[u].y < 0: + v[u].y = 0 + + +@ti.func +def check(u): + ans = 0 + rest = u + for i in ti.static(range(3)): + k = rest % n_cube[2 - i] + rest = rest // n_cube[2 - i] + if k == 0: ans |= (1 << (i * 2)) + if k == n_cube[2 - i] - 1: ans |= (1 << (i * 2 + 1)) + return ans + + +su = 0 +for i in range(3): + su += (n_cube[i] - 1) * (n_cube[(i + 1) % 3] - 1) +indices = ti.field(ti.i32, shape=2 * su * 2 * 3) + + +@ti.kernel +def get_indices(): + # calculate all the meshes on surface + cnt = 0 + for c in vertices: + if c % 5 != 4: + for i in ti.static([0, 2, 3]): + verts = [vertices[c][(i + j) % 4] for j in range(3)] + sum = check(verts[0]) & check(verts[1]) & check(verts[2]) + if sum: + m = ti.atomic_add(cnt, 1) + det = ti.Matrix.rows([ + x[verts[i]] - [0.5, 1.5, 0.5] for i in range(3) + ]).determinant() + if det < 0: + tmp = verts[1] + verts[1] = verts[2] + verts[2] = tmp + indices[m * 3] = verts[0] + indices[m * 3 + 1] = verts[1] + indices[m * 3 + 2] = verts[2] + + +def substep(): + if args.exp == 'explicit': + for i in range(10): + get_force() + advect() + else: + for i in range(1): + cg() + floor_bound() + + +if __name__ == '__main__': + get_vertices() + init() + get_indices() + + if args.gui == 'ggui': + res = (800, 600) + window = ti.ui.Window("Implicit FEM", res, vsync=True) + + frame_id = 0 + canvas = window.get_canvas() + scene = ti.ui.Scene() + camera = ti.ui.make_camera() + camera.position(2.0, 2.0, 3.95) + camera.lookat(0.5, 0.5, 0.5) + camera.fov(55) + + def render(): + camera.track_user_inputs(window, + movement_speed=0.03, + hold_key=ti.ui.RMB) + scene.set_camera(camera) + + scene.ambient_light((0.1, ) * 3) + + scene.point_light(pos=(0.5, 10.0, 0.5), color=(0.5, 0.5, 0.5)) + scene.point_light(pos=(10.0, 10.0, 10.0), color=(0.5, 0.5, 0.5)) + + scene.mesh(x, indices, color=(0.73, 0.33, 0.23)) + + canvas.scene(scene) + + while window.running: + frame_id += 1 + frame_id = frame_id % 256 + substep() + if window.is_pressed('r'): + init() + if window.is_pressed(ti.GUI.ESCAPE): + break + + render() + + window.show() + + else: + + def T(a): + + phi, theta = np.radians(28), np.radians(32) + + a = a - 0.2 + x, y, z = a[:, 0], a[:, 1], a[:, 2] + c, s = np.cos(phi), np.sin(phi) + C, S = np.cos(theta), np.sin(theta) + x, z = x * c + z * s, z * c - x * s + u, v = x, y * C + z * S + return np.array([u, v]).swapaxes(0, 1) + 0.5 + + gui = ti.GUI('Implicit FEM') + while gui.running: + substep() + if gui.get_event(ti.GUI.PRESS): + if gui.event.key in [ti.GUI.ESCAPE, ti.GUI.EXIT]: break + if gui.is_pressed('r'): + init() + gui.clear(0x000000) + gui.circles(T(x.to_numpy() / 3), radius=1.5, color=0xba543a) + gui.show() diff --git a/examples/simulation/implicit_mass_spring.py b/python/taichi/examples/simulation/implicit_mass_spring.py similarity index 90% rename from examples/simulation/implicit_mass_spring.py rename to python/taichi/examples/simulation/implicit_mass_spring.py index 3ab120f253649..698d73df096de 100644 --- a/examples/simulation/implicit_mass_spring.py +++ b/python/taichi/examples/simulation/implicit_mass_spring.py @@ -17,7 +17,7 @@ def __init__(self, N): self.initPos = ti.Vector.field(2, ti.f32, self.NV) self.vel = ti.Vector.field(2, ti.f32, self.NV) self.force = ti.Vector.field(2, ti.f32, self.NV) - self.invMass = ti.field(ti.f32, self.NV) + self.mass = ti.field(ti.f32, self.NV) self.spring = ti.Vector.field(2, ti.i32, self.NE) self.indices = ti.field(ti.i32, 2 * self.NE) @@ -55,9 +55,7 @@ def init_pos(self): [0.25, 0.25]) self.initPos[k] = self.pos[k] self.vel[k] = ti.Vector([0, 0]) - self.invMass[k] = 1.0 - self.invMass[self.N] = 0.0 - self.invMass[self.NV - 1] = 0.0 + self.mass[k] = 1.0 @ti.kernel def init_edges(self): @@ -87,12 +85,11 @@ def init_edges(self): rest_len[idx] = (pos[idx1] - pos[idx2]).norm() @ti.kernel - def init_mass_sp(self, M: ti.linalg.sparse_matrix_builder()): + def init_mass_sp(self, M: ti.types.sparse_matrix_builder()): for i in range(self.NV): - if self.invMass[i] != 0.0: - mass = 1.0 / self.invMass[i] - M[2 * i + 0, 2 * i + 0] += mass - M[2 * i + 1, 2 * i + 1] += mass + mass = self.mass[i] + M[2 * i + 0, 2 * i + 0] += mass + M[2 * i + 1, 2 * i + 1] += mass @ti.func def clear_force(self): @@ -103,8 +100,7 @@ def clear_force(self): def compute_force(self): self.clear_force() for i in self.force: - if self.invMass[i] != 0.0: - self.force[i] += self.gravity / self.invMass[i] + self.force[i] += self.gravity * self.mass[i] for i in self.spring: idx1, idx2 = self.spring[i][0], self.spring[i][1] @@ -137,11 +133,11 @@ def compute_Jacobians(self): self.Jv[i] = self.kd * I # fix point constraint hessian - self.Jf[0] = ti.Matrix([[self.kf, 0], [0, self.kf]]) - self.Jf[1] = ti.Matrix([[self.kf, 0], [0, self.kf]]) + self.Jf[0] = ti.Matrix([[-self.kf, 0], [0, -self.kf]]) + self.Jf[1] = ti.Matrix([[-self.kf, 0], [0, -self.kf]]) @ti.kernel - def assemble_K(self, K: ti.linalg.sparse_matrix_builder()): + def assemble_K(self, K: ti.types.sparse_matrix_builder()): for i in self.spring: idx1, idx2 = self.spring[i][0], self.spring[i][1] for m, n in ti.static(ti.ndrange(2, 2)): @@ -154,7 +150,7 @@ def assemble_K(self, K: ti.linalg.sparse_matrix_builder()): K[2 * (self.NV - 1) + m, 2 * (self.NV - 1) + n] += self.Jf[1][m, n] @ti.kernel - def assemble_D(self, D: ti.linalg.sparse_matrix_builder()): + def assemble_D(self, D: ti.types.sparse_matrix_builder()): for i in self.spring: idx1, idx2 = self.spring[i][0], self.spring[i][1] for m, n in ti.static(ti.ndrange(2, 2)): @@ -166,9 +162,8 @@ def assemble_D(self, D: ti.linalg.sparse_matrix_builder()): @ti.kernel def updatePosVel(self, h: ti.f32, dv: ti.ext_arr()): for i in self.pos: - if self.invMass[i] != 0.0: - self.vel[i] += ti.Vector([dv[2 * i], dv[2 * i + 1]]) - self.pos[i] += h * self.vel[i] + self.vel[i] += ti.Vector([dv[2 * i], dv[2 * i + 1]]) + self.pos[i] += h * self.vel[i] def update(self, h): self.compute_force() @@ -233,7 +228,7 @@ def displayGGUI(self, canvas, radius=0.01, color=(1.0, 1.0, 1.0)): '--use-ggui', action='store_true', help='Display with GGUI') - args = parser.parse_args() + args, unknowns = parser.parse_known_args() use_ggui = False use_ggui = args.use_ggui diff --git a/python/taichi/examples/simulation/inital_value_problem.py b/python/taichi/examples/simulation/inital_value_problem.py new file mode 100644 index 0000000000000..289c09c7a8e50 --- /dev/null +++ b/python/taichi/examples/simulation/inital_value_problem.py @@ -0,0 +1,47 @@ +import math +import time + +import numpy as np + +import taichi as ti + + +def init(): + a = [] + for i in np.linspace(0, 1, n, False): + for j in np.linspace(0, 1, n, False): + a.append([i, j]) + return np.array(a) + + +ti.init(arch=ti.gpu) +n = 50 +dirs = ti.field(dtype=float, shape=(n * n, 2)) +locations_np = init() + +locations = ti.field(dtype=float, shape=(n * n, 2)) +locations.from_numpy(locations_np) + + +@ti.kernel +def paint(t: float): + (o, p) = locations_np.shape + for i in range(0, o): # Parallelized over all pixels + x = locations[i, 0] + y = locations[i, 1] + dirs[i, 0] = ti.sin((t * x - y)) + dirs[i, 1] = ti.cos(t * y - x) + len = (dirs[i, 0]**2 + dirs[i, 1]**2)**0.5 + dirs[i, 0] /= len * 40 + dirs[i, 1] /= len * 40 + + +gui = ti.GUI("Vector Field", res=(500, 500)) + +begining = time.time_ns() +for k in range(1000000): + start_time = time.time_ns() + paint((time.time_ns() - begining) * 0.00000001) + dirs_np = dirs.to_numpy() + gui.arrows(locations_np, dirs_np, radius=1) + gui.show() diff --git a/examples/simulation/mass_spring_game.py b/python/taichi/examples/simulation/mass_spring_game.py similarity index 100% rename from examples/simulation/mass_spring_game.py rename to python/taichi/examples/simulation/mass_spring_game.py diff --git a/examples/simulation/mpm128.py b/python/taichi/examples/simulation/mpm128.py similarity index 100% rename from examples/simulation/mpm128.py rename to python/taichi/examples/simulation/mpm128.py diff --git a/examples/simulation/mpm3d.py b/python/taichi/examples/simulation/mpm3d.py similarity index 97% rename from examples/simulation/mpm3d.py rename to python/taichi/examples/simulation/mpm3d.py index 0e63972e0120b..f7584c7159a6a 100644 --- a/examples/simulation/mpm3d.py +++ b/python/taichi/examples/simulation/mpm3d.py @@ -57,8 +57,8 @@ def substep(): if grid_m[I] > 0: grid_v[I] /= grid_m[I] grid_v[I][1] -= dt * gravity - cond = I < bound and grid_v[I] < 0 or I > n_grid - bound and grid_v[ - I] > 0 + cond = (I < bound) & (grid_v[I] < 0) | \ + (I > n_grid - bound) & (grid_v[I] > 0) grid_v[I] = 0 if cond else grid_v[I] ti.block_dim(n_grid) for p in x: diff --git a/examples/simulation/mpm88.py b/python/taichi/examples/simulation/mpm88.py similarity index 100% rename from examples/simulation/mpm88.py rename to python/taichi/examples/simulation/mpm88.py diff --git a/examples/simulation/mpm99.py b/python/taichi/examples/simulation/mpm99.py similarity index 90% rename from examples/simulation/mpm99.py rename to python/taichi/examples/simulation/mpm99.py index cf71c3877bd48..a68f91b75205d 100644 --- a/examples/simulation/mpm99.py +++ b/python/taichi/examples/simulation/mpm99.py @@ -117,14 +117,19 @@ def initialize(): Jp[i] = 1 -initialize() -gui = ti.GUI("Taichi MLS-MPM-99", res=512, background_color=0x112F41) -while not gui.get_event(ti.GUI.ESCAPE, ti.GUI.EXIT): - for s in range(int(2e-3 // dt)): - substep() - gui.circles(x.to_numpy(), - radius=1.5, - palette=[0x068587, 0xED553B, 0xEEEEF0], - palette_indices=material) - gui.show( - ) # Change to gui.show(f'{frame:06d}.png') to write images to disk +def main(): + initialize() + gui = ti.GUI("Taichi MLS-MPM-99", res=512, background_color=0x112F41) + while not gui.get_event(ti.GUI.ESCAPE, ti.GUI.EXIT): + for s in range(int(2e-3 // dt)): + substep() + gui.circles(x.to_numpy(), + radius=1.5, + palette=[0x068587, 0xED553B, 0xEEEEF0], + palette_indices=material) + gui.show( + ) # Change to gui.show(f'{frame:06d}.png') to write images to disk + + +if __name__ == '__main__': + main() diff --git a/examples/simulation/mpm_lagrangian_forces.py b/python/taichi/examples/simulation/mpm_lagrangian_forces.py similarity index 100% rename from examples/simulation/mpm_lagrangian_forces.py rename to python/taichi/examples/simulation/mpm_lagrangian_forces.py diff --git a/examples/simulation/nbody.py b/python/taichi/examples/simulation/nbody.py similarity index 73% rename from examples/simulation/nbody.py rename to python/taichi/examples/simulation/nbody.py index f70d0d5e88684..c6de3021b2c11 100644 --- a/examples/simulation/nbody.py +++ b/python/taichi/examples/simulation/nbody.py @@ -80,23 +80,28 @@ def update(): pos[i] += dt * vel[i] -gui = ti.GUI('N-body problem', (800, 800)) - -initialize() -while gui.running: - - for e in gui.get_events(ti.GUI.PRESS): - if e.key in [ti.GUI.ESCAPE, ti.GUI.EXIT]: - exit() - elif e.key == 'r': - initialize() - elif e.key == ti.GUI.SPACE: - paused[None] = not paused[None] - - if not paused[None]: - for i in range(substepping): - compute_force() - update() - - gui.circles(pos.to_numpy(), color=0xffffff, radius=planet_radius) - gui.show() +def main(): + gui = ti.GUI('N-body problem', (800, 800)) + + initialize() + while gui.running: + + for e in gui.get_events(ti.GUI.PRESS): + if e.key in [ti.GUI.ESCAPE, ti.GUI.EXIT]: + exit() + elif e.key == 'r': + initialize() + elif e.key == ti.GUI.SPACE: + paused[None] = not paused[None] + + if not paused[None]: + for i in range(substepping): + compute_force() + update() + + gui.circles(pos.to_numpy(), color=0xffffff, radius=planet_radius) + gui.show() + + +if __name__ == '__main__': + main() diff --git a/examples/simulation/odop_solar.py b/python/taichi/examples/simulation/odop_solar.py similarity index 100% rename from examples/simulation/odop_solar.py rename to python/taichi/examples/simulation/odop_solar.py diff --git a/examples/simulation/pbf2d.py b/python/taichi/examples/simulation/pbf2d.py similarity index 100% rename from examples/simulation/pbf2d.py rename to python/taichi/examples/simulation/pbf2d.py diff --git a/examples/simulation/physarum.py b/python/taichi/examples/simulation/physarum.py similarity index 100% rename from examples/simulation/physarum.py rename to python/taichi/examples/simulation/physarum.py diff --git a/examples/simulation/stable_fluid.py b/python/taichi/examples/simulation/stable_fluid.py similarity index 98% rename from examples/simulation/stable_fluid.py rename to python/taichi/examples/simulation/stable_fluid.py index eb159f8de7120..864fd4204799a 100644 --- a/examples/simulation/stable_fluid.py +++ b/python/taichi/examples/simulation/stable_fluid.py @@ -12,13 +12,13 @@ # How to run: # `python stable_fluid.py`: use the jacobi iteration to solve the linear system. -# `python stable_fluid.py -s`: use a sparse matrix to do so. +# `python stable_fluid.py -S`: use a sparse matrix to do so. parser = argparse.ArgumentParser() -parser.add_argument('-s', +parser.add_argument('-S', '--use-sp-mat', action='store_true', help='Solve Poisson\'s equation by using a sparse matrix') -args = parser.parse_args() +args, unknowns = parser.parse_known_args() res = 512 dt = 0.03 @@ -69,7 +69,7 @@ def swap(self): if use_sparse_matrix: # use a sparse matrix to solve Poisson's pressure equation. @ti.kernel - def fill_laplacian_matrix(A: ti.linalg.sparse_matrix_builder()): + def fill_laplacian_matrix(A: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(res, res): row = i * res + j center = 0.0 diff --git a/examples/simulation/vortex_rings.py b/python/taichi/examples/simulation/vortex_rings.py similarity index 100% rename from examples/simulation/vortex_rings.py rename to python/taichi/examples/simulation/vortex_rings.py diff --git a/examples/simulation/waterwave.py b/python/taichi/examples/simulation/waterwave.py similarity index 100% rename from examples/simulation/waterwave.py rename to python/taichi/examples/simulation/waterwave.py diff --git a/python/taichi/experimental.py b/python/taichi/experimental.py new file mode 100644 index 0000000000000..13df4c4a9baf3 --- /dev/null +++ b/python/taichi/experimental.py @@ -0,0 +1,3 @@ +from taichi.lang.kernel_impl import real_func + +__all__ = ["real_func"] diff --git a/python/taichi/lang/README.md b/python/taichi/lang/README.md index 26cd7af7eb516..0df5e4f66cf7b 100644 --- a/python/taichi/lang/README.md +++ b/python/taichi/lang/README.md @@ -1,3 +1,3 @@ Some notes about the current implementation -There are lots of `from taichi.lang.meta import xx`. Unfortunately, this cannot be moved into the top of the file, and has to be delayed inside the function. Otherwise, it would result in some cyclic import issue that is not trivially resolvable. +There are lots of `from taichi._kernels import xx`. Unfortunately, this cannot be moved into the top of the file, and has to be delayed inside the function. Otherwise, it would result in some cyclic import issue that is not trivially resolvable. diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 9c54dd987875e..f1a171b0cb9eb 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -1,1249 +1,32 @@ -import atexit -import functools -import os -import shutil -import tempfile -import time -from contextlib import contextmanager -from copy import deepcopy as _deepcopy +import platform -import taichi.lang.linalg_impl -import taichi.lang.meta -from taichi.core.util import locale_encode -from taichi.core.util import ti_core as _ti_core -from taichi.lang import _random, impl, types -from taichi.lang.ast.transformer import TaichiSyntaxError +from taichi._lib import core as _ti_core +from taichi.lang import impl +from taichi.lang._ndarray import * +from taichi.lang._ndrange import ndrange from taichi.lang.enums import Layout -from taichi.lang.exception import InvalidOperationError +from taichi.lang.exception import * +from taichi.lang.field import * from taichi.lang.impl import * -from taichi.lang.kernel_impl import (KernelArgError, KernelDefError, - data_oriented, func, kernel, pyfunc) -from taichi.lang.matrix import Matrix, Vector -from taichi.lang.ndrange import GroupedNDRange, ndrange -from taichi.lang.ops import * -from taichi.lang.quant_impl import quant -from taichi.lang.runtime_ops import async_flush, sync -from taichi.lang.struct import Struct -from taichi.lang.type_factory_impl import type_factory -from taichi.lang.util import (has_pytorch, is_taichi_class, python_scope, - taichi_scope, to_numpy_type, to_pytorch_type, - to_taichi_type) -from taichi.misc.util import deprecated -from taichi.profiler import KernelProfiler, get_default_kernel_profiler -from taichi.profiler.kernelmetrics import (CuptiMetric, default_cupti_metrics, - get_predefined_cupti_metrics) -from taichi.snode.fields_builder import FieldsBuilder -from taichi.type.annotations import any_arr, ext_arr, template - -import taichi as ti - -runtime = impl.get_runtime() - -i = axes(0) -j = axes(1) -k = axes(2) -l = axes(3) -ij = axes(0, 1) -ik = axes(0, 2) -il = axes(0, 3) -jk = axes(1, 2) -jl = axes(1, 3) -kl = axes(2, 3) -ijk = axes(0, 1, 2) -ijl = axes(0, 1, 3) -ikl = axes(0, 2, 3) -jkl = axes(1, 2, 3) -ijkl = axes(0, 1, 2, 3) - -outer_product = deprecated('ti.outer_product(a, b)', - 'a.outer_product(b)')(Matrix.outer_product) -cross = deprecated('ti.cross(a, b)', 'a.cross(b)')(Matrix.cross) -dot = deprecated('ti.dot(a, b)', 'a.dot(b)')(Matrix.dot) -normalized = deprecated('ti.normalized(a)', - 'a.normalized()')(Matrix.normalized) - -cfg = default_cfg() -x86_64 = _ti_core.x64 -"""The x64 CPU backend. -""" -x64 = _ti_core.x64 -"""The X64 CPU backend. -""" -arm64 = _ti_core.arm64 -"""The ARM CPU backend. -""" -cuda = _ti_core.cuda -"""The CUDA backend. -""" -metal = _ti_core.metal -"""The Apple Metal backend. -""" -opengl = _ti_core.opengl -"""The OpenGL backend. OpenGL 4.3 required. -""" -# Skip annotating this one because it is barely maintained. -cc = _ti_core.cc -wasm = _ti_core.wasm -"""The WebAssembly backend. -""" -vulkan = _ti_core.vulkan -"""The Vulkan backend. -""" -gpu = [cuda, metal, opengl, vulkan] -"""A list of GPU backends supported on the current system. - -When this is used, Taichi automatically picks the matching GPU backend. If no -GPU is detected, Taichi falls back to the CPU backend. -""" -cpu = _ti_core.host_arch() -"""A list of CPU backends supported on the current system. - -When this is used, Taichi automatically picks the matching CPU backend. -""" -timeline_clear = lambda: impl.get_runtime().prog.timeline_clear() -timeline_save = lambda fn: impl.get_runtime().prog.timeline_save(fn) - -# Legacy API -type_factory_ = _ti_core.get_type_factory_instance() - - -@deprecated('kernel_profiler_print()', 'print_kernel_profile_info()') -def kernel_profiler_print(): - return print_kernel_profile_info() - - -def print_kernel_profile_info(mode='count'): - """Print the profiling results of Taichi kernels. - - To enable this profiler, set ``kernel_profiler=True`` in ``ti.init()``. - ``'count'`` mode: print the statistics (min,max,avg time) of launched kernels, - ``'trace'`` mode: print the records of launched kernels with specific profiling metrics (time, memory load/store and core utilization etc.), - and defaults to ``'count'``. - - Args: - mode (str): the way to print profiling results. - - Example:: - - >>> import taichi as ti - - >>> ti.init(ti.cpu, kernel_profiler=True) - >>> var = ti.field(ti.f32, shape=1) - - >>> @ti.kernel - >>> def compute(): - >>> var[0] = 1.0 - - >>> compute() - >>> ti.print_kernel_profile_info() - >>> # equivalent calls : - >>> # ti.print_kernel_profile_info('count') - - >>> ti.print_kernel_profile_info('trace') - - Note: - Currently the result of `KernelProfiler` could be incorrect on OpenGL - backend due to its lack of support for `ti.sync()`. - - For advanced mode of `KernelProfiler`, please visit https://docs.taichi.graphics/docs/lang/articles/misc/profiler#advanced-mode. - """ - get_default_kernel_profiler().print_info(mode) - - -def query_kernel_profile_info(name): - """Query kernel elapsed time(min,avg,max) on devices using the kernel name. - - To enable this profiler, set `kernel_profiler=True` in `ti.init`. - - Args: - name (str): kernel name. - - Returns: - KernelProfilerQueryResult (class): with member variables(counter, min, max, avg) - - Example:: - - >>> import taichi as ti - - >>> ti.init(ti.cpu, kernel_profiler=True) - >>> n = 1024*1024 - >>> var = ti.field(ti.f32, shape=n) - - >>> @ti.kernel - >>> def fill(): - >>> for i in range(n): - >>> var[i] = 0.1 - - >>> fill() - >>> ti.clear_kernel_profile_info() #[1] - >>> for i in range(100): - >>> fill() - >>> query_result = ti.query_kernel_profile_info(fill.__name__) #[2] - >>> print("kernel excuted times =",query_result.counter) - >>> print("kernel elapsed time(min_in_ms) =",query_result.min) - >>> print("kernel elapsed time(max_in_ms) =",query_result.max) - >>> print("kernel elapsed time(avg_in_ms) =",query_result.avg) - - Note: - [1] To get the correct result, query_kernel_profile_info() must be used in conjunction with - clear_kernel_profile_info(). - - [2] Currently the result of `KernelProfiler` could be incorrect on OpenGL - backend due to its lack of support for `ti.sync()`. - """ - return get_default_kernel_profiler().query_info(name) - - -@deprecated('kernel_profiler_clear()', 'clear_kernel_profile_info()') -def kernel_profiler_clear(): - return clear_kernel_profile_info() - - -def clear_kernel_profile_info(): - """Clear all KernelProfiler records.""" - get_default_kernel_profiler().clear_info() - - -def kernel_profiler_total_time(): - """Get elapsed time of all kernels recorded in KernelProfiler. - - Returns: - time (float): total time in second. - """ - return get_default_kernel_profiler().get_total_time() - - -def set_kernel_profile_metrics(metric_list=default_cupti_metrics): - """Set metrics that will be collected by the CUPTI toolkit. - - Args: - metric_list (list): a list of :class:`~taichi.lang.CuptiMetric()` instances, default value: :data:`~taichi.lang.default_cupti_metrics`. - - Example:: - - >>> import taichi as ti - - >>> ti.init(kernel_profiler=True, arch=ti.cuda) - >>> num_elements = 128*1024*1024 - - >>> x = ti.field(ti.f32, shape=num_elements) - >>> y = ti.field(ti.f32, shape=()) - >>> y[None] = 0 - - >>> @ti.kernel - >>> def reduction(): - >>> for i in x: - >>> y[None] += x[i] - - >>> # In the case of not pramater, Taichi will print its pre-defined metrics list - >>> ti.get_predefined_cupti_metrics() - >>> # get Taichi pre-defined metrics - >>> profiling_metrics = ti.get_predefined_cupti_metrics('shared_access') - - >>> global_op_atom = ti.CuptiMetric( - >>> name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_atom.sum', - >>> header=' global.atom ', - >>> format=' {:8.0f} ') - >>> # add user defined metrics - >>> profiling_metrics += [global_op_atom] - - >>> # metrics setting will be retained until the next configuration - >>> ti.set_kernel_profile_metrics(profiling_metrics) - >>> for i in range(16): - >>> reduction() - >>> ti.print_kernel_profile_info('trace') - - Note: - Metrics setting will be retained until the next configuration. - """ - get_default_kernel_profiler().set_metrics(metric_list) - - -@contextmanager -def collect_kernel_profile_metrics(metric_list=default_cupti_metrics): - """Set temporary metrics that will be collected by the CUPTI toolkit within this context. - - Args: - metric_list (list): a list of :class:`~taichi.lang.CuptiMetric()` instances, default value: :data:`~taichi.lang.default_cupti_metrics`. - - Example:: - - >>> import taichi as ti - - >>> ti.init(kernel_profiler=True, arch=ti.cuda) - >>> num_elements = 128*1024*1024 - - >>> x = ti.field(ti.f32, shape=num_elements) - >>> y = ti.field(ti.f32, shape=()) - >>> y[None] = 0 - - >>> @ti.kernel - >>> def reduction(): - >>> for i in x: - >>> y[None] += x[i] - - >>> # In the case of not pramater, Taichi will print its pre-defined metrics list - >>> ti.get_predefined_cupti_metrics() - >>> # get Taichi pre-defined metrics - >>> profiling_metrics = ti.get_predefined_cupti_metrics('device_utilization') - - >>> global_op_atom = ti.CuptiMetric( - >>> name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_atom.sum', - >>> header=' global.atom ', - >>> format=' {:8.0f} ') - >>> # add user defined metrics - >>> profiling_metrics += [global_op_atom] - - >>> # metrics setting is temporary, and will be clear when exit from this context. - >>> with ti.collect_kernel_profile_metrics(profiling_metrics): - >>> for i in range(16): - >>> reduction() - >>> ti.print_kernel_profile_info('trace') - - Note: - The configuration of the ``metric_list`` will be clear when exit from this context. - """ - get_default_kernel_profiler().set_metrics(metric_list) - yield get_default_kernel_profiler() - get_default_kernel_profiler().set_metrics() - - -@deprecated('memory_profiler_print()', 'print_memory_profile_info()') -def memory_profiler_print(): - return print_memory_profile_info() - - -def print_memory_profile_info(): - """Memory profiling tool for LLVM backends with full sparse support. - - This profiler is automatically on. - """ - impl.get_runtime().materialize() - impl.get_runtime().prog.print_memory_profiler_info() - - -extension = _ti_core.Extension - - -def is_extension_supported(arch, ext): - """Checks whether an extension is supported on an arch. - - Args: - arch (taichi_core.Arch): Specified arch. - ext (taichi_core.Extension): Specified extension. - - Returns: - bool: Whether `ext` is supported on `arch`. - """ - return _ti_core.is_extension_supported(arch, ext) - - -def reset(): - """Resets Taichi to its initial state. - - This would destroy all the fields and kernels. - """ - _ti_core.reset_snode_access_flag() - impl.reset() - global runtime - runtime = impl.get_runtime() - - -class _EnvironmentConfigurator: - def __init__(self, kwargs, cfg): - self.cfg = cfg - self.kwargs = kwargs - self.keys = [] - - def add(self, key, cast=None): - cast = cast or self.bool_int - - self.keys.append(key) - - # TI_ASYNC= : no effect - # TI_ASYNC=0 : False - # TI_ASYNC=1 : True - name = 'TI_' + key.upper() - value = os.environ.get(name, '') - if len(value): - self[key] = cast(value) - if key in self.kwargs: - _ti_core.warn( - f'ti.init argument "{key}" overridden by environment variable {name}={value}' - ) - del self.kwargs[key] # mark as recognized - elif key in self.kwargs: - self[key] = self.kwargs[key] - del self.kwargs[key] # mark as recognized - - def __getitem__(self, key): - return getattr(self.cfg, key) - - def __setitem__(self, key, value): - setattr(self.cfg, key, value) - - @staticmethod - def bool_int(x): - return bool(int(x)) - - -class _SpecialConfig: - # like CompileConfig in C++, this is the configurations that belong to other submodules - def __init__(self): - self.print_preprocessed = False - self.log_level = 'info' - self.gdb_trigger = False - self.excepthook = False - self.experimental_real_function = False - - -def prepare_sandbox(): - ''' - Returns a temporary directory, which will be automatically deleted on exit. - It may contain the taichi_core shared object or some misc. files. - ''' - tmp_dir = tempfile.mkdtemp(prefix='taichi-') - atexit.register(shutil.rmtree, tmp_dir) - print(f'[Taichi] preparing sandbox at {tmp_dir}') - os.mkdir(os.path.join(tmp_dir, 'runtime/')) - return tmp_dir - - -def init(arch=None, - default_fp=None, - default_ip=None, - _test_mode=False, - **kwargs): - """Initializes the Taichi runtime. - - This should always be the entry point of your Taichi program. Most - importantly, it sets the backend used throughout the program. - - Args: - arch: Backend to use. This is usually :const:`~taichi.lang.cpu` or :const:`~taichi.lang.gpu`. - default_fp (Optional[type]): Default floating-point type. - default_fp (Optional[type]): Default integral type. - **kwargs: Taichi provides highly customizable compilation through - ``kwargs``, which allows for fine grained control of Taichi compiler - behavior. Below we list some of the most frequently used ones. For a - complete list, please check out - https://github.com/taichi-dev/taichi/blob/master/taichi/program/compile_config.h. - - * ``cpu_max_num_threads`` (int): Sets the number of threads used by the CPU thread pool. - * ``debug`` (bool): Enables the debug mode, under which Taichi does a few more things like boundary checks. - * ``print_ir`` (bool): Prints the CHI IR of the Taichi kernels. - * ``packed`` (bool): Enables the packed memory layout. See https://docs.taichi.graphics/lang/articles/advanced/layout. - """ - # Make a deepcopy in case these args reference to items from ti.cfg, which are - # actually references. If no copy is made and the args are indeed references, - # ti.reset() could override the args to their default values. - default_fp = _deepcopy(default_fp) - default_ip = _deepcopy(default_ip) - kwargs = _deepcopy(kwargs) - ti.reset() - - spec_cfg = _SpecialConfig() - env_comp = _EnvironmentConfigurator(kwargs, ti.cfg) - env_spec = _EnvironmentConfigurator(kwargs, spec_cfg) - - # configure default_fp/ip: - # TODO: move these stuff to _SpecialConfig too: - env_default_fp = os.environ.get("TI_DEFAULT_FP") - if env_default_fp: - if default_fp is not None: - _ti_core.warn( - f'ti.init argument "default_fp" overridden by environment variable TI_DEFAULT_FP={env_default_fp}' - ) - if env_default_fp == '32': - default_fp = ti.f32 - elif env_default_fp == '64': - default_fp = ti.f64 - elif env_default_fp is not None: - raise ValueError( - f'Invalid TI_DEFAULT_FP={env_default_fp}, should be 32 or 64') - - env_default_ip = os.environ.get("TI_DEFAULT_IP") - if env_default_ip: - if default_ip is not None: - _ti_core.warn( - f'ti.init argument "default_ip" overridden by environment variable TI_DEFAULT_IP={env_default_ip}' - ) - if env_default_ip == '32': - default_ip = ti.i32 - elif env_default_ip == '64': - default_ip = ti.i64 - elif env_default_ip is not None: - raise ValueError( - f'Invalid TI_DEFAULT_IP={env_default_ip}, should be 32 or 64') - - if default_fp is not None: - impl.get_runtime().set_default_fp(default_fp) - if default_ip is not None: - impl.get_runtime().set_default_ip(default_ip) - - # submodule configurations (spec_cfg): - env_spec.add('print_preprocessed') - env_spec.add('log_level', str) - env_spec.add('gdb_trigger') - env_spec.add('excepthook') - env_spec.add('experimental_real_function') - - # compiler configurations (ti.cfg): - for key in dir(ti.cfg): - if key in ['arch', 'default_fp', 'default_ip']: - continue - cast = type(getattr(ti.cfg, key)) - if cast is bool: - cast = None - env_comp.add(key, cast) - - unexpected_keys = kwargs.keys() - - if 'use_unified_memory' in unexpected_keys: - _ti_core.warn( - f'"use_unified_memory" is a deprecated option, as taichi no longer have the option of using unified memory.' - ) - del kwargs['use_unified_memory'] - - if len(unexpected_keys): - raise KeyError( - f'Unrecognized keyword argument(s) for ti.init: {", ".join(unexpected_keys)}' - ) - - # dispatch configurations that are not in ti.cfg: - if not _test_mode: - ti.set_gdb_trigger(spec_cfg.gdb_trigger) - impl.get_runtime().print_preprocessed = spec_cfg.print_preprocessed - impl.get_runtime().experimental_real_function = \ - spec_cfg.experimental_real_function - ti.set_logging_level(spec_cfg.log_level.lower()) - if spec_cfg.excepthook: - # TODO(#1405): add a way to restore old excepthook - ti.enable_excepthook() - - # select arch (backend): - env_arch = os.environ.get('TI_ARCH') - if env_arch is not None: - ti.info(f'Following TI_ARCH setting up for arch={env_arch}') - arch = _ti_core.arch_from_name(env_arch) - ti.cfg.arch = adaptive_arch_select(arch) - if ti.cfg.arch == cc: - _ti_core.set_tmp_dir(locale_encode(prepare_sandbox())) - print(f'[Taichi] Starting on arch={_ti_core.arch_name(ti.cfg.arch)}') - - if _test_mode: - return spec_cfg - - get_default_kernel_profiler().set_kernel_profiler_mode( - ti.cfg.kernel_profiler) - - # create a new program: - impl.get_runtime().create_program() - - ti.trace('Materializing runtime...') - impl.get_runtime().prog.materialize_runtime() - - impl._root_fb = FieldsBuilder() - - -def no_activate(*args): - for v in args: - _ti_core.no_activate(v.snode.ptr) - - -def block_local(*args): - """Hints Taichi to cache the fields and to enable the BLS optimization. - - Please visit https://docs.taichi.graphics/lang/articles/advanced/performance - for how BLS is used. - - Args: - *args (List[Field]): A list of sparse Taichi fields. - - Raises: - InvalidOperationError: If the ``dynamic_index`` feature (experimental) - is enabled. - """ - if ti.current_cfg().dynamic_index: - raise InvalidOperationError( - 'dynamic_index is not allowed when block_local is turned on.') - for a in args: - for v in a.get_field_members(): - _ti_core.insert_snode_access_flag( - _ti_core.SNodeAccessFlag.block_local, v.ptr) - - -@deprecated('ti.cache_shared', 'ti.block_local') -def cache_shared(*args): - block_local(*args) - - -def cache_read_only(*args): - for a in args: - for v in a.get_field_members(): - _ti_core.insert_snode_access_flag( - _ti_core.SNodeAccessFlag.read_only, v.ptr) - - -def assume_in_range(val, base, low, high): - return _ti_core.expr_assume_in_range( - Expr(val).ptr, - Expr(base).ptr, low, high) - - -def loop_unique(val, covers=None): - if covers is None: - covers = [] - if not isinstance(covers, (list, tuple)): - covers = [covers] - covers = [x.snode.ptr if isinstance(x, Expr) else x.ptr for x in covers] - return _ti_core.expr_loop_unique(Expr(val).ptr, covers) - - -parallelize = _ti_core.parallelize -serialize = lambda: parallelize(1) -vectorize = _ti_core.vectorize -bit_vectorize = _ti_core.bit_vectorize -block_dim = _ti_core.block_dim - -inversed = deprecated('ti.inversed(a)', 'a.inverse()')(Matrix.inversed) -transposed = deprecated('ti.transposed(a)', 'a.transpose()')(Matrix.transposed) - - -def polar_decompose(A, dt=None): - """Perform polar decomposition (A=UP) for arbitrary size matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. - This is only a wrapper for :func:`taichi.lang.linalg_impl.polar_decompose`. - - Args: - A (ti.Matrix(n, n)): input nxn matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed nxn matrices `U` and `P`. - """ - if dt is None: - dt = impl.get_runtime().default_fp - return taichi.lang.linalg_impl.polar_decompose(A, dt) - - -def svd(A, dt=None): - """Perform singular value decomposition (A=USV^T) for arbitrary size matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. - This is only a wrappers for :func:`taichi.lang.linalg_impl.svd`. - - Args: - A (ti.Matrix(n, n)): input nxn matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed nxn matrices `U`, 'S' and `V`. - """ - if dt is None: - dt = impl.get_runtime().default_fp - return taichi.lang.linalg_impl.svd(A, dt) - - -def eig(A, dt=None): - """Compute the eigenvalues and right eigenvectors of a real matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. - 2D implementation refers to :func:`taichi.lang.linalg_impl.eig2x2`. - - Args: - A (ti.Matrix(n, n)): 2D Matrix for which the eigenvalues and right eigenvectors will be computed. - dt (DataType): The datatype for the eigenvalues and right eigenvectors. - - Returns: - eigenvalues (ti.Matrix(n, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part. - eigenvectors (ti.Matrix(n*2, n)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of n entries, each of which is represented by two numbers for its real part and imaginary part. - """ - if dt is None: - dt = impl.get_runtime().default_fp - if A.n == 2: - return taichi.lang.linalg_impl.eig2x2(A, dt) - raise Exception("Eigen solver only supports 2D matrices.") - - -def sym_eig(A, dt=None): - """Compute the eigenvalues and right eigenvectors of a real symmetric matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. - 2D implementation refers to :func:`taichi.lang.linalg_impl.sym_eig2x2`. - - Args: - A (ti.Matrix(n, n)): Symmetric Matrix for which the eigenvalues and right eigenvectors will be computed. - dt (DataType): The datatype for the eigenvalues and right eigenvectors. - - Returns: - eigenvalues (ti.Vector(n)): The eigenvalues. Each entry store one eigen value. - eigenvectors (ti.Matrix(n, n)): The eigenvectors. Each column stores one eigenvector. - """ - assert all(A == A.transpose()), "A needs to be symmetric" - if dt is None: - dt = impl.get_runtime().default_fp - if A.n == 2: - return taichi.lang.linalg_impl.sym_eig2x2(A, dt) - raise Exception("Symmetric eigen solver only supports 2D matrices.") - - -def randn(dt=None): - """Generates a random number from standard normal distribution. - - Implementation refers to :func:`taichi.lang.random.randn`. - - Args: - dt (DataType): The datatype for the generated random number. - - Returns: - The generated random number. - """ - if dt is None: - dt = impl.get_runtime().default_fp - return _random.randn(dt) - - -determinant = deprecated('ti.determinant(a)', - 'a.determinant()')(Matrix.determinant) -tr = deprecated('ti.tr(a)', 'a.trace()')(Matrix.trace) - - -def Tape(loss, clear_gradients=True): - """Return a context manager of :class:`~taichi.lang.tape.TapeImpl`. The - context manager would catching all of the callings of functions that - decorated by :func:`~taichi.lang.kernel_impl.kernel` or - :func:`~taichi.ad.grad_replaced` under `with` statement, and calculate - all the partial gradients of a given loss variable by calling all of the - gradient function of the callings caught in reverse order while `with` - statement ended. - - See also :func:`~taichi.lang.kernel_impl.kernel` and - :func:`~taichi.ad.grad_replaced` for gradient functions. - - Args: - loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be (). - clear_gradients(Bool): Before `with` body start, clear all gradients or not. - - Returns: - :class:`~taichi.lang.tape.TapeImpl`: The context manager. - - Example:: - - >>> @ti.kernel - >>> def sum(a: ti.float32): - >>> for I in ti.grouped(x): - >>> y[None] += x[I] ** a - >>> - >>> with ti.Tape(loss = y): - >>> sum(2)""" - impl.get_runtime().materialize() - if len(loss.shape) != 0: - raise RuntimeError( - 'The loss of `Tape` must be a 0-D field, i.e. scalar') - if not loss.snode.ptr.has_grad(): - raise RuntimeError( - 'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)' - ' for all fields that are required by autodiff.') - if clear_gradients: - clear_all_gradients() - - taichi.lang.meta.clear_loss(loss) - - return runtime.get_tape(loss) - - -def clear_all_gradients(): - """Set all fields' gradients to 0.""" - impl.get_runtime().materialize() - - def visit(node): - places = [] - for i in range(node.ptr.get_num_ch()): - ch = node.ptr.get_ch(i) - if not ch.is_place(): - visit(SNode(ch)) - else: - if not ch.is_primal(): - places.append(ch.get_expr()) - - places = tuple(places) - if places: - taichi.lang.meta.clear_gradients(places) - - for root_fb in FieldsBuilder.finalized_roots(): - visit(root_fb) - - -def deactivate_all_snodes(): - """Recursively deactivate all SNodes.""" - for root_fb in FieldsBuilder.finalized_roots(): - root_fb.deactivate_all() - - -def benchmark(func, repeat=300, args=()): - def run_benchmark(): - compile_time = time.time() - func(*args) # compile the kernel first - ti.sync() - compile_time = time.time() - compile_time - ti.stat_write('compilation_time', compile_time) - codegen_stat = _ti_core.stat() - for line in codegen_stat.split('\n'): - try: - a, b = line.strip().split(':') - except: - continue - a = a.strip() - b = int(float(b)) - if a == 'codegen_kernel_statements': - ti.stat_write('compiled_inst', b) - if a == 'codegen_offloaded_tasks': - ti.stat_write('compiled_tasks', b) - elif a == 'launched_tasks': - ti.stat_write('launched_tasks', b) - - # Use 3 initial iterations to warm up - # instruction/data caches. Discussion: - # https://github.com/taichi-dev/taichi/pull/1002#discussion_r426312136 - for i in range(3): - func(*args) - ti.sync() - ti.clear_kernel_profile_info() - t = time.time() - for n in range(repeat): - func(*args) - ti.sync() - elapsed = time.time() - t - avg = elapsed / repeat - ti.stat_write('wall_clk_t', avg) - device_time = ti.kernel_profiler_total_time() - avg_device_time = device_time / repeat - ti.stat_write('exec_t', avg_device_time) - - run_benchmark() - - -def benchmark_plot(fn=None, - cases=None, - columns=None, - column_titles=None, - archs=None, - title=None, - bars='sync_vs_async', - bar_width=0.4, - bar_distance=0, - left_margin=0, - size=(12, 8)): - import matplotlib.pyplot as plt # pylint: disable=C0415 - import yaml # pylint: disable=C0415 - if fn is None: - fn = os.path.join(_ti_core.get_repo_dir(), 'benchmarks', 'output', - 'benchmark.yml') - - with open(fn, 'r') as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - if bars != 'sync_vs_async': # need baseline - baseline_dir = os.path.join(_ti_core.get_repo_dir(), 'benchmarks', - 'baseline') - baseline_file = f'{baseline_dir}/benchmark.yml' - with open(baseline_file, 'r') as f: - baseline_data = yaml.load(f, Loader=yaml.SafeLoader) - if cases is None: - cases = list(data.keys()) - - assert len(cases) >= 1 - if len(cases) == 1: - cases = [cases[0], cases[0]] - ti.warning( - 'Function benchmark_plot does not support plotting with only one case for now. Duplicating the item to move on.' - ) - - if columns is None: - columns = list(data[cases[0]].keys()) - if column_titles is None: - column_titles = columns - normalize_to_lowest = lambda x: True - figure, subfigures = plt.subplots(len(cases), len(columns)) - if title is None: - title = 'Taichi Performance Benchmarks (Higher means more)' - figure.suptitle(title, fontweight="bold") - for col_id in range(len(columns)): - subfigures[0][col_id].set_title(column_titles[col_id]) - for case_id in range(len(cases)): - case = cases[case_id] - subfigures[case_id][0].annotate( - case, - xy=(0, 0.5), - xytext=(-subfigures[case_id][0].yaxis.labelpad - 5, 0), - xycoords=subfigures[case_id][0].yaxis.label, - textcoords='offset points', - size='large', - ha='right', - va='center') - for col_id in range(len(columns)): - col = columns[col_id] - if archs is None: - current_archs = data[case][col].keys() - else: - current_archs = [ - x for x in archs if x in data[case][col].keys() - ] - if bars == 'sync_vs_async': - y_left = [ - data[case][col][arch]['sync'] for arch in current_archs - ] - label_left = 'sync' - y_right = [ - data[case][col][arch]['async'] for arch in current_archs - ] - label_right = 'async' - elif bars == 'sync_regression': - y_left = [ - baseline_data[case][col][arch]['sync'] - for arch in current_archs - ] - label_left = 'before' - y_right = [ - data[case][col][arch]['sync'] for arch in current_archs - ] - label_right = 'after' - elif bars == 'async_regression': - y_left = [ - baseline_data[case][col][arch]['async'] - for arch in current_archs - ] - label_left = 'before' - y_right = [ - data[case][col][arch]['async'] for arch in current_archs - ] - label_right = 'after' - else: - raise RuntimeError('Unknown bars type') - if normalize_to_lowest(col): - for i in range(len(current_archs)): - maximum = max(y_left[i], y_right[i]) - y_left[i] = y_left[i] / maximum if y_left[i] != 0 else 1 - y_right[i] = y_right[i] / maximum if y_right[i] != 0 else 1 - ax = subfigures[case_id][col_id] - bar_left = ax.bar(x=[ - i - bar_width / 2 - bar_distance / 2 - for i in range(len(current_archs)) - ], - height=y_left, - width=bar_width, - label=label_left, - color=(0.47, 0.69, 0.89, 1.0)) - bar_right = ax.bar(x=[ - i + bar_width / 2 + bar_distance / 2 - for i in range(len(current_archs)) - ], - height=y_right, - width=bar_width, - label=label_right, - color=(0.68, 0.26, 0.31, 1.0)) - ax.set_xticks(range(len(current_archs))) - ax.set_xticklabels(current_archs) - figure.legend((bar_left, bar_right), (label_left, label_right), - loc='lower center') - figure.subplots_adjust(left=left_margin) - - fig = plt.gcf() - fig.set_size_inches(size) - - plt.show() - - -def stat_write(key, value): - import yaml # pylint: disable=C0415 - case_name = os.environ.get('TI_CURRENT_BENCHMARK') - if case_name is None: - return - if case_name.startswith('benchmark_'): - case_name = case_name[10:] - arch_name = _ti_core.arch_name(ti.cfg.arch) - async_mode = 'async' if ti.cfg.async_mode else 'sync' - output_dir = os.environ.get('TI_BENCHMARK_OUTPUT_DIR', '.') - filename = f'{output_dir}/benchmark.yml' - try: - with open(filename, 'r') as f: - data = yaml.load(f, Loader=yaml.SafeLoader) - except FileNotFoundError: - data = {} - data.setdefault(case_name, {}) - data[case_name].setdefault(key, {}) - data[case_name][key].setdefault(arch_name, {}) - data[case_name][key][arch_name][async_mode] = value - with open(filename, 'w') as f: - yaml.dump(data, f, Dumper=yaml.SafeDumper) - - -def is_arch_supported(arch): - """Checks whether an arch is supported on the machine. - - Args: - arch (taichi_core.Arch): Specified arch. - - Returns: - bool: Whether `arch` is supported on the machine. - """ - arch_table = { - cuda: _ti_core.with_cuda, - metal: _ti_core.with_metal, - opengl: _ti_core.with_opengl, - cc: _ti_core.with_cc, - vulkan: lambda: _ti_core.with_vulkan(), - wasm: lambda: True, - cpu: lambda: True, - } - with_arch = arch_table.get(arch, lambda: False) - try: - return with_arch() - except Exception as e: - arch = _ti_core.arch_name(arch) - _ti_core.warn( - f"{e.__class__.__name__}: '{e}' occurred when detecting " - f"{arch}, consider add `export TI_WITH_{arch.upper()}=0` " - f" to environment variables to depress this warning message.") - return False - - -def supported_archs(): - """Gets all supported archs on the machine. - - Returns: - List[taichi_core.Arch]: All supported archs on the machine. - """ - archs = set([cpu, cuda, metal, vulkan, opengl, cc]) - archs = set(filter(lambda x: is_arch_supported(x), archs)) - - wanted_archs = os.environ.get('TI_WANTED_ARCHS', '') - want_exclude = wanted_archs.startswith('^') - if want_exclude: - wanted_archs = wanted_archs[1:] - wanted_archs = wanted_archs.split(',') - # Note, ''.split(',') gives you [''], which is not an empty array. - expanded_wanted_archs = set([]) - for arch in wanted_archs: - if arch == '': - continue - if arch == 'cpu': - expanded_wanted_archs.add(cpu) - elif arch == 'gpu': - expanded_wanted_archs.update(gpu) - else: - expanded_wanted_archs.add(_ti_core.arch_from_name(arch)) - if len(expanded_wanted_archs) == 0: - return list(archs) - if want_exclude: - supported = archs - expanded_wanted_archs - else: - supported = archs & expanded_wanted_archs - return list(supported) - - -def adaptive_arch_select(arch): - if arch is None: - return cpu - if not isinstance(arch, (list, tuple)): - arch = [arch] - for a in arch: - if is_arch_supported(a): - return a - ti.warn(f'Arch={arch} is not supported, falling back to CPU') - return cpu - - -class _ArchCheckers(object): - def __init__(self): - self._checkers = [] - - def register(self, c): - self._checkers.append(c) - - def __call__(self, arch): - assert isinstance(arch, _ti_core.Arch) - return all([c(arch) for c in self._checkers]) - - -_tests_arch_checkers_argname = '_tests_arch_checkers' - - -def _get_or_make_arch_checkers(kwargs): - k = _tests_arch_checkers_argname - if k not in kwargs: - kwargs[k] = _ArchCheckers() - return kwargs[k] - - -# test with all archs -def all_archs_with(**kwargs): - kwargs = _deepcopy(kwargs) - - def decorator(test): - # @pytest.mark.parametrize decorator only knows about regular function args, - # without *args or **kwargs. By decorating with @functools.wraps, the - # signature of |test| is preserved, so that @ti.all_archs can be used after - # the parametrization decorator. - # - # Full discussion: https://github.com/pytest-dev/pytest/issues/6810 - @functools.wraps(test) - def wrapped(*test_args, **test_kwargs): - can_run_on = test_kwargs.pop(_tests_arch_checkers_argname, - _ArchCheckers()) - # Filter away archs that don't support 64-bit data. - fp = kwargs.get('default_fp', ti.f32) - ip = kwargs.get('default_ip', ti.i32) - if fp == ti.f64 or ip == ti.i64: - can_run_on.register(lambda arch: is_extension_supported( - arch, extension.data64)) - - for arch in ti.supported_archs(): - if can_run_on(arch): - print('Running test on arch={}'.format(arch)) - ti.init(arch=arch, **kwargs) - test(*test_args, **test_kwargs) - else: - print('Skipped test on arch={}'.format(arch)) - - return wrapped - - return decorator - - -# test with all archs -def all_archs(test): - return all_archs_with()(test) - - -# Exclude the given archs when running the tests -# -# Example usage: -# -# @ti.archs_excluding(ti.cuda, ti.metal) -# def test_xx(): -# ... -# -# @ti.archs_excluding(ti.cuda, default_fp=ti.f64) -# def test_yy(): -# ... -def archs_excluding(*excluded_archs, **kwargs): - # |kwargs| will be passed to all_archs_with(**kwargs) - assert all([isinstance(a, _ti_core.Arch) for a in excluded_archs]) - excluded_archs = set(excluded_archs) - - def decorator(test): - @functools.wraps(test) - def wrapped(*test_args, **test_kwargs): - def checker(arch): - return arch not in excluded_archs - - _get_or_make_arch_checkers(test_kwargs).register(checker) - return all_archs_with(**kwargs)(test)(*test_args, **test_kwargs) - - return wrapped - - return decorator - - -# Specifies the extension features the archs are required to support in order -# to run the test. -# -# Example usage: -# -# @ti.require(ti.extension.data64) -# @ti.all_archs_with(default_fp=ti.f64) -# def test_xx(): -# ... -def require(*exts): - # Because this decorator injects an arch checker, its usage must be followed - # with all_archs_with(), either directly or indirectly. - assert all([isinstance(e, _ti_core.Extension) for e in exts]) - - def decorator(test): - @functools.wraps(test) - def wrapped(*test_args, **test_kwargs): - def checker(arch): - return all([is_extension_supported(arch, e) for e in exts]) - - _get_or_make_arch_checkers(test_kwargs).register(checker) - test(*test_args, **test_kwargs) - - return wrapped - - return decorator - - -def archs_support_sparse(test, **kwargs): - wrapped = all_archs_with(**kwargs)(test) - return require(extension.sparse)(wrapped) - - -def torch_test(func): - if ti.has_pytorch(): - # OpenGL somehow crashes torch test without a reason, unforturnately - return ti.test(exclude=[opengl])(func) - else: - return lambda: None - - -def get_host_arch_list(): - return [_ti_core.host_arch()] - - -# test with host arch only -def host_arch_only(func): - @functools.wraps(func) - def test(*args, **kwargs): - archs = [_ti_core.host_arch()] - for arch in archs: - ti.init(arch=arch) - func(*args, **kwargs) - - return test - - -def archs_with(archs, **init_kwags): - """ - Run the test on the given archs with the given init args. - - Args: - archs: a list of Taichi archs - init_kwargs: kwargs passed to ti.init() - """ - def decorator(test): - @functools.wraps(test) - def wrapped(*test_args, **test_kwargs): - for arch in archs: - ti.init(arch=arch, **init_kwags) - test(*test_args, **test_kwargs) - - return wrapped - - return decorator - - -def must_throw(ex): - def decorator(func): - def func__(*args, **kwargs): - finishes = False - try: - func(*args, **kwargs) - finishes = True - except ex: - # throws. test passed - pass - except Exception as err_actual: - assert False, 'Exception {} instead of {} thrown'.format( - str(type(err_actual)), str(ex)) - if finishes: - assert False, 'Test successfully finished instead of throwing {}'.format( - str(ex)) - - return func__ - - return decorator - - -__all__ = [s for s in dir() if not s.startswith('_')] +from taichi.lang.kernel_impl import * +from taichi.lang.matrix import * +from taichi.lang.mesh import * +from taichi.lang.misc import * # pylint: disable=W0622 +from taichi.lang.ops import * # pylint: disable=W0622 +from taichi.lang.runtime_ops import * +from taichi.lang.snode import * +from taichi.lang.source_builder import * +from taichi.lang.struct import * +from taichi.types.annotations import any_arr, ext_arr, template +from taichi.types.primitive_types import f16, f32, f64, i32, i64, u32, u64 + +from taichi import _logging, _snode + +__all__ = [ + s for s in dir() if not s.startswith('_') and s not in [ + 'any_array', 'ast', 'common_ops', 'enums', 'exception', 'expr', 'impl', + 'inspect', 'kernel_arguments', 'kernel_impl', 'matrix', 'mesh', 'misc', + 'ops', 'platform', 'runtime_ops', 'shell', 'snode', 'source_builder', + 'struct', 'tape', 'util' + ] +] diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index f0a59a564cdbb..a1bd856dabf39 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -1,48 +1,35 @@ import numpy as np -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core from taichi.lang import impl from taichi.lang.enums import Layout -from taichi.lang.util import (cook_dtype, has_pytorch, python_scope, - to_pytorch_type, to_taichi_type) - -if has_pytorch(): - import torch +from taichi.lang.util import cook_dtype, python_scope, to_numpy_type +from taichi.types import primitive_types class Ndarray: - """Taichi ndarray class implemented with a torch tensor. + """Taichi ndarray class. Args: dtype (DataType): Data type of each value. - shape (Tuple[int]): Shape of the torch tensor. + shape (Tuple[int]): Shape of the Ndarray. """ - def __init__(self, dtype, shape): - assert has_pytorch( - ), "PyTorch must be available if you want to create a Taichi ndarray." - self.arr = torch.zeros(shape, dtype=to_pytorch_type(cook_dtype(dtype))) - if impl.current_cfg().arch == _ti_core.Arch.cuda: - self.arr = self.arr.cuda() + def __init__(self, dtype, arr_shape): + self.host_accessor = None + self.dtype = cook_dtype(dtype) + self.arr = _ti_core.Ndarray(impl.get_runtime().prog, cook_dtype(dtype), + arr_shape) @property - def shape(self): - """Gets ndarray shape. + def element_shape(self): + """Gets ndarray element shape. Returns: - Tuple[Int]: Ndarray shape. + Tuple[Int]: Ndarray element shape. """ raise NotImplementedError() @property - def dtype(self): - """Gets data type of each individual value. - - Returns: - DataType: Data type of each individual value. - """ - return to_taichi_type(self.arr.dtype) - - @property - def data_handle(self): + def _data_handle(self): """Gets the pointer to underlying data. Returns: @@ -79,19 +66,44 @@ def fill(self, val): Args: val (Union[int, float]): Value to fill. """ - self.arr.fill_(val) + if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg( + ).arch != _ti_core.Arch.x64: + self._fill_by_kernel(val) + elif self.dtype == primitive_types.f32: + self.arr.fill_float(val) + elif self.dtype == primitive_types.i32: + self.arr.fill_int(val) + elif self.dtype == primitive_types.u32: + self.arr.fill_uint(val) + else: + self._fill_by_kernel(val) - @python_scope - def to_numpy(self): + def _ndarray_to_numpy(self): """Converts ndarray to a numpy array. Returns: numpy.ndarray: The result numpy array. """ - return self.arr.cpu().numpy() + arr = np.zeros(shape=self.arr.shape, dtype=to_numpy_type(self.dtype)) + from taichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415 + ndarray_to_ext_arr(self, arr) + impl.get_runtime().sync() + return arr - @python_scope - def from_numpy(self, arr): + def _ndarray_matrix_to_numpy(self, as_vector): + """Converts matrix ndarray to a numpy array. + + Returns: + numpy.ndarray: The result numpy array. + """ + arr = np.zeros(shape=self.arr.shape, dtype=to_numpy_type(self.dtype)) + from taichi._kernels import \ + ndarray_matrix_to_ext_arr # pylint: disable=C0415 + ndarray_matrix_to_ext_arr(self, arr, as_vector) + impl.get_runtime().sync() + return arr + + def _ndarray_from_numpy(self, arr): """Loads all values from a numpy array. Args: @@ -103,35 +115,169 @@ def from_numpy(self, arr): raise ValueError( f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided" ) - self.arr = torch.from_numpy(arr).to(self.arr.dtype) + if hasattr(arr, 'contiguous'): + arr = arr.contiguous() + + from taichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415 + ext_arr_to_ndarray(arr, self) + impl.get_runtime().sync() + + def _ndarray_matrix_from_numpy(self, arr, as_vector): + """Loads all values from a numpy array. + + Args: + arr (numpy.ndarray): The source numpy array. + """ + if not isinstance(arr, np.ndarray): + raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided") + if tuple(self.arr.shape) != tuple(arr.shape): + raise ValueError( + f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided" + ) + if hasattr(arr, 'contiguous'): + arr = arr.contiguous() + + from taichi._kernels import \ + ext_arr_to_ndarray_matrix # pylint: disable=C0415 + ext_arr_to_ndarray_matrix(arr, self, as_vector) + impl.get_runtime().sync() + + @python_scope + def _get_element_size(self): + """Returns the size of one element in bytes. + + Returns: + Size in bytes. + """ + return self.arr.element_size() + + @python_scope + def _get_nelement(self): + """Returns the total number of elements. + + Returns: + Total number of elements. + """ + return self.arr.nelement() + + @python_scope + def copy_from(self, other): + """Copies all elements from another ndarray. + + The shape of the other ndarray needs to be the same as `self`. + + Args: + other (Ndarray): The source ndarray. + """ + assert isinstance(other, Ndarray) + assert tuple(self.arr.shape) == tuple(other.arr.shape) + from taichi._kernels import ndarray_to_ndarray # pylint: disable=C0415 + ndarray_to_ndarray(self, other) + impl.get_runtime().sync() + + def __deepcopy__(self, memo=None): + """Copies all elements to a new ndarray. + + Returns: + Ndarray: The result ndarray. + """ + raise NotImplementedError() + + def _fill_by_kernel(self, val): + """Fills ndarray with a specific scalar value using a ti.kernel. + + Args: + val (Union[int, float]): Value to fill. + """ + raise NotImplementedError() + + def _pad_key(self, key): + if key is None: + key = () + if not isinstance(key, (tuple, list)): + key = (key, ) + assert len(key) == len(self.arr.shape) + return key + + def _initialize_host_accessor(self): + if self.host_accessor: + return + impl.get_runtime().materialize() + self.host_accessor = NdarrayHostAccessor(self.arr) class ScalarNdarray(Ndarray): - """Taichi ndarray with scalar elements implemented with a torch tensor. + """Taichi ndarray with scalar elements. Args: dtype (DataType): Data type of each value. shape (Tuple[int]): Shape of the ndarray. """ - def __init__(self, dtype, shape): - super().__init__(dtype, shape) + def __init__(self, dtype, arr_shape): + super().__init__(dtype, arr_shape) + self.shape = tuple(self.arr.shape) @property - def shape(self): - return tuple(self.arr.shape) + def element_shape(self): + return () @python_scope def __setitem__(self, key, value): - self.arr.__setitem__(key, value) + self._initialize_host_accessor() + self.host_accessor.setter(value, *self._pad_key(key)) @python_scope def __getitem__(self, key): - return self.arr.__getitem__(key) + self._initialize_host_accessor() + return self.host_accessor.getter(*self._pad_key(key)) + + @python_scope + def to_numpy(self): + return self._ndarray_to_numpy() + + @python_scope + def from_numpy(self, arr): + self._ndarray_from_numpy(arr) + + def __deepcopy__(self, memo=None): + ret_arr = ScalarNdarray(self.dtype, self.shape) + ret_arr.copy_from(self) + return ret_arr + + def _fill_by_kernel(self, val): + from taichi._kernels import fill_ndarray # pylint: disable=C0415 + fill_ndarray(self, val) def __repr__(self): return '' +class NdarrayHostAccessor: + def __init__(self, ndarray): + if _ti_core.is_real(ndarray.dtype): + + def getter(*key): + return ndarray.read_float(key) + + def setter(value, *key): + ndarray.write_float(key, value) + else: + if _ti_core.is_signed(ndarray.dtype): + + def getter(*key): + return ndarray.read_int(key) + else: + + def getter(*key): + return ndarray.read_uint(key) + + def setter(value, *key): + ndarray.write_int(key, value) + + self.getter = getter + self.setter = setter + + class NdarrayHostAccess: """Class for accessing VectorNdarray/MatrixNdarray in Python scope. Args: @@ -140,14 +286,25 @@ class NdarrayHostAccess: indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix). """ def __init__(self, arr, indices_first, indices_second): + self.ndarr = arr self.arr = arr.arr if arr.layout == Layout.SOA: self.indices = indices_second + indices_first else: self.indices = indices_first + indices_second - def getter(self): - return self.arr[self.indices] + def getter(): + self.ndarr._initialize_host_accessor() + return self.ndarr.host_accessor.getter( + *self.ndarr._pad_key(self.indices)) + + def setter(value): + self.ndarr._initialize_host_accessor() + self.ndarr.host_accessor.setter(value, + *self.ndarr._pad_key(self.indices)) + + self.getter = getter + self.setter = setter + - def setter(self, value): - self.arr[self.indices] = value +__all__ = ["Ndarray", "ScalarNdarray"] diff --git a/python/taichi/lang/ndrange.py b/python/taichi/lang/_ndrange.py similarity index 57% rename from python/taichi/lang/ndrange.py rename to python/taichi/lang/_ndrange.py index 56b9cb445b63b..a4e8ec85d6de7 100644 --- a/python/taichi/lang/ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -1,20 +1,24 @@ -import taichi as ti +import collections.abc +from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.matrix import _IntermediateMatrix -class ndrange: + +class _Ndrange: def __init__(self, *args): args = list(args) - for i in range(len(args)): - if isinstance(args[i], list): - args[i] = tuple(args[i]) - if not isinstance(args[i], tuple): - args[i] = (0, args[i]) - assert len(args[i]) == 2 + for i, arg in enumerate(args): + if not isinstance(arg, collections.abc.Sequence): + args[i] = (0, arg) + if len(args[i]) != 2: + raise TaichiSyntaxError( + "Every argument of ndrange should be a scalar or a tuple/list like (begin, end)" + ) self.bounds = args self.dimensions = [None] * len(args) - for i in range(len(self.bounds)): - self.dimensions[i] = self.bounds[i][1] - self.bounds[i][0] + for i, bound in enumerate(self.bounds): + self.dimensions[i] = bound[1] - bound[0] self.acc_dimensions = self.dimensions.copy() for i in reversed(range(len(self.bounds) - 1)): @@ -38,10 +42,17 @@ def grouped(self): return GroupedNDRange(self) +def ndrange(*args): + return _Ndrange(*args) + + class GroupedNDRange: def __init__(self, r): self.r = r def __iter__(self): for ind in self.r: - yield ti.Vector(list(ind), keep_raw=True) + yield _IntermediateMatrix(len(ind), 1, list(ind)) + + +__all__ = ['ndrange'] diff --git a/python/taichi/lang/_random.py b/python/taichi/lang/_random.py deleted file mode 100644 index af92b772083d8..0000000000000 --- a/python/taichi/lang/_random.py +++ /dev/null @@ -1,19 +0,0 @@ -import math - -from taichi.lang.kernel_impl import func - -import taichi as ti - - -@func -def randn(dt): - ''' - Generates a random number from standard normal distribution - using the Box-Muller transform. - ''' - assert dt == ti.f32 or dt == ti.f64 - u1 = ti.random(dt) - u2 = ti.random(dt) - r = ti.sqrt(-2 * ti.log(u1)) - c = ti.cos(math.tau * u2) - return r * c diff --git a/python/taichi/lang/any_array.py b/python/taichi/lang/any_array.py index 1ad0ff8abaebb..5e126f25d9023 100644 --- a/python/taichi/lang/any_array.py +++ b/python/taichi/lang/any_array.py @@ -1,4 +1,4 @@ -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core from taichi.lang.enums import Layout from taichi.lang.expr import Expr, make_expr_group from taichi.lang.util import taichi_scope @@ -34,13 +34,11 @@ def shape(self): element_dim = len(self.element_shape) if element_dim == 0: return ret - else: - return ret[ - element_dim:] if self.layout == Layout.SOA else ret[: - -element_dim] + return ret[ + element_dim:] if self.layout == Layout.SOA else ret[:-element_dim] @taichi_scope - def loop_range(self): + def _loop_range(self): """Gets the corresponding taichi_core.Expr to serve as loop range. This is not in use now because struct fors on AnyArrays are not supported yet. @@ -71,3 +69,6 @@ def subscript(self, i, j): indices = self.indices_first + indices_second return Expr(_ti_core.subscript(self.arr.ptr, make_expr_group(*indices))) + + +__all__ = [] diff --git a/python/taichi/lang/ast/__init__.py b/python/taichi/lang/ast/__init__.py index 8b137891791fe..5787470b3631b 100644 --- a/python/taichi/lang/ast/__init__.py +++ b/python/taichi/lang/ast/__init__.py @@ -1 +1,3 @@ - +from taichi.lang.ast.ast_transformer_utils import ASTTransformerContext +from taichi.lang.ast.checkers import KernelSimplicityASTChecker +from taichi.lang.ast.transform import transform_tree diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py new file mode 100644 index 0000000000000..002cc2de76780 --- /dev/null +++ b/python/taichi/lang/ast/ast_transformer.py @@ -0,0 +1,1238 @@ +import ast +import collections.abc +import itertools +import warnings +from collections import ChainMap +from sys import version_info + +import astor +from taichi._lib import core as _ti_core +from taichi.lang import expr, impl, kernel_arguments, matrix, mesh +from taichi.lang import ops as ti_ops +from taichi.lang._ndrange import _Ndrange, ndrange +from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus +from taichi.lang.ast.symbol_resolver import ASTResolver +from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.matrix import MatrixType +from taichi.lang.util import is_taichi_class, to_taichi_type +from taichi.types import annotations, primitive_types + +if version_info < (3, 9): + from astunparse import unparse +else: + from ast import unparse + + +class ASTTransformer(Builder): + @staticmethod + def build_Name(ctx, node): + node.ptr = ctx.get_var_by_name(node.id) + return node.ptr + + @staticmethod + def build_AnnAssign(ctx, node): + build_stmt(ctx, node.value) + build_stmt(ctx, node.annotation) + + is_static_assign = isinstance( + node.value, ast.Call) and node.value.func.ptr is impl.static + + node.ptr = ASTTransformer.build_assign_annotated( + ctx, node.target, node.value.ptr, is_static_assign, + node.annotation.ptr) + return node.ptr + + @staticmethod + def build_assign_annotated(ctx, target, value, is_static_assign, + annotation): + """Build an annotated assginment like this: target: annotation = value. + + Args: + ctx (ast_builder_utils.BuilderContext): The builder context. + target (ast.Name): A variable name. `target.id` holds the name as + a string. + annotation: A type we hope to assign to the target + value: A node representing the value. + is_static_assign: A boolean value indicating whether this is a static assignment + """ + is_local = isinstance(target, ast.Name) + anno = impl.expr_init(annotation) + if is_static_assign: + raise TaichiSyntaxError( + "Static assign cannot be used on annotated assignment") + if is_local and not ctx.is_var_declared(target.id): + var = ti_ops.cast(value, anno) + var = impl.expr_init(var) + ctx.create_variable(target.id, var) + else: + var = build_stmt(ctx, target) + if var.ptr.get_ret_type() != anno: + raise TaichiSyntaxError( + "Static assign cannot have type overloading") + var._assign(value) + return var + + @staticmethod + def build_Assign(ctx, node): + build_stmt(ctx, node.value) + + is_static_assign = isinstance( + node.value, ast.Call) and node.value.func.ptr is impl.static + + # Keep all generated assign statements and compose single one at last. + # The variable is introduced to support chained assignments. + # Ref https://github.com/taichi-dev/taichi/issues/2659. + for node_target in node.targets: + ASTTransformer.build_assign_unpack(ctx, node_target, + node.value.ptr, + is_static_assign) + return None + + @staticmethod + def build_assign_unpack(ctx, node_target, values, is_static_assign): + """Build the unpack assignments like this: (target1, target2) = (value1, value2). + The function should be called only if the node target is a tuple. + + Args: + ctx (ast_builder_utils.BuilderContext): The builder context. + node_target (ast.Tuple): A list or tuple object. `node_target.elts` holds a + list of nodes representing the elements. + values: A node/list representing the values. + is_static_assign: A boolean value indicating whether this is a static assignment + """ + if not isinstance(node_target, ast.Tuple): + return ASTTransformer.build_assign_basic(ctx, node_target, values, + is_static_assign) + targets = node_target.elts + tmp_tuple = values if is_static_assign else impl.expr_init_list( + values, len(targets)) + + for i, target in enumerate(targets): + ASTTransformer.build_assign_basic(ctx, target, tmp_tuple[i], + is_static_assign) + + return None + + @staticmethod + def build_assign_basic(ctx, target, value, is_static_assign): + """Build basic assginment like this: target = value. + + Args: + ctx (ast_builder_utils.BuilderContext): The builder context. + target (ast.Name): A variable name. `target.id` holds the name as + a string. + value: A node representing the value. + is_static_assign: A boolean value indicating whether this is a static assignment + """ + is_local = isinstance(target, ast.Name) + if is_static_assign: + if not is_local: + raise TaichiSyntaxError( + "Static assign cannot be used on elements in arrays") + ctx.create_variable(target.id, value) + var = value + elif is_local and not ctx.is_var_declared(target.id): + var = impl.expr_init(value) + ctx.create_variable(target.id, var) + else: + var = build_stmt(ctx, target) + try: + var._assign(value) + except AttributeError: + raise TaichiSyntaxError( + f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a Taichi object?" + ) + return var + + @staticmethod + def build_NamedExpr(ctx, node): + build_stmt(ctx, node.value) + is_static_assign = isinstance( + node.value, ast.Call) and node.value.func.ptr is impl.static + node.ptr = ASTTransformer.build_assign_basic(ctx, node.target, + node.value.ptr, + is_static_assign) + return node.ptr + + @staticmethod + def is_tuple(node): + if isinstance(node, ast.Tuple): + return True + if isinstance(node, ast.Index) and isinstance(node.value.ptr, tuple): + return True + if isinstance(node.ptr, tuple): + return True + return False + + @staticmethod + def build_Subscript(ctx, node): + build_stmt(ctx, node.value) + build_stmt(ctx, node.slice) + if not ASTTransformer.is_tuple(node.slice): + node.slice.ptr = [node.slice.ptr] + node.ptr = impl.subscript(node.value.ptr, *node.slice.ptr) + return node.ptr + + @staticmethod + def build_Slice(ctx, node): + if node.lower is not None: + build_stmt(ctx, node.lower) + if node.upper is not None: + build_stmt(ctx, node.upper) + if node.step is not None: + build_stmt(ctx, node.step) + + node.ptr = slice(node.lower.ptr if node.lower else None, + node.upper.ptr if node.upper else None, + node.step.ptr if node.step else None) + return node.ptr + + @staticmethod + def build_ExtSlice(ctx, node): + build_stmts(ctx, node.dims) + node.ptr = tuple(dim.ptr for dim in node.dims) + return node.ptr + + @staticmethod + def build_Tuple(ctx, node): + build_stmts(ctx, node.elts) + node.ptr = tuple(elt.ptr for elt in node.elts) + return node.ptr + + @staticmethod + def build_List(ctx, node): + build_stmts(ctx, node.elts) + node.ptr = [elt.ptr for elt in node.elts] + return node.ptr + + @staticmethod + def build_Dict(ctx, node): + dic = {} + for key, value in zip(node.keys, node.values): + if key is None: + dic.update(build_stmt(ctx, value)) + else: + dic[build_stmt(ctx, key)] = build_stmt(ctx, value) + node.ptr = dic + return node.ptr + + @staticmethod + def process_listcomp(ctx, node, result): + result.append(build_stmt(ctx, node.elt)) + + @staticmethod + def process_dictcomp(ctx, node, result): + key = build_stmt(ctx, node.key) + value = build_stmt(ctx, node.value) + result[key] = value + + @staticmethod + def process_generators(ctx, node, now_comp, func, result): + if now_comp >= len(node.generators): + return func(ctx, node, result) + with ctx.static_scope_guard(): + _iter = build_stmt(ctx, node.generators[now_comp].iter) + for value in _iter: + with ctx.variable_scope_guard(): + ASTTransformer.build_assign_unpack( + ctx, node.generators[now_comp].target, value, True) + with ctx.static_scope_guard(): + build_stmts(ctx, node.generators[now_comp].ifs) + ASTTransformer.process_ifs(ctx, node, now_comp, 0, func, + result) + return None + + @staticmethod + def process_ifs(ctx, node, now_comp, now_if, func, result): + if now_if >= len(node.generators[now_comp].ifs): + return ASTTransformer.process_generators(ctx, node, now_comp + 1, + func, result) + cond = node.generators[now_comp].ifs[now_if].ptr + if cond: + ASTTransformer.process_ifs(ctx, node, now_comp, now_if + 1, func, + result) + + return None + + @staticmethod + def build_ListComp(ctx, node): + result = [] + ASTTransformer.process_generators(ctx, node, 0, + ASTTransformer.process_listcomp, + result) + node.ptr = result + return node.ptr + + @staticmethod + def build_DictComp(ctx, node): + result = {} + ASTTransformer.process_generators(ctx, node, 0, + ASTTransformer.process_dictcomp, + result) + node.ptr = result + return node.ptr + + @staticmethod + def build_Index(ctx, node): + + node.ptr = build_stmt(ctx, node.value) + return node.ptr + + @staticmethod + def build_Constant(ctx, node): + node.ptr = node.value + return node.ptr + + @staticmethod + def build_Num(ctx, node): + node.ptr = node.n + return node.ptr + + @staticmethod + def build_Str(ctx, node): + node.ptr = node.s + return node.ptr + + @staticmethod + def build_Bytes(ctx, node): + node.ptr = node.s + return node.ptr + + @staticmethod + def build_NameConstant(ctx, node): + node.ptr = node.value + return node.ptr + + @staticmethod + def build_keyword(ctx, node): + build_stmt(ctx, node.value) + if node.arg is None: + node.ptr = node.value.ptr + else: + node.ptr = {node.arg: node.value.ptr} + return node.ptr + + @staticmethod + def build_Starred(ctx, node): + node.ptr = build_stmt(ctx, node.value) + return node.ptr + + @staticmethod + def build_JoinedStr(ctx, node): + str_spec = '' + args = [] + for sub_node in node.values: + if isinstance(sub_node, ast.FormattedValue): + str_spec += '{}' + args.append(build_stmt(ctx, sub_node.value)) + elif isinstance(sub_node, ast.Constant): + str_spec += sub_node.value + elif isinstance(sub_node, ast.Str): + str_spec += sub_node.s + else: + raise TaichiSyntaxError("Invalid value for fstring.") + + args.insert(0, str_spec) + node.ptr = impl.ti_format(*args) + return node.ptr + + @staticmethod + def build_call_if_is_builtin(ctx, node, args, keywords): + func = node.func.ptr + replace_func = { + id(print): impl.ti_print, + id(min): ti_ops.min, + id(max): ti_ops.max, + id(int): impl.ti_int, + id(float): impl.ti_float, + id(any): ti_ops.ti_any, + id(all): ti_ops.ti_all, + id(abs): abs, + id(pow): pow, + } + if id(func) in replace_func: + node.ptr = replace_func[id(func)](*args, **keywords) + if func is min or func is max: + name = "min" if func is min else "max" + warnings.warn_explicit( + f'Calling builtin function "{name}" in Taichi scope is deprecated. ' + f'Please use "ti.{name}" instead.', DeprecationWarning, + ctx.file, node.lineno + ctx.lineno_offset) + return True + return False + + @staticmethod + def build_call_if_is_type(ctx, node, args, keywords): + func = node.func.ptr + if id(func) in primitive_types.type_ids: + if len(args) != 1 or keywords or isinstance(args[0], expr.Expr): + raise TaichiSyntaxError( + "Type annotation can only be given to a single literal.") + node.ptr = expr.Expr(args[0], dtype=func) + return True + return False + + @staticmethod + def warn_if_is_external_func(ctx, node): + func = node.func.ptr + if ctx.is_in_static_scope(): # allow external function in static scope + return + if hasattr(func, "_is_taichi_function") or hasattr( + func, "_is_wrapped_kernel"): # taichi func/kernel + return + if hasattr( + func, "__module__" + ) and func.__module__ and func.__module__.startswith("taichi."): + return + name = unparse(node.func).strip() + warnings.warn_explicit( + f'Calling non-taichi function "{name}". ' + f'Scope inside the function is not processed by the Taichi AST transformer. ' + f'The function may not work as expected. Proceed with caution! ' + f'Maybe you can consider turning it into a @ti.func?', UserWarning, + ctx.file, node.lineno + ctx.lineno_offset) + + @staticmethod + def build_Call(ctx, node): + if ASTTransformer.get_decorator(ctx, node) == 'static': + with ctx.static_scope_guard(): + build_stmt(ctx, node.func) + build_stmts(ctx, node.args) + build_stmts(ctx, node.keywords) + else: + build_stmt(ctx, node.func) + build_stmts(ctx, node.args) + build_stmts(ctx, node.keywords) + + args = [] + for arg in node.args: + if isinstance(arg, ast.Starred): + for i in arg.ptr: + args.append(i) + else: + args.append(arg.ptr) + keywords = dict(ChainMap(*[keyword.ptr for keyword in node.keywords])) + func = node.func.ptr + + if isinstance(node.func, ast.Attribute) and isinstance( + node.func.value.ptr, str) and node.func.attr == 'format': + args.insert(0, node.func.value.ptr) + node.ptr = impl.ti_format(*args, **keywords) + return node.ptr + + if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): + return node.ptr + + if ASTTransformer.build_call_if_is_type(ctx, node, args, keywords): + return node.ptr + + node.ptr = func(*args, **keywords) + ASTTransformer.warn_if_is_external_func(ctx, node) + + return node.ptr + + @staticmethod + def build_FunctionDef(ctx, node): + if ctx.visited_funcdef: + raise TaichiSyntaxError( + f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'." + ) + ctx.visited_funcdef = True + + args = node.args + assert args.vararg is None + assert args.kwonlyargs == [] + assert args.kw_defaults == [] + assert args.kwarg is None + + def transform_as_kernel(): + # Treat return type + if node.returns is not None: + kernel_arguments.decl_ret(ctx.func.return_type) + + for i, arg in enumerate(args.args): + if isinstance(ctx.func.argument_annotations[i], + annotations.template): + ctx.create_variable(arg.arg, ctx.global_vars[arg.arg]) + elif isinstance(ctx.func.argument_annotations[i], + annotations.sparse_matrix_builder): + ctx.create_variable( + arg.arg, + kernel_arguments.decl_sparse_matrix( + to_taichi_type(ctx.arg_features[i]))) + elif isinstance(ctx.func.argument_annotations[i], + annotations.any_arr): + ctx.create_variable( + arg.arg, + kernel_arguments.decl_any_arr_arg( + to_taichi_type(ctx.arg_features[i][0]), + ctx.arg_features[i][1], ctx.arg_features[i][2], + ctx.arg_features[i][3])) + elif isinstance(ctx.func.argument_annotations[i], MatrixType): + ctx.create_variable( + arg.arg, + kernel_arguments.decl_matrix_arg( + ctx.func.argument_annotations[i])) + else: + ctx.global_vars[ + arg.arg] = kernel_arguments.decl_scalar_arg( + ctx.func.argument_annotations[i]) + # remove original args + node.args.args = [] + + if ctx.is_kernel: # ti.kernel + transform_as_kernel() + + else: # ti.func + if ctx.is_real_function: + transform_as_kernel() + else: + len_args = len(args.args) + len_default = len(args.defaults) + len_provided = len(ctx.argument_data) + len_minimum = len_args - len_default + if len_args < len_provided or len_args - len_default > len_provided: + if len(args.defaults): + raise TaichiSyntaxError( + f"Function receives {len_minimum} to {len_args} argument(s) and {len_provided} provided." + ) + else: + raise TaichiSyntaxError( + f"Function receives {len_args} argument(s) and {len_provided} provided." + ) + # Transform as force-inlined func + default_start = len_provided - len_minimum + ctx.argument_data = list(ctx.argument_data) + for arg in args.defaults[default_start:]: + ctx.argument_data.append(build_stmt(ctx, arg)) + assert len(args.args) == len(ctx.argument_data) + for i, (arg, + data) in enumerate(zip(args.args, ctx.argument_data)): + # Remove annotations because they are not used. + args.args[i].annotation = None + # Template arguments are passed by reference. + if isinstance(ctx.func.argument_annotations[i], + annotations.template): + ctx.create_variable(ctx.func.argument_names[i], data) + continue + # Create a copy for non-template arguments, + # so that they are passed by value. + ctx.create_variable(arg.arg, impl.expr_init_func(data)) + + with ctx.variable_scope_guard(): + build_stmts(ctx, node.body) + + return None + + @staticmethod + def build_Return(ctx, node): + if not ctx.is_real_function: + if ctx.is_in_non_static_control_flow(): + raise TaichiSyntaxError( + "Return inside non-static if/for is not supported") + build_stmt(ctx, node.value) + if ctx.is_kernel or ctx.is_real_function: + # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not + if node.value is not None: + if ctx.func.return_type is None: + raise TaichiSyntaxError( + f'A {"kernel" if ctx.is_kernel else "function"} ' + 'with a return value must be annotated ' + 'with a return type, e.g. def func() -> ti.f32') + if id(ctx.func.return_type) in primitive_types.type_ids: + ctx.ast_builder.create_kernel_exprgroup_return( + expr.make_expr_group( + ti_ops.cast(expr.Expr(node.value.ptr), + ctx.func.return_type).ptr)) + elif isinstance(ctx.func.return_type, MatrixType): + ctx.ast_builder.create_kernel_exprgroup_return( + expr.make_expr_group([ + ti_ops.cast(exp, ctx.func.return_type.dtype) + for exp in itertools.chain.from_iterable( + node.value.ptr.to_list()) + ])) + else: + raise TaichiSyntaxError( + "The return type is not supported now!") + # For args[0], it is an ast.Attribute, because it loads the + # attribute, |ptr|, of the expression |ret_expr|. Therefore we + # only need to replace the object part, i.e. args[0].value + else: + ctx.return_data = node.value.ptr + if not ctx.is_real_function: + ctx.returned = True + return None + + @staticmethod + def build_Module(ctx, node): + with ctx.variable_scope_guard(): + # Do NOT use |build_stmts| which inserts 'del' statements to the + # end and deletes parameters passed into the module + for stmt in node.body: + build_stmt(ctx, stmt) + return None + + @staticmethod + def build_Attribute(ctx, node): + build_stmt(ctx, node.value) + node.ptr = getattr(node.value.ptr, node.attr) + return node.ptr + + @staticmethod + def build_BinOp(ctx, node): + build_stmt(ctx, node.left) + build_stmt(ctx, node.right) + op = { + ast.Add: lambda l, r: l + r, + ast.Sub: lambda l, r: l - r, + ast.Mult: lambda l, r: l * r, + ast.Div: lambda l, r: l / r, + ast.FloorDiv: lambda l, r: l // r, + ast.Mod: lambda l, r: l % r, + ast.Pow: lambda l, r: l**r, + ast.LShift: lambda l, r: l << r, + ast.RShift: lambda l, r: l >> r, + ast.BitOr: lambda l, r: l | r, + ast.BitXor: lambda l, r: l ^ r, + ast.BitAnd: lambda l, r: l & r, + ast.MatMult: lambda l, r: l @ r, + }.get(type(node.op)) + node.ptr = op(node.left.ptr, node.right.ptr) + return node.ptr + + @staticmethod + def build_AugAssign(ctx, node): + build_stmt(ctx, node.target) + build_stmt(ctx, node.value) + node.ptr = node.target.ptr._augassign(node.value.ptr, + type(node.op).__name__) + return node.ptr + + @staticmethod + def build_UnaryOp(ctx, node): + build_stmt(ctx, node.operand) + op = { + ast.UAdd: lambda l: l, + ast.USub: lambda l: -l, + ast.Not: ti_ops.logical_not, + ast.Invert: lambda l: ~l, + }.get(type(node.op)) + node.ptr = op(node.operand.ptr) + return node.ptr + + @staticmethod + def build_short_circuit_and(ast_builder, operands): + if len(operands) == 1: + return operands[0].ptr + + val = impl.expr_init(None) + lhs = operands[0].ptr + impl.begin_frontend_if(ast_builder, lhs) + + ast_builder.begin_frontend_if_true() + rhs = ASTTransformer.build_short_circuit_and(ast_builder, operands[1:]) + val._assign(rhs) + ast_builder.pop_scope() + + ast_builder.begin_frontend_if_false() + val._assign(0) + ast_builder.pop_scope() + + return val + + @staticmethod + def build_short_circuit_or(ast_builder, operands): + if len(operands) == 1: + return operands[0].ptr + + val = impl.expr_init(None) + lhs = operands[0].ptr + impl.begin_frontend_if(ast_builder, lhs) + + ast_builder.begin_frontend_if_true() + val._assign(1) + ast_builder.pop_scope() + + ast_builder.begin_frontend_if_false() + rhs = ASTTransformer.build_short_circuit_or(ast_builder, operands[1:]) + val._assign(rhs) + ast_builder.pop_scope() + + return val + + @staticmethod + def build_normal_bool_op(op): + def inner(ast_builder, operands): + result = op(operands[0].ptr, operands[1].ptr) + for i in range(2, len(operands)): + result = op(result, operands[i].ptr) + return result + + return inner + + @staticmethod + def build_static_short_circuit_and(ast_builder, operands): + for operand in operands: + if not operand.ptr: + return operand.ptr + return operands[-1].ptr + + @staticmethod + def build_static_short_circuit_or(ast_builder, operands): + for operand in operands: + if operand.ptr: + return operand.ptr + return operands[-1].ptr + + @staticmethod + def build_BoolOp(ctx, node): + build_stmts(ctx, node.values) + if ctx.is_in_static_scope(): + ops = { + ast.And: ASTTransformer.build_static_short_circuit_and, + ast.Or: ASTTransformer.build_static_short_circuit_or, + } + elif impl.get_runtime().short_circuit_operators: + ops = { + ast.And: ASTTransformer.build_short_circuit_and, + ast.Or: ASTTransformer.build_short_circuit_or, + } + else: + ops = { + ast.And: + ASTTransformer.build_normal_bool_op(ti_ops.logical_and), + ast.Or: ASTTransformer.build_normal_bool_op(ti_ops.logical_or), + } + op = ops.get(type(node.op)) + node.ptr = op(ctx.ast_builder, node.values) + return node.ptr + + @staticmethod + def build_Compare(ctx, node): + build_stmt(ctx, node.left) + build_stmts(ctx, node.comparators) + ops = { + ast.Eq: lambda l, r: l == r, + ast.NotEq: lambda l, r: l != r, + ast.Lt: lambda l, r: l < r, + ast.LtE: lambda l, r: l <= r, + ast.Gt: lambda l, r: l > r, + ast.GtE: lambda l, r: l >= r, + } + ops_static = { + ast.In: lambda l, r: l in r, + ast.NotIn: lambda l, r: l not in r, + ast.Is: lambda l, r: l is r, + ast.IsNot: lambda l, r: l is not r, + } + if ctx.is_in_static_scope(): + ops = {**ops, **ops_static} + operands = [node.left.ptr + ] + [comparator.ptr for comparator in node.comparators] + val = True + for i, node_op in enumerate(node.ops): + l = operands[i] + r = operands[i + 1] + op = ops.get(type(node_op)) + if isinstance(node_op, (ast.Is, ast.IsNot)): + name = "is" if isinstance(node_op, ast.Is) else "is not" + warnings.warn_explicit( + f'Operator "{name}" in Taichi scope is deprecated. Please avoid using it.', + DeprecationWarning, ctx.file, + node.lineno + ctx.lineno_offset) + if op is None: + if type(node_op) in ops_static: + raise TaichiSyntaxError( + f'"{type(node_op).__name__}" is only supported inside `ti.static`.' + ) + else: + raise TaichiSyntaxError( + f'"{type(node_op).__name__}" is not supported in Taichi kernels.' + ) + val = ti_ops.logical_and(val, op(l, r)) + node.ptr = val + return node.ptr + + @staticmethod + def get_decorator(ctx, node): + if not isinstance(node, ast.Call): + return '' + for wanted, name in [ + (impl.static, 'static'), + (impl.grouped, 'grouped'), + (ndrange, 'ndrange'), + ]: + if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars): + return name + return '' + + @staticmethod + def get_for_loop_targets(node): + """ + Returns the list of indices of the for loop |node|. + See also: https://docs.python.org/3/library/ast.html#ast.For + """ + if isinstance(node.target, ast.Name): + return [node.target.id] + assert isinstance(node.target, ast.Tuple) + return [name.id for name in node.target.elts] + + @staticmethod + def build_static_for(ctx, node, is_grouped): + if is_grouped: + assert len(node.iter.args[0].args) == 1 + ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0]) + if not isinstance(ndrange_arg, _Ndrange): + raise TaichiSyntaxError( + "Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'." + ) + targets = ASTTransformer.get_for_loop_targets(node) + if len(targets) != 1: + raise TaichiSyntaxError( + f"Group for should have 1 loop target, found {len(targets)}" + ) + target = targets[0] + for value in impl.grouped(ndrange_arg): + with ctx.variable_scope_guard(): + ctx.create_variable(target, value) + build_stmts(ctx, node.body) + status = ctx.loop_status() + if status == LoopStatus.Break: + break + elif status == LoopStatus.Continue: + ctx.set_loop_status(LoopStatus.Normal) + else: + build_stmt(ctx, node.iter) + targets = ASTTransformer.get_for_loop_targets(node) + for target_values in node.iter.ptr: + if not isinstance( + target_values, + collections.abc.Sequence) or len(targets) == 1: + target_values = [target_values] + with ctx.variable_scope_guard(): + for target, target_value in zip(targets, target_values): + ctx.create_variable(target, target_value) + build_stmts(ctx, node.body) + status = ctx.loop_status() + if status == LoopStatus.Break: + break + elif status == LoopStatus.Continue: + ctx.set_loop_status(LoopStatus.Normal) + return None + + @staticmethod + def build_range_for(ctx, node): + with ctx.variable_scope_guard(): + loop_name = node.target.id + ctx.check_loop_var(loop_name) + loop_var = expr.Expr(_ti_core.make_id_expr('')) + ctx.create_variable(loop_name, loop_var) + if len(node.iter.args) not in [1, 2]: + raise TaichiSyntaxError( + f"Range should have 1 or 2 arguments, found {len(node.iter.args)}" + ) + if len(node.iter.args) == 2: + begin = ti_ops.cast( + expr.Expr(build_stmt(ctx, node.iter.args[0])), + primitive_types.i32) + end = ti_ops.cast( + expr.Expr(build_stmt(ctx, node.iter.args[1])), + primitive_types.i32) + else: + begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) + end = ti_ops.cast( + expr.Expr(build_stmt(ctx, node.iter.args[0])), + primitive_types.i32) + ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, + end.ptr) + build_stmts(ctx, node.body) + ctx.ast_builder.end_frontend_range_for() + return None + + @staticmethod + def build_ndrange_for(ctx, node): + with ctx.variable_scope_guard(): + ndrange_var = impl.expr_init(build_stmt(ctx, node.iter)) + ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) + ndrange_end = ti_ops.cast( + expr.Expr(impl.subscript(ndrange_var.acc_dimensions, 0)), + primitive_types.i32) + ndrange_loop_var = expr.Expr(_ti_core.make_id_expr('')) + ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, + ndrange_begin.ptr, + ndrange_end.ptr) + I = impl.expr_init(ndrange_loop_var) + targets = ASTTransformer.get_for_loop_targets(node) + for i, target in enumerate(targets): + if i + 1 < len(targets): + target_tmp = impl.expr_init( + I // ndrange_var.acc_dimensions[i + 1]) + else: + target_tmp = impl.expr_init(I) + ctx.create_variable( + target, + impl.expr_init(target_tmp + impl.subscript( + impl.subscript(ndrange_var.bounds, i), 0))) + if i + 1 < len(targets): + I._assign(I - + target_tmp * ndrange_var.acc_dimensions[i + 1]) + build_stmts(ctx, node.body) + ctx.ast_builder.end_frontend_range_for() + return None + + @staticmethod + def build_grouped_ndrange_for(ctx, node): + with ctx.variable_scope_guard(): + ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0])) + ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32) + ndrange_end = ti_ops.cast( + expr.Expr(impl.subscript(ndrange_var.acc_dimensions, 0)), + primitive_types.i32) + ndrange_loop_var = expr.Expr(_ti_core.make_id_expr('')) + ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, + ndrange_begin.ptr, + ndrange_end.ptr) + + targets = ASTTransformer.get_for_loop_targets(node) + if len(targets) != 1: + raise TaichiSyntaxError( + f"Group for should have 1 loop target, found {len(targets)}" + ) + target = targets[0] + target_var = impl.expr_init( + matrix.Vector([0] * len(ndrange_var.dimensions), + dt=primitive_types.i32)) + ctx.create_variable(target, target_var) + I = impl.expr_init(ndrange_loop_var) + for i in range(len(ndrange_var.dimensions)): + if i + 1 < len(ndrange_var.dimensions): + target_tmp = I // ndrange_var.acc_dimensions[i + 1] + else: + target_tmp = I + impl.subscript(target_var, i)._assign(target_tmp + + ndrange_var.bounds[i][0]) + if i + 1 < len(ndrange_var.dimensions): + I._assign(I - + target_tmp * ndrange_var.acc_dimensions[i + 1]) + build_stmts(ctx, node.body) + ctx.ast_builder.end_frontend_range_for() + return None + + @staticmethod + def build_struct_for(ctx, node, is_grouped): + # for i, j in x + # for I in ti.grouped(x) + targets = ASTTransformer.get_for_loop_targets(node) + + for target in targets: + ctx.check_loop_var(target) + + with ctx.variable_scope_guard(): + if is_grouped: + if len(targets) != 1: + raise TaichiSyntaxError( + f"Group for should have 1 loop target, found {len(targets)}" + ) + target = targets[0] + loop_var = build_stmt(ctx, node.iter) + loop_indices = expr.make_var_list(size=len(loop_var.shape)) + expr_group = expr.make_expr_group(loop_indices) + impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, + loop_var) + ctx.create_variable( + target, matrix.Vector(loop_indices, + dt=primitive_types.i32)) + build_stmts(ctx, node.body) + ctx.ast_builder.end_frontend_struct_for() + else: + _vars = [] + for name in targets: + var = expr.Expr(_ti_core.make_id_expr("")) + _vars.append(var) + ctx.create_variable(name, var) + loop_var = node.iter.ptr + expr_group = expr.make_expr_group(*_vars) + impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, + loop_var) + build_stmts(ctx, node.body) + ctx.ast_builder.end_frontend_struct_for() + return None + + @staticmethod + def build_mesh_for(ctx, node): + targets = ASTTransformer.get_for_loop_targets(node) + if len(targets) != 1: + raise TaichiSyntaxError( + "Mesh for should have 1 loop target, found {len(targets)}") + target = targets[0] + + with ctx.variable_scope_guard(): + var = expr.Expr(_ti_core.make_id_expr("")) + ctx.mesh = node.iter.ptr.mesh + assert isinstance(ctx.mesh, impl.MeshInstance) + mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, + node.iter.ptr._type, var.ptr) + ctx.create_variable(target, mesh_idx) + ctx.ast_builder.begin_frontend_mesh_for(mesh_idx.ptr, + ctx.mesh.mesh_ptr, + node.iter.ptr._type) + build_stmts(ctx, node.body) + ctx.mesh = None + ctx.ast_builder.end_frontend_mesh_for() + return None + + @staticmethod + def build_nested_mesh_for(ctx, node): + targets = ASTTransformer.get_for_loop_targets(node) + if len(targets) != 1: + raise TaichiSyntaxError( + "Nested-mesh for should have 1 loop target, found {len(targets)}" + ) + target = targets[0] + + with ctx.variable_scope_guard(): + ctx.mesh = node.iter.ptr.mesh + assert isinstance(ctx.mesh, impl.MeshInstance) + loop_name = node.target.id + '_index__' + loop_var = expr.Expr(_ti_core.make_id_expr('')) + ctx.create_variable(loop_name, loop_var) + begin = expr.Expr(0) + end = node.iter.ptr.size + ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, + end.ptr) + entry_expr = _ti_core.get_relation_access( + ctx.mesh.mesh_ptr, node.iter.ptr.from_index.ptr, + node.iter.ptr.to_element_type, loop_var.ptr) + entry_expr.type_check(impl.get_runtime().prog.config) + mesh_idx = mesh.MeshElementFieldProxy( + ctx.mesh, node.iter.ptr.to_element_type, entry_expr) + ctx.create_variable(target, mesh_idx) + build_stmts(ctx, node.body) + ctx.ast_builder.end_frontend_range_for() + + return None + + @staticmethod + def build_For(ctx, node): + if node.orelse: + raise TaichiSyntaxError( + "'else' clause for 'for' not supported in Taichi kernels") + decorator = ASTTransformer.get_decorator(ctx, node.iter) + double_decorator = '' + if decorator != '' and len(node.iter.args) == 1: + double_decorator = ASTTransformer.get_decorator( + ctx, node.iter.args[0]) + + if decorator == 'static': + if double_decorator == 'static': + raise TaichiSyntaxError("'ti.static' cannot be nested") + with ctx.loop_scope_guard(is_static=True): + return ASTTransformer.build_static_for( + ctx, node, double_decorator == 'grouped') + with ctx.loop_scope_guard(): + if decorator == 'ndrange': + if double_decorator != '': + raise TaichiSyntaxError( + "No decorator is allowed inside 'ti.ndrange") + return ASTTransformer.build_ndrange_for(ctx, node) + if decorator == 'grouped': + if double_decorator == 'static': + raise TaichiSyntaxError( + "'ti.static' is not allowed inside 'ti.grouped'") + elif double_decorator == 'ndrange': + return ASTTransformer.build_grouped_ndrange_for(ctx, node) + elif double_decorator == 'grouped': + raise TaichiSyntaxError("'ti.grouped' cannot be nested") + else: + return ASTTransformer.build_struct_for(ctx, + node, + is_grouped=True) + elif isinstance(node.iter, ast.Call) and isinstance( + node.iter.func, ast.Name) and node.iter.func.id == 'range': + return ASTTransformer.build_range_for(ctx, node) + else: + build_stmt(ctx, node.iter) + if isinstance(node.iter.ptr, mesh.MeshElementField): + if not _ti_core.is_extension_supported( + impl.default_cfg().arch, _ti_core.Extension.mesh): + raise Exception( + 'Backend ' + str(impl.default_cfg().arch) + + ' doesn\'t support MeshTaichi extension') + return ASTTransformer.build_mesh_for(ctx, node) + if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy): + return ASTTransformer.build_nested_mesh_for(ctx, node) + # Struct for + return ASTTransformer.build_struct_for(ctx, + node, + is_grouped=False) + + @staticmethod + def build_While(ctx, node): + if node.orelse: + raise TaichiSyntaxError( + "'else' clause for 'while' not supported in Taichi kernels") + + with ctx.loop_scope_guard(): + ctx.ast_builder.begin_frontend_while(expr.Expr(1).ptr) + while_cond = build_stmt(ctx, node.test) + impl.begin_frontend_if(ctx.ast_builder, while_cond) + ctx.ast_builder.begin_frontend_if_true() + ctx.ast_builder.pop_scope() + ctx.ast_builder.begin_frontend_if_false() + ctx.ast_builder.insert_break_stmt() + ctx.ast_builder.pop_scope() + build_stmts(ctx, node.body) + ctx.ast_builder.pop_scope() + return None + + @staticmethod + def build_If(ctx, node): + build_stmt(ctx, node.test) + is_static_if = (ASTTransformer.get_decorator(ctx, + node.test) == "static") + + if is_static_if: + if node.test.ptr: + build_stmts(ctx, node.body) + else: + build_stmts(ctx, node.orelse) + return node + + with ctx.non_static_control_flow_guard(): + impl.begin_frontend_if(ctx.ast_builder, node.test.ptr) + ctx.ast_builder.begin_frontend_if_true() + build_stmts(ctx, node.body) + ctx.ast_builder.pop_scope() + ctx.ast_builder.begin_frontend_if_false() + build_stmts(ctx, node.orelse) + ctx.ast_builder.pop_scope() + return None + + @staticmethod + def build_Expr(ctx, node): + build_stmt(ctx, node.value) + if not isinstance(node.value, ast.Call): + return None + is_taichi_function = getattr(node.value.func.ptr, + '_is_taichi_function', False) + if is_taichi_function and node.value.func.ptr._is_real_function: + func_call_result = node.value.ptr + ctx.ast_builder.insert_expr_stmt(func_call_result.ptr) + return None + + @staticmethod + def build_IfExp(ctx, node): + build_stmt(ctx, node.test) + build_stmt(ctx, node.body) + build_stmt(ctx, node.orelse) + + if is_taichi_class(node.test.ptr) or is_taichi_class( + node.body.ptr) or is_taichi_class(node.orelse.ptr): + node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, + node.orelse.ptr) + warnings.warn_explicit( + 'Using conditional expression for element-wise select operation on ' + 'Taichi vectors/matrices is deprecated. ' + 'Please use "ti.select" instead.', DeprecationWarning, + ctx.file, node.lineno + ctx.lineno_offset) + return node.ptr + + is_static_if = (ASTTransformer.get_decorator(ctx, + node.test) == "static") + + if is_static_if: + if node.test.ptr: + node.ptr = build_stmt(ctx, node.body) + else: + node.ptr = build_stmt(ctx, node.orelse) + return node.ptr + + val = impl.expr_init(None) + + impl.begin_frontend_if(ctx.ast_builder, node.test.ptr) + ctx.ast_builder.begin_frontend_if_true() + val._assign(node.body.ptr) + ctx.ast_builder.pop_scope() + ctx.ast_builder.begin_frontend_if_false() + val._assign(node.orelse.ptr) + ctx.ast_builder.pop_scope() + + node.ptr = val + return node.ptr + + @staticmethod + def _is_string_mod_args(msg): + # 1. str % (a, b, c, ...) + # 2. str % single_item + # Note that |msg.right| may not be a tuple. + if not isinstance(msg, ast.BinOp): + return False + if not isinstance(msg.op, ast.Mod): + return False + if isinstance(msg.left, ast.Str): + return True + if isinstance(msg.left, ast.Constant) and isinstance( + msg.left.value, str): + return True + return False + + @staticmethod + def _handle_string_mod_args(ctx, node): + msg = build_stmt(ctx, node.left) + args = build_stmt(ctx, node.right) + if not isinstance(args, collections.abc.Sequence): + args = (args, ) + return msg, args + + @staticmethod + def build_Assert(ctx, node): + extra_args = [] + if node.msg is not None: + if isinstance(node.msg, ast.Constant): + msg = node.msg.value + elif isinstance(node.msg, ast.Str): + msg = node.msg.s + elif ASTTransformer._is_string_mod_args(node.msg): + msg, extra_args = ASTTransformer._handle_string_mod_args( + ctx, node.msg) + else: + raise ValueError( + f"assert info must be constant, not {ast.dump(node.msg)}") + else: + msg = astor.to_source(node.test) + test = build_stmt(ctx, node.test) + impl.ti_assert(test, msg.strip(), extra_args) + return None + + @staticmethod + def build_Break(ctx, node): + if ctx.is_in_static_for(): + ctx.set_loop_status(LoopStatus.Break) + else: + ctx.ast_builder.insert_break_stmt() + return None + + @staticmethod + def build_Continue(ctx, node): + if ctx.is_in_static_for(): + ctx.set_loop_status(LoopStatus.Continue) + else: + ctx.ast_builder.insert_continue_stmt() + return None + + @staticmethod + def build_Pass(ctx, node): + return None + + +build_stmt = ASTTransformer() + + +def build_stmts(ctx, stmts): + with ctx.variable_scope_guard(): + for stmt in stmts: + if ctx.returned or ctx.loop_status() != LoopStatus.Normal: + break + else: + build_stmt(ctx, stmt) + return stmts diff --git a/python/taichi/lang/ast/ast_transformer_utils.py b/python/taichi/lang/ast/ast_transformer_utils.py new file mode 100644 index 0000000000000..900351cc75555 --- /dev/null +++ b/python/taichi/lang/ast/ast_transformer_utils.py @@ -0,0 +1,264 @@ +import ast +import builtins +import traceback +from enum import Enum +from sys import version_info +from textwrap import TextWrapper + +from taichi.lang.exception import (TaichiCompilationError, TaichiNameError, + TaichiSyntaxError, + handle_exception_from_cpp) + + +class Builder: + def __call__(self, ctx, node): + method = getattr(self, 'build_' + node.__class__.__name__, None) + try: + if method is None: + error_msg = f'Unsupported node "{node.__class__.__name__}"' + raise TaichiSyntaxError(error_msg) + return method(ctx, node) + except Exception as e: + if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)): + raise e.with_traceback(None) + ctx.raised = True + e = handle_exception_from_cpp(e) + if not isinstance(e, TaichiCompilationError): + msg = ctx.get_pos_info(node) + traceback.format_exc() + raise TaichiCompilationError(msg) from None + msg = ctx.get_pos_info(node) + str(e) + raise type(e)(msg) from None + + +class VariableScopeGuard: + def __init__(self, scopes): + self.scopes = scopes + + def __enter__(self): + self.scopes.append({}) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.scopes.pop() + + +class StaticScopeStatus: + def __init__(self): + self.is_in_static_scope = False + + +class StaticScopeGuard: + def __init__(self, status): + self.status = status + + def __enter__(self): + self.prev = self.status.is_in_static_scope + self.status.is_in_static_scope = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.status.is_in_static_scope = self.prev + + +class NonStaticControlFlowStatus: + def __init__(self): + self.is_in_non_static_control_flow = False + + +class NonStaticControlFlowGuard: + def __init__(self, status): + self.status = status + + def __enter__(self): + self.prev = self.status.is_in_non_static_control_flow + self.status.is_in_non_static_control_flow = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.status.is_in_non_static_control_flow = self.prev + + +class LoopStatus(Enum): + Normal = 0 + Break = 1 + Continue = 2 + + +class LoopScopeAttribute: + def __init__(self, is_static): + self.is_static = is_static + self.status = LoopStatus.Normal + + +class LoopScopeGuard: + def __init__(self, scopes, non_static_guard=None): + self.scopes = scopes + self.non_static_guard = non_static_guard + + def __enter__(self): + self.scopes.append(LoopScopeAttribute(self.non_static_guard is None)) + if self.non_static_guard: + self.non_static_guard.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.scopes.pop() + if self.non_static_guard: + self.non_static_guard.__exit__(exc_type, exc_val, exc_tb) + + +class ASTTransformerContext: + def __init__(self, + excluded_parameters=(), + is_kernel=True, + func=None, + arg_features=None, + global_vars=None, + argument_data=None, + file=None, + src=None, + start_lineno=None, + ast_builder=None, + is_real_function=False): + self.func = func + self.local_scopes = [] + self.loop_scopes = [] + self.excluded_parameters = excluded_parameters + self.is_kernel = is_kernel + self.arg_features = arg_features + self.returns = None + self.global_vars = global_vars + self.argument_data = argument_data + self.return_data = None + self.file = file + self.src = src + self.indent = 0 + for c in self.src[0]: + if c == ' ': + self.indent += 1 + else: + break + self.lineno_offset = start_lineno - 1 + self.raised = False + self.non_static_control_flow_status = NonStaticControlFlowStatus() + self.static_scope_status = StaticScopeStatus() + self.returned = False + self.ast_builder = ast_builder + self.visited_funcdef = False + self.is_real_function = is_real_function + + # e.g.: FunctionDef, Module, Global + def variable_scope_guard(self): + return VariableScopeGuard(self.local_scopes) + + # e.g.: For, While + def loop_scope_guard(self, is_static=False): + if is_static: + return LoopScopeGuard(self.loop_scopes) + return LoopScopeGuard(self.loop_scopes, + self.non_static_control_flow_guard()) + + def non_static_control_flow_guard(self): + return NonStaticControlFlowGuard(self.non_static_control_flow_status) + + def static_scope_guard(self): + return StaticScopeGuard(self.static_scope_status) + + def current_scope(self): + return self.local_scopes[-1] + + def current_loop_scope(self): + return self.loop_scopes[-1] + + def loop_status(self): + if self.loop_scopes: + return self.loop_scopes[-1].status + return LoopStatus.Normal + + def set_loop_status(self, status): + self.loop_scopes[-1].status = status + + def is_in_static_for(self): + if self.loop_scopes: + return self.loop_scopes[-1].is_static + return False + + def is_in_non_static_control_flow(self): + return self.non_static_control_flow_status.is_in_non_static_control_flow + + def is_in_static_scope(self): + return self.static_scope_status.is_in_static_scope + + def is_var_declared(self, name): + for s in self.local_scopes: + if name in s: + return True + return False + + def create_variable(self, name, var): + if name in self.current_scope(): + raise TaichiSyntaxError("Recreating variables is not allowed") + self.current_scope()[name] = var + + def check_loop_var(self, loop_var): + if self.is_var_declared(loop_var): + raise TaichiSyntaxError( + f"Variable '{loop_var}' is already declared in the outer scope and cannot be used as loop variable" + ) + + def get_var_by_name(self, name): + for s in reversed(self.local_scopes): + if name in s: + return s[name] + if name in self.global_vars: + return self.global_vars[name] + try: + return getattr(builtins, name) + except AttributeError: + raise TaichiNameError(f'Name "{name}" is not defined') + + def get_pos_info(self, node): + msg = f'On line {node.lineno + self.lineno_offset} of file "{self.file}", in {self.func.func.__name__}:\n' + if version_info < (3, 8): + msg += self.src[node.lineno - 1] + "\n" + return msg + col_offset = self.indent + node.col_offset + end_col_offset = self.indent + node.end_col_offset + + wrapper = TextWrapper(width=80) + + def gen_line(code, hint): + hint += ' ' * (len(code) - len(hint)) + code = wrapper.wrap(code) + hint = wrapper.wrap(hint) + if not len(code): + return "\n\n" + return "".join([c + '\n' + h + '\n' for c, h in zip(code, hint)]) + + if node.lineno == node.end_lineno: + hint = ' ' * col_offset + '^' * (end_col_offset - col_offset) + msg += gen_line(self.src[node.lineno - 1], hint) + else: + node_type = node.__class__.__name__ + + if node_type in ["For", "While", "FunctionDef", "If"]: + end_lineno = max(node.body[0].lineno - 1, node.lineno) + else: + end_lineno = node.end_lineno + + for i in range(node.lineno - 1, end_lineno): + last = len(self.src[i]) + while last > 0 and (self.src[i][last - 1].isspace() or + not self.src[i][last - 1].isprintable()): + last -= 1 + first = 0 + while first < len(self.src[i]) and ( + self.src[i][first].isspace() + or not self.src[i][first].isprintable()): + first += 1 + if i == node.lineno - 1: + hint = ' ' * col_offset + '^' * (last - col_offset) + elif i == node.end_lineno - 1: + hint = ' ' * first + '^' * (end_col_offset - first) + elif first < last: + hint = ' ' * first + '^' * (last - first) + else: + hint = '' + msg += gen_line(self.src[i], hint) + return msg diff --git a/python/taichi/lang/ast/checkers.py b/python/taichi/lang/ast/checkers.py index fb8a3d62ba418..825a13ae6ab9e 100644 --- a/python/taichi/lang/ast/checkers.py +++ b/python/taichi/lang/ast/checkers.py @@ -1,6 +1,6 @@ import ast -import taichi.lang.kernel_impl +from taichi.lang.exception import TaichiSyntaxError from taichi.lang.shell import oinspect @@ -55,7 +55,8 @@ def get_error_location(self, node): lineno = self._func_lineno + node.lineno - 1 return f'file={self._func_file} kernel={self._func_name} line={lineno}' - def should_check(self, node): + @staticmethod + def should_check(node): if not isinstance(node, ast.stmt): return False # TODO(#536): Frontend pass should help make sure |func| is a valid AST for @@ -69,7 +70,7 @@ def generic_visit(self, node): return if not (self.top_level or self.current_scope.allows_more_stmt): - raise taichi.lang.kernel_impl.KernelDefError( + raise TaichiSyntaxError( f'No more statements allowed, at {self.get_error_location(node)}' ) old_top_level = self.top_level @@ -83,25 +84,23 @@ def generic_visit(self, node): if old_top_level: self._scope_guards.pop() - def visit_For(self, node): + @staticmethod + def visit_for(node): # TODO: since autodiff is enhanced, AST checker rules should be relaxed. This part should be updated. + # original code is #def visit_For(self, node) without #@staticmethod before fix pylint R0201 return - if (isinstance(node.iter, ast.Call) - and isinstance(node.iter.func, ast.Attribute) - and isinstance(node.iter.func.value, ast.Name) - and node.iter.func.value.id == 'ti' - and node.iter.func.attr == 'static'): - is_static = True - else: - is_static = False - if not (self.top_level or self.current_scope.allows_for_loop - or is_static): - raise taichi.lang.kernel_impl.KernelDefError( - f'No more for loops allowed, at {self.get_error_location(node)}' - ) - - with self.new_scope(): - super().generic_visit(node) - - if not (self.top_level or is_static): - self.current_scope.mark_no_more_stmt() + # is_static = (isinstance(node.iter, ast.Call) + # and isinstance(node.iter.func, ast.Attribute) + # and isinstance(node.iter.func.value, ast.Name) + # and node.iter.func.value.id == 'ti' + # and node.iter.func.attr == 'static') + # if not (self.top_level or self.current_scope.allows_for_loop + # or is_static): + # raise TaichiSyntaxError( + # f'No more for loops allowed, at {self.get_error_location(node)}' + # ) + # with self.new_scope(): + # super().generic_visit(node) + # + # if not (self.top_level or is_static): + # self.current_scope.mark_no_more_stmt() diff --git a/python/taichi/lang/ast/transform.py b/python/taichi/lang/ast/transform.py new file mode 100644 index 0000000000000..f9ca13b49f43e --- /dev/null +++ b/python/taichi/lang/ast/transform.py @@ -0,0 +1,7 @@ +from taichi.lang.ast.ast_transformer import ASTTransformer +from taichi.lang.ast.ast_transformer_utils import ASTTransformerContext + + +def transform_tree(tree, ctx: ASTTransformerContext): + ASTTransformer()(ctx, tree) + return ctx.return_data diff --git a/python/taichi/lang/ast/transformer.py b/python/taichi/lang/ast/transformer.py deleted file mode 100644 index 476a3cd9c7855..0000000000000 --- a/python/taichi/lang/ast/transformer.py +++ /dev/null @@ -1,98 +0,0 @@ -import ast - -import astor -from taichi.lang import impl -from taichi.lang.ast.symbol_resolver import ASTResolver -from taichi.lang.ast_builder_utils import BuilderContext -from taichi.lang.exception import TaichiSyntaxError -from taichi.lang.stmt_builder import build_stmt - -import taichi as ti - - -# Total transform -class ASTTransformerTotal(object): - def __init__(self, - func=None, - excluded_parameters=(), - is_kernel=True, - arg_features=None): - self.func = func - self.excluded_parameters = excluded_parameters - self.is_kernel = is_kernel - self.arg_features = arg_features - self.pass_Checks = ASTTransformerChecks(func=func) - - @staticmethod - def print_ast(tree, title=None): - if not impl.get_runtime().print_preprocessed: - return - if title is not None: - ti.info(f'{title}:') - print(astor.to_source(tree.body[0], indent_with=' '), flush=True) - - def visit(self, tree): - self.print_ast(tree, 'Initial AST') - ctx = BuilderContext(func=self.func, - excluded_parameters=self.excluded_parameters, - is_kernel=self.is_kernel, - arg_features=self.arg_features) - # Convert Python AST to Python code that generates Taichi C++ AST. - tree = build_stmt(ctx, tree) - ast.fix_missing_locations(tree) - self.print_ast(tree, 'Preprocessed') - self.pass_Checks.visit(tree) # does not modify the AST - - -class ASTTransformerBase(ast.NodeTransformer): - def __init__(self, func): - super().__init__() - self.func = func - - @staticmethod - def get_decorator(node): - if not isinstance(node, ast.Call): - return '' - for wanted, name in [ - (ti.static, 'static'), - (ti.grouped, 'grouped'), - (ti.ndrange, 'ndrange'), - ]: - if ASTResolver.resolve_to(node.func, wanted, globals()): - return name - return '' - - -# Performs checks at the Python AST level. Does not modify the AST. -class ASTTransformerChecks(ASTTransformerBase): - def __init__(self, func): - super().__init__(func) - self.has_return = False - self.in_static_if = False - - def visit_If(self, node): - node.test = self.visit(node.test) - - old_in_static_if = self.in_static_if - self.in_static_if = self.get_decorator(node.test) == 'static' - - node.body = list(map(self.visit, node.body)) - if node.orelse is not None: - node.orelse = list(map(self.visit, node.orelse)) - - self.in_static_if = old_in_static_if - - return node - - def visit_Return(self, node): - if self.in_static_if: # we can have multiple return in static-if branches - return node - - if not self.has_return: - self.has_return = True - else: - raise TaichiSyntaxError( - 'Taichi functions/kernels cannot have multiple returns!' - ' Consider using a local variable to walk around.') - - return node diff --git a/python/taichi/lang/ast_builder_utils.py b/python/taichi/lang/ast_builder_utils.py deleted file mode 100644 index 103bc3a4fd368..0000000000000 --- a/python/taichi/lang/ast_builder_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -import ast - -from taichi.lang.exception import TaichiSyntaxError - - -class Builder(object): - def __call__(self, ctx, node): - method = getattr(self, 'build_' + node.__class__.__name__, None) - if method is None: - try: - import astpretty # pylint: disable=C0415 - error_msg = f'Unsupported node {node}:\n{astpretty.pformat(node)}' - except: - error_msg = f'Unsupported node {node}' - raise TaichiSyntaxError(error_msg) - return method(ctx, node) - - -def parse_stmt(stmt): - return ast.parse(stmt).body[0] - - -def parse_expr(expr): - return ast.parse(expr).body[0].value - - -class ScopeGuard: - def __init__(self, scopes, stmt_block=None): - self.scopes = scopes - self.stmt_block = stmt_block - - def __enter__(self): - self.scopes.append([]) - - def __exit__(self, exc_type, exc_val, exc_tb): - local = self.scopes[-1] - if self.stmt_block is not None: - for var in reversed(local): - stmt = parse_stmt('del var') - stmt.targets[0].id = var - self.stmt_block.append(stmt) - self.scopes.pop() - - -class BuilderContext: - def __init__(self, - excluded_parameters=(), - is_kernel=True, - func=None, - arg_features=None): - self.func = func - self.local_scopes = [] - self.control_scopes = [] - self.excluded_parameters = excluded_parameters - self.is_kernel = is_kernel - self.arg_features = arg_features - self.returns = None - - # e.g.: FunctionDef, Module, Global - def variable_scope(self, *args): - return ScopeGuard(self.local_scopes, *args) - - # e.g.: For, While - def control_scope(self): - return ScopeGuard(self.control_scopes) - - def current_scope(self): - return self.local_scopes[-1] - - def current_control_scope(self): - return self.control_scopes[-1] - - def var_declared(self, name): - for s in self.local_scopes: - if name in s: - return True - return False - - def is_creation(self, name): - return not self.var_declared(name) - - def create_variable(self, name): - assert name not in self.current_scope( - ), "Recreating variables is not allowed" - self.current_scope().append(name) - - def check_loop_var(self, loop_var): - if self.var_declared(loop_var): - raise TaichiSyntaxError( - "Variable '{}' is already declared in the outer scope and cannot be used as loop variable" - .format(loop_var)) diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py index 28b36964b2440..38b13a29573ce 100644 --- a/python/taichi/lang/common_ops.py +++ b/python/taichi/lang/common_ops.py @@ -1,167 +1,132 @@ -import taichi as ti +import warnings + +from taichi.lang import ops class TaichiOperations: """The base class of taichi operations of expressions. Subclasses: :class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`""" + + __deprecated_atomic_ops__ = { + "atomic_add": "_atomic_add", + "atomic_and": "_atomic_and", + "atomic_or": "_atomic_or", + "atomic_sub": "_atomic_sub", + "atomic_xor": "_atomic_xor", + } + + def __getattr__(self, item): + if item in TaichiOperations.__deprecated_atomic_ops__: + warnings.warn( + f"a.{item}(b) is deprecated. Please use ti.{item}(a, b) instead.", + DeprecationWarning) + return getattr(self, + TaichiOperations.__deprecated_atomic_ops__[item]) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{item}'") + def __neg__(self): - _taichi_skip_traceback = 1 - return ti.neg(self) + return ops.neg(self) def __abs__(self): - _taichi_skip_traceback = 1 - return ti.abs(self) + return ops.abs(self) def __add__(self, other): - _taichi_skip_traceback = 1 - return ti.add(self, other) + return ops.add(self, other) def __radd__(self, other): - _taichi_skip_traceback = 1 - return ti.add(other, self) + return ops.add(other, self) def __sub__(self, other): - _taichi_skip_traceback = 1 - return ti.sub(self, other) + return ops.sub(self, other) def __rsub__(self, other): - _taichi_skip_traceback = 1 - return ti.sub(other, self) + return ops.sub(other, self) def __mul__(self, other): - _taichi_skip_traceback = 1 - return ti.mul(self, other) + return ops.mul(self, other) def __rmul__(self, other): - _taichi_skip_traceback = 1 - return ti.mul(other, self) + return ops.mul(other, self) def __truediv__(self, other): - _taichi_skip_traceback = 1 - return ti.truediv(self, other) + return ops.truediv(self, other) def __rtruediv__(self, other): - _taichi_skip_traceback = 1 - return ti.truediv(other, self) + return ops.truediv(other, self) def __floordiv__(self, other): - _taichi_skip_traceback = 1 - return ti.floordiv(self, other) + return ops.floordiv(self, other) def __rfloordiv__(self, other): - _taichi_skip_traceback = 1 - return ti.floordiv(other, self) + return ops.floordiv(other, self) def __mod__(self, other): - _taichi_skip_traceback = 1 - return ti.mod(self, other) + return ops.mod(self, other) def __rmod__(self, other): - _taichi_skip_traceback = 1 - return ti.mod(other, self) + return ops.mod(other, self) def __pow__(self, other, modulo=None): - _taichi_skip_traceback = 1 - return ti.pow(self, other) + return ops.pow(self, other) def __rpow__(self, other, modulo=None): - _taichi_skip_traceback = 1 - return ti.pow(other, self) + return ops.pow(other, self) def __le__(self, other): - _taichi_skip_traceback = 1 - return ti.cmp_le(self, other) + return ops.cmp_le(self, other) def __lt__(self, other): - _taichi_skip_traceback = 1 - return ti.cmp_lt(self, other) + return ops.cmp_lt(self, other) def __ge__(self, other): - _taichi_skip_traceback = 1 - return ti.cmp_ge(self, other) + return ops.cmp_ge(self, other) def __gt__(self, other): - _taichi_skip_traceback = 1 - return ti.cmp_gt(self, other) + return ops.cmp_gt(self, other) def __eq__(self, other): - _taichi_skip_traceback = 1 - return ti.cmp_eq(self, other) + return ops.cmp_eq(self, other) def __ne__(self, other): - _taichi_skip_traceback = 1 - return ti.cmp_ne(self, other) + return ops.cmp_ne(self, other) def __and__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_and(self, other) + return ops.bit_and(self, other) def __rand__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_and(other, self) + return ops.bit_and(other, self) def __or__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_or(self, other) + return ops.bit_or(self, other) def __ror__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_or(other, self) + return ops.bit_or(other, self) def __xor__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_xor(self, other) + return ops.bit_xor(self, other) def __rxor__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_xor(other, self) + return ops.bit_xor(other, self) def __lshift__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_shl(self, other) + return ops.bit_shl(self, other) def __rlshift__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_shl(other, self) + return ops.bit_shl(other, self) def __rshift__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_sar(self, other) + return ops.bit_sar(self, other) def __rrshift__(self, other): - _taichi_skip_traceback = 1 - return ti.bit_sar(other, self) - - def logical_and(self, other): - """Return the new expression of computing logical and between self and a given operand. - - Args: - other (Any): Given operand. - - Returns: - :class:`~taichi.lang.expr.Expr`: The computing expression of logical and.""" - _taichi_skip_traceback = 1 - return ti.logical_and(self, other) - - def logical_or(self, other): - """Return the new expression of computing logical or between self and a given operand. - - Args: - other (Any): Given operand. - - Returns: - :class:`~taichi.lang.expr.Expr`: The computing expression of logical or.""" - _taichi_skip_traceback = 1 - return ti.logical_or(self, other) + return ops.bit_sar(other, self) def __invert__(self): # ~a => a.__invert__() - _taichi_skip_traceback = 1 - return ti.bit_not(self) + return ops.bit_not(self) def __not__(self): # not a => a.__not__() - _taichi_skip_traceback = 1 - return ti.logical_not(self) + return ops.logical_not(self) - def atomic_add(self, other): + def _atomic_add(self, other): """Return the new expression of computing atomic add between self and a given operand. Args: @@ -169,10 +134,9 @@ def atomic_add(self, other): Returns: :class:`~taichi.lang.expr.Expr`: The computing expression of atomic add.""" - _taichi_skip_traceback = 1 - return ti.atomic_add(self, other) + return ops.atomic_add(self, other) - def atomic_sub(self, other): + def _atomic_sub(self, other): """Return the new expression of computing atomic sub between self and a given operand. Args: @@ -180,10 +144,9 @@ def atomic_sub(self, other): Returns: :class:`~taichi.lang.expr.Expr`: The computing expression of atomic sub.""" - _taichi_skip_traceback = 1 - return ti.atomic_sub(self, other) + return ops.atomic_sub(self, other) - def atomic_and(self, other): + def _atomic_and(self, other): """Return the new expression of computing atomic and between self and a given operand. Args: @@ -191,10 +154,9 @@ def atomic_and(self, other): Returns: :class:`~taichi.lang.expr.Expr`: The computing expression of atomic and.""" - _taichi_skip_traceback = 1 - return ti.atomic_and(self, other) + return ops.atomic_and(self, other) - def atomic_xor(self, other): + def _atomic_xor(self, other): """Return the new expression of computing atomic xor between self and a given operand. Args: @@ -202,10 +164,9 @@ def atomic_xor(self, other): Returns: :class:`~taichi.lang.expr.Expr`: The computing expression of atomic xor.""" - _taichi_skip_traceback = 1 - return ti.atomic_xor(self, other) + return ops.atomic_xor(self, other) - def atomic_or(self, other): + def _atomic_or(self, other): """Return the new expression of computing atomic or between self and a given operand. Args: @@ -213,66 +174,58 @@ def atomic_or(self, other): Returns: :class:`~taichi.lang.expr.Expr`: The computing expression of atomic or.""" - _taichi_skip_traceback = 1 - return ti.atomic_or(self, other) + return ops.atomic_or(self, other) def __iadd__(self, other): - _taichi_skip_traceback = 1 - self.atomic_add(other) + self._atomic_add(other) return self def __isub__(self, other): - _taichi_skip_traceback = 1 - self.atomic_sub(other) + self._atomic_sub(other) return self def __iand__(self, other): - _taichi_skip_traceback = 1 - self.atomic_and(other) + self._atomic_and(other) return self def __ixor__(self, other): - _taichi_skip_traceback = 1 - self.atomic_xor(other) + self._atomic_xor(other) return self def __ior__(self, other): - _taichi_skip_traceback = 1 - self.atomic_or(other) + self._atomic_or(other) return self # we don't support atomic_mul/truediv/floordiv/mod yet: def __imul__(self, other): - _taichi_skip_traceback = 1 - self.assign(ti.mul(self, other)) + self._assign(ops.mul(self, other)) return self def __itruediv__(self, other): - _taichi_skip_traceback = 1 - self.assign(ti.truediv(self, other)) + self._assign(ops.truediv(self, other)) return self def __ifloordiv__(self, other): - _taichi_skip_traceback = 1 - self.assign(ti.floordiv(self, other)) + self._assign(ops.floordiv(self, other)) return self def __imod__(self, other): - _taichi_skip_traceback = 1 - self.assign(ti.mod(self, other)) + self._assign(ops.mod(self, other)) return self def __ilshift__(self, other): - _taichi_skip_traceback = 1 - self.assign(ti.bit_shl(self, other)) + self._assign(ops.bit_shl(self, other)) return self def __irshift__(self, other): - _taichi_skip_traceback = 1 - self.assign(ti.bit_shr(self, other)) + self._assign(ops.bit_shr(self, other)) + return self + + def __ipow__(self, other): + self._assign(ops.pow(self, other)) return self - def assign(self, other): + def _assign(self, other): """Assign the expression of the given operand to self. Args: @@ -280,16 +233,15 @@ def assign(self, other): Returns: :class:`~taichi.lang.expr.Expr`: The expression after assigning.""" - _taichi_skip_traceback = 1 - return ti.assign(self, other) + return ops.assign(self, other) - def augassign(self, x, op): + # pylint: disable=R0201 + def _augassign(self, x, op): """Generate the computing expression between self and the given operand of given operator and assigned to self. Args: x (Any): Given operand. op (str): The name of operator.""" - _taichi_skip_traceback = 1 if op == 'Add': self += x elif op == 'Sub': @@ -312,13 +264,13 @@ def augassign(self, x, op): self >>= x elif op == 'LShift': self <<= x + elif op == 'Pow': + self **= x else: assert False, op def __ti_int__(self): - _taichi_skip_traceback = 1 - return ti.cast(self, int) + return ops.cast(self, int) def __ti_float__(self): - _taichi_skip_traceback = 1 - return ti.cast(self, float) + return ops.cast(self, float) diff --git a/python/taichi/lang/enums.py b/python/taichi/lang/enums.py index e7649322c9d9c..43f14d50dcccb 100644 --- a/python/taichi/lang/enums.py +++ b/python/taichi/lang/enums.py @@ -9,3 +9,6 @@ class Layout(Enum): """ AOS = 1 SOA = 2 + + +__all__ = ['Layout'] diff --git a/python/taichi/lang/exception.py b/python/taichi/lang/exception.py index 149279f6eea91..d76290638e632 100644 --- a/python/taichi/lang/exception.py +++ b/python/taichi/lang/exception.py @@ -1,7 +1,51 @@ -class TaichiSyntaxError(Exception): - def __init__(self, *args): - super().__init__(*args) +from taichi._lib import core -class InvalidOperationError(Exception): +class TaichiCompilationError(Exception): + """Base class for all compilation exceptions. + """ pass + + +class TaichiSyntaxError(TaichiCompilationError, SyntaxError): + """Thrown when a syntax error is found during compilation. + """ + pass + + +class TaichiNameError(TaichiCompilationError, NameError): + """Thrown when an undefine name is found during compilation. + """ + pass + + +class TaichiTypeError(TaichiCompilationError, TypeError): + """Thrown when a type mismatch is found during compilation. + """ + pass + + +class TaichiRuntimeError(RuntimeError): + """Thrown when the compiled program cannot be executed due to unspecified reasons. + """ + pass + + +class TaichiRuntimeTypeError(TaichiRuntimeError, TypeError): + def __init__(self, pos, needed, provided): + message = f'Argument {pos} (type={provided}) cannot be converted into required type {needed}' + super().__init__(message) + + +def handle_exception_from_cpp(exc): + if isinstance(exc, core.TaichiTypeError): + return TaichiTypeError(str(exc)) + if isinstance(exc, core.TaichiSyntaxError): + return TaichiSyntaxError(str(exc)) + return exc + + +__all__ = [ + 'TaichiSyntaxError', 'TaichiTypeError', 'TaichiCompilationError', + 'TaichiNameError', 'TaichiRuntimeError', 'TaichiRuntimeTypeError' +] diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 0f9b14957cc5b..05582cd6be9cb 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -1,17 +1,16 @@ import numpy as np -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core from taichi.lang import impl from taichi.lang.common_ops import TaichiOperations -from taichi.lang.util import is_taichi_class, python_scope - -import taichi as ti +from taichi.lang.exception import TaichiTypeError +from taichi.lang.util import is_taichi_class, to_numpy_type, to_taichi_type +from taichi.types.primitive_types import integer_types, real_types # Scalar, basic data type class Expr(TaichiOperations): """A Python-side Expr wrapper, whose member variable `ptr` is an instance of C++ Expr class. A C++ Expr object contains member variable `expr` which holds an instance of C++ Expression class.""" - def __init__(self, *args, tb=None): - _taichi_skip_traceback = 1 + def __init__(self, *args, tb=None, dtype=None): self.tb = tb if len(args) == 1: if isinstance(args[0], _ti_core.Expr): @@ -20,21 +19,24 @@ def __init__(self, *args, tb=None): self.ptr = args[0].ptr self.tb = args[0].tb elif is_taichi_class(args[0]): - raise ValueError('cannot initialize scalar expression from ' - f'taichi class: {type(args[0])}') + raise TaichiTypeError( + 'Cannot initialize scalar expression from ' + f'taichi class: {type(args[0])}') else: # assume to be constant arg = args[0] - try: - if isinstance(arg, np.ndarray): - arg = arg.dtype(arg) - except: - pass - self.ptr = impl.make_constant_expr(arg).ptr + if isinstance(arg, np.ndarray): + if arg.shape: + raise TaichiTypeError( + "Only 0-dimensional numpy array can be used to initialize a scalar expression" + ) + arg = arg.dtype.type(arg) + self.ptr = make_constant_expr(arg, dtype).ptr else: assert False if self.tb: self.ptr.set_tb(self.tb) + self.ptr.type_check(impl.get_runtime().prog.config) def __hash__(self): return self.ptr.get_raw_address() @@ -46,22 +48,74 @@ def __repr__(self): return '' -def make_var_vector(size): +def _check_in_range(npty, val): + iif = np.iinfo(npty) + if not iif.min <= val <= iif.max: + # This isn't the case we want to deal with: |val| does't fall into the valid range of either + # the signed or the unsigned type. + raise TaichiTypeError( + f'Constant {val} has exceeded the range of {to_taichi_type(npty)}: [{iif.min}, {iif.max}]' + ) + + +def _clamp_unsigned_to_range(npty, val): + # npty: np.int32 or np.int64 + iif = np.iinfo(npty) + if iif.min <= val <= iif.max: + return val + cap = (1 << iif.bits) + assert 0 <= val < cap + new_val = val - cap + return new_val + + +def make_constant_expr(val, dtype): + if isinstance(val, (int, np.integer)): + constant_dtype = impl.get_runtime( + ).default_ip if dtype is None else dtype + if constant_dtype not in integer_types: + raise TaichiTypeError( + 'Integer literals must be annotated with a integer type. For type casting, use `ti.cast`.' + ) + _check_in_range(to_numpy_type(constant_dtype), val) + return Expr( + _ti_core.make_const_expr_int( + constant_dtype, _clamp_unsigned_to_range(np.int64, val))) + if isinstance(val, (float, np.floating)): + constant_dtype = impl.get_runtime( + ).default_fp if dtype is None else dtype + if constant_dtype not in real_types: + raise TaichiTypeError( + 'Floating-point literals must be annotated with a floating-point type. For type casting, use `ti.cast`.' + ) + return Expr(_ti_core.make_const_expr_fp(constant_dtype, val)) + raise TaichiTypeError(f'Invalid constant scalar data type: {type(val)}') + + +def make_var_list(size): exprs = [] for _ in range(size): exprs.append(_ti_core.make_id_expr('')) - return ti.Vector(exprs, disable_local_tensor=True) + return exprs def make_expr_group(*exprs): + from taichi.lang.matrix import Matrix # pylint: disable=C0415 if len(exprs) == 1: if isinstance(exprs[0], (list, tuple)): exprs = exprs[0] - elif isinstance(exprs[0], ti.Matrix): + elif isinstance(exprs[0], Matrix): mat = exprs[0] assert mat.m == 1 exprs = mat.entries expr_group = _ti_core.ExprGroup() for i in exprs: - expr_group.push_back(Expr(i).ptr) + if isinstance(i, Matrix): + assert i.local_tensor_proxy is not None + expr_group.push_back(i.local_tensor_proxy) + else: + expr_group.push_back(Expr(i).ptr) return expr_group + + +__all__ = [] diff --git a/python/taichi/lang/expr_builder.py b/python/taichi/lang/expr_builder.py deleted file mode 100644 index 8cf2082fbcaf0..0000000000000 --- a/python/taichi/lang/expr_builder.py +++ /dev/null @@ -1,277 +0,0 @@ -import ast -import warnings - -from taichi.lang.ast.symbol_resolver import ASTResolver -from taichi.lang.ast_builder_utils import * -from taichi.lang.exception import TaichiSyntaxError - -import taichi as ti - - -class ExprBuilder(Builder): - @staticmethod - def build_Subscript(ctx, node): - def get_subscript_index(node): - assert isinstance(node, ast.Subscript), type(node) - # ast.Index has been deprecated in Python 3.9, - # use the index value directly instead :) - if isinstance(node.slice, ast.Index): - return build_expr(ctx, node.slice.value) - return build_expr(ctx, node.slice) - - value = build_expr(ctx, node.value) - indices = get_subscript_index(node) - if isinstance(indices, ast.Tuple): - indices = indices.elts - else: - indices = [indices] - - call = ast.Call(func=parse_expr('ti.subscript'), - args=[value] + indices, - keywords=[]) - return ast.copy_location(call, node) - - @staticmethod - def build_Compare(ctx, node): - operands = build_exprs(ctx, [node.left] + list(node.comparators)) - operators = [] - for i in range(len(node.ops)): - if isinstance(node.ops[i], ast.Lt): - op_str = 'Lt' - elif isinstance(node.ops[i], ast.LtE): - op_str = 'LtE' - elif isinstance(node.ops[i], ast.Gt): - op_str = 'Gt' - elif isinstance(node.ops[i], ast.GtE): - op_str = 'GtE' - elif isinstance(node.ops[i], ast.Eq): - op_str = 'Eq' - elif isinstance(node.ops[i], ast.NotEq): - op_str = 'NotEq' - elif isinstance(node.ops[i], ast.In): - raise TaichiSyntaxError( - '"in" is not supported in Taichi kernels.') - elif isinstance(node.ops[i], ast.NotIn): - raise TaichiSyntaxError( - '"not in" is not supported in Taichi kernels.') - elif isinstance(node.ops[i], ast.Is): - raise TaichiSyntaxError( - '"is" is not supported in Taichi kernels.') - elif isinstance(node.ops[i], ast.IsNot): - raise TaichiSyntaxError( - '"is not" is not supported in Taichi kernels.') - else: - raise Exception(f'Unknown operator {node.ops[i]}') - operators += [ - ast.copy_location(ast.Str(s=op_str, kind=None), node) - ] - - call = ast.Call( - func=parse_expr('ti.chain_compare'), - args=[ - ast.copy_location(ast.List(elts=operands, ctx=ast.Load()), - node), - ast.copy_location(ast.List(elts=operators, ctx=ast.Load()), - node) - ], - keywords=[]) - call = ast.copy_location(call, node) - return call - - @staticmethod - def build_Call(ctx, node): - if ASTResolver.resolve_to(node.func, ti.static, globals()): - # Do not modify the expression if the function called is ti.static - return node - node.func = build_expr(ctx, node.func) - node.args = build_exprs(ctx, node.args) - for i in range(len(node.keywords)): - node.keywords[i].value = build_expr(ctx, node.keywords[i].value) - if isinstance(node.func, ast.Attribute): - attr_name = node.func.attr - if attr_name == 'format': - node.args.insert(0, node.func.value) - node.func = parse_expr('ti.ti_format') - if isinstance(node.func, ast.Name): - func_name = node.func.id - if func_name == 'print': - node.func = parse_expr('ti.ti_print') - elif func_name == 'min': - node.func = parse_expr('ti.ti_min') - elif func_name == 'max': - node.func = parse_expr('ti.ti_max') - elif func_name == 'int': - node.func = parse_expr('ti.ti_int') - elif func_name == 'float': - node.func = parse_expr('ti.ti_float') - elif func_name == 'any': - node.func = parse_expr('ti.ti_any') - elif func_name == 'all': - node.func = parse_expr('ti.ti_all') - else: - pass - - _taichi_skip_traceback = 1 - ti_func = node.func - if '_sitebuiltins' == getattr(ti_func, '__module__', '') and getattr( - getattr(ti_func, '__class__', ''), '__name__', - '') == 'Quitter': - raise TaichiSyntaxError( - f'exit or quit not supported in Taichi-scope') - if getattr(ti_func, '__module__', '') == '__main__' and not getattr( - ti_func, '__wrapped__', ''): - warnings.warn( - f'Calling into non-Taichi function {ti_func.__name__}.' - ' This means that scope inside that function will not be processed' - ' by the Taichi transformer. Proceed with caution! ' - ' Maybe you want to decorate it with @ti.func?', - UserWarning, - stacklevel=2) - - return node - - @staticmethod - def build_IfExp(ctx, node): - node.test = build_expr(ctx, node.test) - node.body = build_expr(ctx, node.body) - node.orelse = build_expr(ctx, node.orelse) - - call = ast.Call(func=parse_expr('ti.select'), - args=[node.test, node.body, node.orelse], - keywords=[]) - return ast.copy_location(call, node) - - @staticmethod - def build_UnaryOp(ctx, node): - node.operand = build_expr(ctx, node.operand) - if isinstance(node.op, ast.Not): - # Python does not support overloading logical and & or - new_node = parse_expr('ti.logical_not(0)') - new_node.args[0] = node.operand - node = new_node - return node - - @staticmethod - def build_BoolOp(ctx, node): - node.values = build_exprs(ctx, node.values) - - def make_node(a, b, token): - new_node = parse_expr('ti.logical_{}(0, 0)'.format(token)) - new_node.args[0] = a - new_node.args[1] = b - return new_node - - token = '' - if isinstance(node.op, ast.And): - token = 'and' - elif isinstance(node.op, ast.Or): - token = 'or' - else: - print(node.op) - print("BoolOp above not implemented") - exit(0) - - new_node = node.values[0] - for i in range(1, len(node.values)): - new_node = make_node(new_node, node.values[i], token) - - return new_node - - @staticmethod - def build_BinOp(ctx, node): - node.left = build_expr(ctx, node.left) - node.right = build_expr(ctx, node.right) - return node - - @staticmethod - def build_Attribute(ctx, node): - node.value = build_expr(ctx, node.value) - return node - - @staticmethod - def build_List(ctx, node): - node.elts = build_exprs(ctx, node.elts) - return node - - @staticmethod - def build_Tuple(ctx, node): - node.elts = build_exprs(ctx, node.elts) - return node - - @staticmethod - def build_Dict(ctx, node): - node.keys = build_exprs(ctx, node.keys) - node.values = build_exprs(ctx, node.values) - return node - - @staticmethod - def build_ListComp(ctx, node): - node.elt = build_expr(ctx, node.elt) - node.generators = build_exprs(ctx, node.generators) - return node - - @staticmethod - def build_DictComp(ctx, node): - node.key = build_expr(ctx, node.value) - node.value = build_expr(ctx, node.value) - node.generators = build_exprs(ctx, node.generators) - return node - - @staticmethod - def build_comprehension(ctx, node): - node.target = build_expr(ctx, node.target) - node.iter = build_expr(ctx, node.iter) - node.ifs = build_exprs(ctx, node.ifs) - return node - - @staticmethod - def build_Starred(ctx, node): - node.value = build_expr(ctx, node.value) - return node - - @staticmethod - def build_Set(ctx, node): - raise TaichiSyntaxError( - 'Python set is not supported in Taichi kernels.') - - @staticmethod - def build_Name(ctx, node): - return node - - @staticmethod - def build_NamedExpr(ctx, node): - node.value = build_expr(ctx, node.value) - return node - - @staticmethod - def build_Constant(ctx, node): - return node - - # Methods for Python 3.7 or lower - @staticmethod - def build_Num(ctx, node): - return node - - @staticmethod - def build_Str(ctx, node): - return node - - @staticmethod - def build_Bytes(ctx, node): - return node - - @staticmethod - def build_NameConstant(ctx, node): - return node - - -build_expr = ExprBuilder() - - -def build_exprs(ctx, exprs): - result = [] - # TODO(#2495): check if we really need this variable scope - with ctx.variable_scope(result): - for expr in list(exprs): - result.append(build_expr(ctx, expr)) - return result diff --git a/python/taichi/lang/field.py b/python/taichi/lang/field.py index 19f6e0ab0b939..7afe75bdc996d 100644 --- a/python/taichi/lang/field.py +++ b/python/taichi/lang/field.py @@ -1,9 +1,7 @@ import taichi.lang -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core from taichi.lang.util import python_scope, to_numpy_type, to_pytorch_type -import taichi as ti - class Field: """Taichi field with SNode implementation. @@ -16,8 +14,8 @@ class Field: Args: vars (List[Expr]): Field members. """ - def __init__(self, vars): - self.vars = vars + def __init__(self, _vars): + self.vars = _vars self.host_accessors = None self.grad = None @@ -25,6 +23,15 @@ def __init__(self, vars): def snode(self): """Gets representative SNode for info purposes. + Returns: + SNode: Representative SNode (SNode of first field member). + """ + return self._snode + + @property + def _snode(self): + """Gets representative SNode for info purposes. + Returns: SNode: Representative SNode (SNode of first field member). """ @@ -37,7 +44,7 @@ def shape(self): Returns: Tuple[Int]: Field shape. """ - return self.snode.shape + return self._snode.shape @property def dtype(self): @@ -46,16 +53,16 @@ def dtype(self): Returns: DataType: Data type of each individual value. """ - return self.snode.dtype + return self._snode._dtype @property - def name(self): + def _name(self): """Gets field name. Returns: str: Field name. """ - return self.snode.name + return self._snode._name def parent(self, n=1): """Gets an ancestor of the representative SNode in the SNode tree. @@ -68,7 +75,7 @@ def parent(self, n=1): """ return self.snode.parent(n) - def get_field_members(self): + def _get_field_members(self): """Gets field members. Returns: @@ -76,7 +83,7 @@ def get_field_members(self): """ return self.vars - def loop_range(self): + def _loop_range(self): """Gets representative field member for loop range info. Returns: @@ -84,7 +91,7 @@ def loop_range(self): """ return self.vars[0].ptr - def set_grad(self, grad): + def _set_grad(self, grad): """Sets corresponding gradient field. Args: @@ -156,9 +163,13 @@ def copy_from(self, other): Args: other (Field): The source field. """ - assert isinstance(other, Field) - assert len(self.shape) == len(other.shape) - taichi.lang.meta.tensor_to_tensor(self, other) + if not isinstance(other, Field): + raise TypeError('Cannot copy from a non-field object') + if self.shape != other.shape: + raise ValueError(f"ti.field shape {self.shape} does not match" + f" the source field shape {other.shape}") + from taichi._kernels import tensor_to_tensor # pylint: disable=C0415 + tensor_to_tensor(self, other) @python_scope def __setitem__(self, key, value): @@ -185,12 +196,11 @@ def __getitem__(self, key): def __str__(self): if taichi.lang.impl.inside_kernel(): return self.__repr__() # make pybind11 happy, see Matrix.__str__ - if self.snode.ptr is None: + if self._snode.ptr is None: return '' - else: - return str(self.to_numpy()) + return str(self.to_numpy()) - def pad_key(self, key): + def _pad_key(self, key): if key is None: key = () if not isinstance(key, (tuple, list)): @@ -198,7 +208,7 @@ def pad_key(self, key): assert len(key) == len(self.shape) return key + ((0, ) * (_ti_core.get_max_num_indices() - len(key))) - def initialize_host_accessors(self): + def _initialize_host_accessors(self): if self.host_accessors: return taichi.lang.impl.get_runtime().materialize() @@ -206,7 +216,7 @@ def initialize_host_accessors(self): SNodeHostAccessor(e.ptr.snode()) for e in self.vars ] - def host_access(self, key): + def _host_access(self, key): return [SNodeHostAccess(e, key) for e in self.host_accessors] @@ -221,7 +231,8 @@ def __init__(self, var): @python_scope def fill(self, val): - taichi.lang.meta.fill_tensor(self, val) + from taichi._kernels import fill_tensor # pylint: disable=C0415 + fill_tensor(self, val) @python_scope def to_numpy(self, dtype=None): @@ -229,18 +240,22 @@ def to_numpy(self, dtype=None): dtype = to_numpy_type(self.dtype) import numpy as np # pylint: disable=C0415 arr = np.zeros(shape=self.shape, dtype=dtype) - taichi.lang.meta.tensor_to_ext_arr(self, arr) - ti.sync() + from taichi._kernels import tensor_to_ext_arr # pylint: disable=C0415 + tensor_to_ext_arr(self, arr) + taichi.lang.runtime_ops.sync() return arr @python_scope def to_torch(self, device=None): import torch # pylint: disable=C0415 + + # pylint: disable=E1101 arr = torch.zeros(size=self.shape, dtype=to_pytorch_type(self.dtype), device=device) - taichi.lang.meta.tensor_to_ext_arr(self, arr) - ti.sync() + from taichi._kernels import tensor_to_ext_arr # pylint: disable=C0415 + tensor_to_ext_arr(self, arr) + taichi.lang.runtime_ops.sync() return arr @python_scope @@ -254,18 +269,19 @@ def from_numpy(self, arr): f" the numpy array shape {arr.shape}") if hasattr(arr, 'contiguous'): arr = arr.contiguous() - taichi.lang.meta.ext_arr_to_tensor(arr, self) - ti.sync() + from taichi._kernels import ext_arr_to_tensor # pylint: disable=C0415 + ext_arr_to_tensor(arr, self) + taichi.lang.runtime_ops.sync() @python_scope def __setitem__(self, key, value): - self.initialize_host_accessors() - self.host_accessors[0].setter(value, *self.pad_key(key)) + self._initialize_host_accessors() + self.host_accessors[0].setter(value, *self._pad_key(key)) @python_scope def __getitem__(self, key): - self.initialize_host_accessors() - return self.host_accessors[0].getter(*self.pad_key(key)) + self._initialize_host_accessors() + return self.host_accessors[0].getter(*self._pad_key(key)) def __repr__(self): # make interactive shell happy, prevent materialization @@ -307,3 +323,6 @@ class SNodeHostAccess: def __init__(self, accessor, key): self.accessor = accessor self.key = key + + +__all__ = ["Field", "ScalarField"] diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 4811a3b399059..742cccf31792c 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -1,80 +1,84 @@ import numbers -import warnings from types import FunctionType, MethodType from typing import Iterable import numpy as np -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core +from taichi._snode.fields_builder import FieldsBuilder from taichi.lang._ndarray import ScalarNdarray +from taichi.lang._ndrange import GroupedNDRange, _Ndrange from taichi.lang.any_array import AnyArray, AnyArrayAccess -from taichi.lang.exception import InvalidOperationError, TaichiSyntaxError +from taichi.lang.exception import TaichiRuntimeError from taichi.lang.expr import Expr, make_expr_group from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy -from taichi.lang.matrix import MatrixField +from taichi.lang.matrix import (Matrix, MatrixField, _IntermediateMatrix, + _MatrixFieldElement) +from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, + MeshRelationAccessProxy, + MeshReorderedMatrixFieldProxy, + MeshReorderedScalarFieldProxy, element_type_name) from taichi.lang.snode import SNode -from taichi.lang.struct import StructField +from taichi.lang.struct import Struct, StructField, _IntermediateStruct from taichi.lang.tape import TapeImpl -from taichi.lang.util import (cook_dtype, has_pytorch, is_taichi_class, - python_scope, taichi_scope, to_pytorch_type) -from taichi.misc.util import deprecated, get_traceback, warning -from taichi.snode.fields_builder import FieldsBuilder -from taichi.type.primitive_types import f32, f64, i32, i64, u32, u64 - -import taichi as ti +from taichi.lang.util import (cook_dtype, get_traceback, is_taichi_class, + python_scope, taichi_scope, warning) +from taichi.types.primitive_types import f16, f32, f64, i32, i64 @taichi_scope def expr_init_local_tensor(shape, element_type, elements): - return _ti_core.expr_alloca_local_tensor(shape, element_type, elements) + return get_runtime().prog.current_ast_builder().expr_alloca_local_tensor( + shape, element_type, elements) @taichi_scope def expr_init(rhs): if rhs is None: - return Expr(_ti_core.expr_alloca()) - if is_taichi_class(rhs): - if rhs.local_tensor_proxy is not None: - return rhs - else: - return rhs.variable() - else: - if isinstance(rhs, list): - return [expr_init(e) for e in rhs] - elif isinstance(rhs, tuple): - return tuple(expr_init(e) for e in rhs) - elif isinstance(rhs, dict): - return dict((key, expr_init(val)) for key, val in rhs.items()) - elif isinstance(rhs, _ti_core.DataType): - return rhs - elif isinstance(rhs, _ti_core.Arch): - return rhs - elif isinstance(rhs, ti.ndrange): - return rhs - elif hasattr(rhs, '_data_oriented'): - return rhs - else: - return Expr(_ti_core.expr_var(Expr(rhs).ptr)) + return Expr(get_runtime().prog.current_ast_builder().expr_alloca()) + if isinstance(rhs, Matrix): + return Matrix(rhs.to_list()) + if isinstance(rhs, Struct): + return Struct(rhs.to_dict()) + if isinstance(rhs, list): + return [expr_init(e) for e in rhs] + if isinstance(rhs, tuple): + return tuple(expr_init(e) for e in rhs) + if isinstance(rhs, dict): + return dict((key, expr_init(val)) for key, val in rhs.items()) + if isinstance(rhs, _ti_core.DataType): + return rhs + if isinstance(rhs, _ti_core.Arch): + return rhs + if isinstance(rhs, _Ndrange): + return rhs + if isinstance(rhs, MeshElementFieldProxy): + return rhs + if isinstance(rhs, MeshRelationAccessProxy): + return rhs + if hasattr(rhs, '_data_oriented'): + return rhs + return Expr(get_runtime().prog.current_ast_builder().expr_var( + Expr(rhs).ptr)) @taichi_scope def expr_init_list(xs, expected): - if not isinstance(xs, (list, tuple, ti.Matrix)): + if not isinstance(xs, (list, tuple, Matrix)): raise TypeError(f'Cannot unpack type: {type(xs)}') - if isinstance(xs, ti.Matrix): + if isinstance(xs, Matrix): if not xs.m == 1: raise ValueError( - f'Matrices with more than one columns cannot be unpacked') + 'Matrices with more than one columns cannot be unpacked') xs = xs.entries if expected != len(xs): raise ValueError( f'Tuple assignment size mismatch: {expected} != {len(xs)}') if isinstance(xs, list): return [expr_init(e) for e in xs] - elif isinstance(xs, tuple): + if isinstance(xs, tuple): return tuple(expr_init(e) for e in xs) - else: - raise ValueError(f'Cannot unpack from {type(xs)}') + raise ValueError(f'Cannot unpack from {type(xs)}') @taichi_scope @@ -85,7 +89,7 @@ def expr_init_func( return expr_init(rhs) -def begin_frontend_struct_for(group, loop_range): +def begin_frontend_struct_for(ast_builder, group, loop_range): if not isinstance(loop_range, (AnyArray, Field, SNode, _Root)): raise TypeError( 'Can only iterate through Taichi fields/snodes (via template) or dense arrays (via any_arr)' @@ -96,10 +100,11 @@ def begin_frontend_struct_for(group, loop_range): f'({group.size()} != {len(loop_range.shape)}). Maybe you wanted to ' 'use "for I in ti.grouped(x)" to group all indices into a single vector I?' ) - _ti_core.begin_frontend_struct_for(group, loop_range.loop_range()) + ast_builder.begin_frontend_struct_for(group, loop_range._loop_range()) -def begin_frontend_if(cond): +def begin_frontend_if(ast_builder, cond): + assert ast_builder is not None if is_taichi_class(cond): raise ValueError( 'The truth value of vectors/matrices is ambiguous.\n' @@ -107,70 +112,85 @@ def begin_frontend_if(cond): ' if all(x == y):\n' 'or\n' ' if any(x != y):\n') - _ti_core.begin_frontend_if(Expr(cond).ptr) - - -def wrap_scalar(x): - if type(x) in [int, float]: - return Expr(x) - else: - return x + ast_builder.begin_frontend_if(Expr(cond).ptr) @taichi_scope -def subscript(value, *indices): - _taichi_skip_traceback = 1 +def subscript(value, *_indices, skip_reordered=False): if isinstance(value, np.ndarray): - return value.__getitem__(*indices) + return value.__getitem__(_indices) if isinstance(value, (tuple, list, dict)): - assert len(indices) == 1 - return value[indices[0]] + assert len(_indices) == 1 + return value[_indices[0]] + has_slice = False flattened_indices = [] - for i in range(len(indices)): - if is_taichi_class(indices[i]): - ind = indices[i].entries + for _index in _indices: + if is_taichi_class(_index): + ind = _index.entries + elif isinstance(_index, slice): + ind = [_index] + has_slice = True else: - ind = [indices[i]] + ind = [_index] flattened_indices += ind - indices = tuple(flattened_indices) - if isinstance(indices, tuple) and len(indices) == 1 and indices[0] is None: - indices = () - indices_expr_group = make_expr_group(*indices) - index_dim = indices_expr_group.size() + _indices = tuple(flattened_indices) + if isinstance(_indices, + tuple) and len(_indices) == 1 and _indices[0] is None: + _indices = () + + if has_slice: + if not isinstance(value, Matrix): + raise SyntaxError( + f"The type {type(value)} do not support index of slice type") + else: + indices_expr_group = make_expr_group(*_indices) + index_dim = indices_expr_group.size() if is_taichi_class(value): - return value.subscript(*indices) - elif isinstance(value, SparseMatrixProxy): - return value.subscript(*indices) - elif isinstance(value, Field): - var = value.get_field_members()[0].ptr - if var.snode() is None: - if var.is_primal(): + return value._subscript(*_indices) + if isinstance(value, MeshElementFieldProxy): + return value.subscript(*_indices) + if isinstance(value, MeshRelationAccessProxy): + return value.subscript(*_indices) + if isinstance(value, + (MeshReorderedScalarFieldProxy, + MeshReorderedMatrixFieldProxy)) and not skip_reordered: + assert index_dim == 1 + reordered_index = tuple([ + Expr( + _ti_core.get_index_conversion(value.mesh_ptr, + value.element_type, + Expr(_indices[0]).ptr, + ConvType.g2r)) + ]) + return subscript(value, *reordered_index, skip_reordered=True) + if isinstance(value, SparseMatrixProxy): + return value.subscript(*_indices) + if isinstance(value, Field): + _var = value._get_field_members()[0].ptr + if _var.snode() is None: + if _var.is_primal(): raise RuntimeError( - f"{var.get_expr_name()} has not been placed.") + f"{_var.get_expr_name()} has not been placed.") else: raise RuntimeError( - f"Gradient {var.get_expr_name()} has not been placed, check whether `needs_grad=True`" + f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) - field_dim = int(var.get_attribute("dim")) + field_dim = int(_var.get_attribute("dim")) if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) if isinstance(value, MatrixField): - return ti.Matrix.with_entries(value.n, value.m, [ - Expr(_ti_core.subscript(e.ptr, indices_expr_group)) - for e in value.get_field_members() - ]) - elif isinstance(value, StructField): - return ti.Struct( - {k: subscript(v, *indices) - for k, v in value.items}) - else: - return Expr(_ti_core.subscript(var, indices_expr_group)) - elif isinstance(value, AnyArray): + return _MatrixFieldElement(value, indices_expr_group) + if isinstance(value, StructField): + return _IntermediateStruct( + {k: subscript(v, *_indices) + for k, v in value._items}) + return Expr(_ti_core.subscript(_var, indices_expr_group)) + if isinstance(value, AnyArray): # TODO: deprecate using get_attribute to get dim field_dim = int(value.ptr.get_attribute("dim")) element_dim = len(value.element_shape) @@ -182,14 +202,14 @@ def subscript(value, *indices): return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) n = value.element_shape[0] m = 1 if element_dim == 1 else value.element_shape[1] - any_array_access = AnyArrayAccess(value, indices) - ret = ti.Matrix.with_entries(n, m, [ + any_array_access = AnyArrayAccess(value, _indices) + ret = _IntermediateMatrix(n, m, [ any_array_access.subscript(i, j) for i in range(n) for j in range(m) ]) ret.any_array_access = any_array_access return ret - elif isinstance(value, SNode): + if isinstance(value, SNode): # When reading bit structure we only support the 0-D case for now. field_dim = 0 if field_dim != index_dim: @@ -197,105 +217,40 @@ def subscript(value, *indices): f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) - else: # Directly evaluate in Python for non-Taichi types - return value.__getitem__(*indices) - - -@taichi_scope -def local_subscript_with_offset(var, indices, shape): - return Expr( - _ti_core.local_subscript_with_offset(var, make_expr_group(*indices), - shape)) + # Directly evaluate in Python for non-Taichi types + return value.__getitem__(*_indices) @taichi_scope -def global_subscript_with_offset(var, indices, shape, is_aos): +def make_tensor_element_expr(_var, _indices, shape, stride): return Expr( - _ti_core.global_subscript_with_offset(var.ptr, - make_expr_group(*indices), shape, - is_aos)) - - -@taichi_scope -def chain_compare(comparators, ops): - _taichi_skip_traceback = 1 - assert len(comparators) == len(ops) + 1, \ - f'Chain comparison invoked with {len(comparators)} comparators but {len(ops)} operators' - ret = True - for i in range(len(ops)): - lhs = comparators[i] - rhs = comparators[i + 1] - if ops[i] == 'Lt': - now = lhs < rhs - elif ops[i] == 'LtE': - now = lhs <= rhs - elif ops[i] == 'Gt': - now = lhs > rhs - elif ops[i] == 'GtE': - now = lhs >= rhs - elif ops[i] == 'Eq': - now = lhs == rhs - elif ops[i] == 'NotEq': - now = lhs != rhs - else: - assert False, f'Unknown operator {ops[i]}' - ret = ti.logical_and(ret, now) - return ret - - -@taichi_scope -def insert_expr_stmt_if_ti_func(func, *args, **kwargs): - """This method is used only for real functions. It inserts a - FrontendExprStmt to the C++ AST to hold the function call if `func` is a - Taichi function. - - Args: - func: The function to be called. - args: The arguments of the function call. - kwargs: The keyword arguments of the function call. - - Returns: - The return value of the function call if it's a non-Taichi function. - Returns None if it's a Taichi function.""" - is_taichi_function = getattr(func, '_is_taichi_function', False) - # If is_taichi_function is true: call a decorated Taichi function - # in a Taichi kernel/function. - - if is_taichi_function: - # Compiles the function here. - # Invokes Func.__call__. - func_call_result = func(*args, **kwargs) - # Insert FrontendExprStmt here. - return _ti_core.insert_expr_stmt(func_call_result.ptr) - else: - # Call the non-Taichi function directly. - return func(*args, **kwargs) + _ti_core.make_tensor_element_expr(_var, make_expr_group(*_indices), + shape, stride)) class PyTaichi: def __init__(self, kernels=None): self.materialized = False self.prog = None - self.materialize_callbacks = [] self.compiled_functions = {} self.compiled_grad_functions = {} self.scope_stack = [] self.inside_kernel = False self.current_kernel = None self.global_vars = [] - self.print_preprocessed = False - self.experimental_real_function = False + self.matrix_fields = [] self.default_fp = f32 self.default_ip = i32 self.target_tape = None self.grad_replaced = False self.kernels = kernels or [] + self._signal_handler_registry = None def get_num_compiled_functions(self): return len(self.compiled_functions) + len(self.compiled_grad_functions) def set_default_fp(self, fp): - assert fp in [f32, f64] + assert fp in [f16, f32, f64] self.default_fp = fp default_cfg().default_fp = self.default_fp @@ -308,28 +263,37 @@ def create_program(self): if self.prog is None: self.prog = _ti_core.Program() - def materialize_root_fb(self, first): - if not root.finalized and not root.empty: - root.finalize() - elif first: - root.finalize(raise_warning=False) - + @staticmethod + def materialize_root_fb(is_first_call): if root.finalized: - global _root_fb - _root_fb = FieldsBuilder() + return + if not is_first_call and root.empty: + # We have to forcefully finalize when `is_first_call` is True (even + # if the root itself is empty), so that there is a valid struct + # llvm::Module, if no field has been declared before the first kernel + # invocation. Example case: + # https://github.com/taichi-dev/taichi/blob/27bb1dc3227d9273a79fcb318fdb06fd053068f5/tests/python/test_ad_basics.py#L260-L266 + return + root.finalize(raise_warning=not is_first_call) + global _root_fb + _root_fb = FieldsBuilder() - def materialize(self): - self.materialize_root_fb(not self.materialized) + @staticmethod + def _finalize_root_fb_for_aot(): + if _root_fb.finalized: + raise RuntimeError( + 'AOT: can only finalize the root FieldsBuilder once') + _root_fb._finalize_for_aot() - if self.materialized: - return + @staticmethod + def _get_tb(_var): + return getattr(_var, 'declaration_tb', str(_var.ptr)) - self.materialized = True + def _check_field_not_placed(self): not_placed = [] - for var in self.global_vars: - if var.ptr.snode() is None: - tb = getattr(var, 'declaration_tb', str(var.ptr)) - not_placed.append(tb) + for _var in self.global_vars: + if _var.ptr.snode() is None: + not_placed.append(self._get_tb(_var)) if len(not_placed): bar = '=' * 44 + '\n' @@ -339,14 +303,41 @@ def materialize(self): f'{bar}Please consider specifying a shape for them. E.g.,' + '\n\n x = ti.field(float, shape=(2, 3))') - for callback in self.materialize_callbacks: - callback() - self.materialize_callbacks = [] + def _check_matrix_field_member_shape(self): + for _field in self.matrix_fields: + shapes = [ + _field.get_scalar_field(i, j).shape for i in range(_field.n) + for j in range(_field.m) + ] + if any(shape != shapes[0] for shape in shapes): + raise RuntimeError( + 'Members of the following field have different shapes ' + + f'{shapes}:\n{self._get_tb(_field._get_field_members()[0])}' + ) + + def _calc_matrix_field_dynamic_index_stride(self): + for _field in self.matrix_fields: + _field._calc_dynamic_index_stride() + + def materialize(self): + self.materialize_root_fb(not self.materialized) + self.materialized = True + + self._check_field_not_placed() + self._check_matrix_field_member_shape() + self._calc_matrix_field_dynamic_index_stride() + self.global_vars = [] + self.matrix_fields = [] + + def _register_signal_handlers(self): + if self._signal_handler_registry is None: + self._signal_handler_registry = _ti_core.HackedSignalRegister() def clear(self): if self.prog: self.prog.finalize() self.prog = None + self._signal_handler_registry = None self.materialized = False def get_tape(self, loss=None): @@ -364,56 +355,6 @@ def get_runtime(): return pytaichi -def materialize_callback(foo): - get_runtime().materialize_callbacks.append(foo) - - -def _clamp_unsigned_to_range(npty, val): - # npty: np.int32 or np.int64 - iif = np.iinfo(npty) - if iif.min <= val <= iif.max: - return val - cap = (1 << iif.bits) - if not (0 <= val < cap): - # We let pybind11 fail intentionally, because this isn't the case we want - # to deal with: |val| does't fall into the valid range of either - # the signed or the unsigned type. - return val - new_val = val - cap - ti.warn( - f'Constant {val} has exceeded the range of {iif.bits} int, clamped to {new_val}' - ) - return new_val - - -@taichi_scope -def make_constant_expr(val): - _taichi_skip_traceback = 1 - if isinstance(val, (int, np.integer)): - if pytaichi.default_ip in {i32, u32}: - # It is not always correct to do such clamp without the type info on - # the LHS, but at least this makes assigning constant to unsigned - # int work. See https://github.com/taichi-dev/taichi/issues/2060 - return Expr( - _ti_core.make_const_expr_i32( - _clamp_unsigned_to_range(np.int32, val))) - elif pytaichi.default_ip in {i64, u64}: - return Expr( - _ti_core.make_const_expr_i64( - _clamp_unsigned_to_range(np.int64, val))) - else: - assert False - elif isinstance(val, (float, np.floating, np.ndarray)): - if pytaichi.default_fp == f32: - return Expr(_ti_core.make_const_expr_f32(val)) - elif pytaichi.default_fp == f64: - return Expr(_ti_core.make_const_expr_f64(val)) - else: - assert False - else: - raise ValueError(f'Invalid constant scalar expression: {type(val)}') - - def reset(): global pytaichi old_kernels = pytaichi.kernels @@ -431,7 +372,6 @@ def static_print(*args, __p=print, **kwargs): # we don't add @taichi_scope decorator for @ti.pyfunc to work def static_assert(cond, msg=None): - _taichi_skip_traceback = 1 if msg is not None: assert cond, msg else: @@ -451,7 +391,7 @@ def __getattr__(self, item): if item == '__qualname__': # For sphinx docstring extraction. return '_UninitializedRootFieldsBuilder' - raise InvalidOperationError('Please call init() first') + raise TaichiRuntimeError('Please call init() first') # `root` initialization must be delayed until after the program is @@ -469,26 +409,36 @@ def __getattr__(self, item): _root_fb = _UninitializedRootFieldsBuilder() +def deactivate_all_snodes(): + """Recursively deactivate all SNodes.""" + for root_fb in FieldsBuilder._finalized_roots(): + root_fb.deactivate_all() + + class _Root: """Wrapper around the default root FieldsBuilder instance.""" - def parent(self, n=1): + @staticmethod + def parent(n=1): """Same as :func:`taichi.SNode.parent`""" return _root_fb.root.parent(n) - def loop_range(self): + @staticmethod + def _loop_range(): """Same as :func:`taichi.SNode.loop_range`""" - return _root_fb.root.loop_range() + return _root_fb.root._loop_range() - def get_children(self): + @staticmethod + def _get_children(): """Same as :func:`taichi.SNode.get_children`""" - return _root_fb.root.get_children() + return _root_fb.root._get_children() # TODO: Record all of the SNodeTrees that finalized under 'ti.root' - def deactivate_all(self): + @staticmethod + def deactivate_all(): warning( """'ti.root.deactivate_all()' would deactivate all finalized snodes.""" ) - ti.deactivate_all_snodes() + deactivate_all_snodes() @property def shape(self): @@ -496,8 +446,8 @@ def shape(self): return _root_fb.root.shape @property - def id(self): - return _root_fb.root.id + def _id(self): + return _root_fb.root._id def __getattr__(self, item): return getattr(_root_fb, item) @@ -524,7 +474,7 @@ def create_field_member(dtype, name): # primal x = Expr(_ti_core.make_id_expr("")) - x.declaration_tb = get_traceback(stacklevel=2) + x.declaration_tb = get_traceback(stacklevel=4) x.ptr = _ti_core.global_new(x.ptr, dtype) x.ptr.set_name(name) x.ptr.set_is_primal(True) @@ -542,12 +492,6 @@ def create_field_member(dtype, name): return x, x_grad -@deprecated('ti.var', 'ti.field') -def var(dt, shape=None, offset=None, needs_grad=False): - _taichi_skip_traceback = 1 - return field(dt, shape, offset, needs_grad) - - @python_scope def field(dtype, shape=None, name="", offset=None, needs_grad=False): """Defines a Taichi field @@ -576,7 +520,6 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False): >>> x2 = ti.field(ti.f32) >>> ti.root.dense(ti.ij, shape=(16, 8)).place(x2) """ - _taichi_skip_traceback = 1 if isinstance(shape, numbers.Number): shape = (shape, ) @@ -590,13 +533,11 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False): ), f'The dimensionality of shape and offset must be the same ({len(shape)} != {len(offset)})' assert (offset is None or shape - is not None), f'The shape cannot be None when offset is being set' - - del _taichi_skip_traceback + is not None), 'The shape cannot be None when offset is being set' x, x_grad = create_field_member(dtype, name) x, x_grad = ScalarField(x), ScalarField(x_grad) - x.set_grad(x_grad) + x._set_grad(x_grad) if shape is not None: dim = len(shape) @@ -625,45 +566,43 @@ def ndarray(dtype, shape): @taichi_scope -def ti_print(*vars, sep=' ', end='\n'): - def entry2content(var): - if isinstance(var, str): - return var - else: - return Expr(var).ptr +def ti_print(*_vars, sep=' ', end='\n'): + def entry2content(_var): + if isinstance(_var, str): + return _var + return Expr(_var).ptr - def list_ti_repr(var): + def list_ti_repr(_var): yield '[' # distinguishing tuple & list will increase maintainance cost - for i, v in enumerate(var): + for i, v in enumerate(_var): if i: yield ', ' yield v yield ']' - def vars2entries(vars): - for var in vars: - if hasattr(var, '__ti_repr__'): - res = var.__ti_repr__() - elif isinstance(var, (list, tuple)): - res = var + def vars2entries(_vars): + for _var in _vars: + if hasattr(_var, '__ti_repr__'): + res = _var.__ti_repr__() + elif isinstance(_var, (list, tuple)): # If the first element is '__ti_format__', this list is the result of ti_format. - if len(var) > 0 and isinstance( - var[0], str) and var[0] == '__ti_format__': - res = var[1:] + if len(_var) > 0 and isinstance( + _var[0], str) and _var[0] == '__ti_format__': + res = _var[1:] else: - res = list_ti_repr(var) + res = list_ti_repr(_var) else: - yield var + yield _var continue for v in vars2entries(res): yield v - def add_separators(vars): - for i, var in enumerate(vars): + def add_separators(_vars): + for i, _var in enumerate(_vars): if i: yield sep - yield var + yield _var yield end def fused_string(entries): @@ -679,28 +618,34 @@ def fused_string(entries): if accumated: yield accumated - vars = add_separators(vars) - entries = vars2entries(vars) + _vars = add_separators(_vars) + entries = vars2entries(_vars) entries = fused_string(entries) contentries = [entry2content(entry) for entry in entries] - _ti_core.create_print(contentries) + get_runtime().prog.current_ast_builder().create_print(contentries) @taichi_scope -def ti_format(*args): +def ti_format(*args, **kwargs): content = args[0] mixed = args[1:] new_mixed = [] + new_mixed_kwargs = {} args = [] for x in mixed: - if isinstance(x, ti.Expr): + if isinstance(x, Expr): new_mixed.append('{}') args.append(x) else: new_mixed.append(x) - + for k, v in kwargs.items(): + if isinstance(v, Expr): + new_mixed_kwargs[k] = '{}' + args.append(v) + else: + new_mixed_kwargs[k] = v try: - content = content.format(*new_mixed) + content = content.format(*new_mixed, **new_mixed_kwargs) except ValueError: print('Number formatting is not supported with Taichi fields') exit(1) @@ -709,8 +654,8 @@ def ti_format(*args): args ) + 1, 'Number of args is different from number of positions provided in string' - for i in range(len(args)): - res.insert(i * 2 + 1, args[i]) + for i, arg in enumerate(args): + res.insert(i * 2 + 1, arg) res.insert(0, '__ti_format__') return res @@ -719,26 +664,22 @@ def ti_format(*args): def ti_assert(cond, msg, extra_args): # Mostly a wrapper to help us convert from Expr (defined in Python) to # _ti_core.Expr (defined in C++) - _ti_core.create_assert_stmt( + get_runtime().prog.current_ast_builder().create_assert_stmt( Expr(cond).ptr, msg, [Expr(x).ptr for x in extra_args]) @taichi_scope -def ti_int(var): - _taichi_skip_traceback = 1 - if hasattr(var, '__ti_int__'): - return var.__ti_int__() - else: - return int(var) +def ti_int(_var): + if hasattr(_var, '__ti_int__'): + return _var.__ti_int__() + return int(_var) @taichi_scope -def ti_float(var): - _taichi_skip_traceback = 1 - if hasattr(var, '__ti_float__'): - return var.__ti_float__() - else: - return float(var) +def ti_float(_var): + if hasattr(_var, '__ti_float__'): + return _var.__ti_float__() + return float(_var) @taichi_scope @@ -782,14 +723,6 @@ def axes(*x: Iterable[int]): return [_ti_core.Axis(i) for i in x] -@deprecated("ti.indices", "ti.axes") -def indices(*x): - """Same as :func:`~taichi.lang.impl.axes`.""" - return [_ti_core.Axis(i) for i in x] - - -index = indices - Axis = _ti_core.Axis @@ -833,24 +766,22 @@ def static(x, *xs): >>> print(1) >>> print(2) """ - _taichi_skip_traceback = 1 if len(xs): # for python-ish pointer assign: x, y = ti.static(y, x) return [static(x)] + [static(x) for x in xs] if isinstance(x, - (bool, int, float, range, list, tuple, enumerate, ti.ndrange, - ti.GroupedNDRange, zip, filter, map)) or x is None: + (bool, int, float, range, list, tuple, enumerate, _Ndrange, + GroupedNDRange, zip, filter, map)) or x is None: return x - elif isinstance(x, AnyArray): + if isinstance(x, AnyArray): return x - elif isinstance(x, Field): + if isinstance(x, Field): return x - elif isinstance(x, (FunctionType, MethodType)): + if isinstance(x, (FunctionType, MethodType)): return x - else: - raise ValueError( - f'Input to ti.static must be compile-time constants or global pointers, instead of {type(x)}' - ) + raise ValueError( + f'Input to ti.static must be compile-time constants or global pointers, instead of {type(x)}' + ) @taichi_scope @@ -862,21 +793,20 @@ def grouped(x): Example:: - >>> for I in ti.grouped(ti.ndrange(8, 16)): + >>> for I in ti.grouped(ndrange(8, 16)): >>> print(I[0] + I[1]) """ - if isinstance(x, ti.ndrange): + if isinstance(x, _Ndrange): return x.grouped() - else: - return x + return x def stop_grad(x): - _ti_core.stop_grad(x.snode.ptr) + get_runtime().prog.current_ast_builder().stop_grad(x.snode.ptr) def current_cfg(): - return _ti_core.current_compile_config() + return get_runtime().prog.config def default_cfg(): @@ -886,3 +816,19 @@ def default_cfg(): def call_internal(name, *args): return expr_init( _ti_core.insert_internal_func_call(name, make_expr_group(args))) + + +@taichi_scope +def mesh_relation_access(mesh, from_index, to_element_type): + # to support ti.mesh_local and access mesh attribute as field + if isinstance(from_index, MeshInstance): + return getattr(from_index, element_type_name(to_element_type)) + if isinstance(mesh, MeshInstance): + return MeshRelationAccessProxy(mesh, from_index, to_element_type) + raise RuntimeError("Relation access should be with a mesh instance!") + + +__all__ = [ + 'axes', 'deactivate_all_snodes', 'field', 'grouped', 'ndarray', 'one', + 'root', 'static', 'static_assert', 'static_print', 'stop_grad', 'zero' +] diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index e603a87d63c6d..aa1bf8a3c7324 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -1,64 +1,78 @@ import taichi.lang -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core +from taichi.lang import impl, ops from taichi.lang.any_array import AnyArray from taichi.lang.enums import Layout from taichi.lang.expr import Expr +from taichi.lang.matrix import Matrix, MatrixType from taichi.lang.util import cook_dtype -from taichi.linalg import SparseMatrixBuilder -from taichi.type.primitive_types import u64 +from taichi.types.primitive_types import u64 class SparseMatrixEntry: - def __init__(self, ptr, i, j): + def __init__(self, ptr, i, j, dtype): self.ptr = ptr self.i = i self.j = j + self.dtype = dtype - def augassign(self, value, op): + def _augassign(self, value, op): + call_func = f"insert_triplet_{self.dtype}" if op == 'Add': - taichi.lang.impl.call_internal("insert_triplet", self.ptr, self.i, - self.j, - taichi.lang.impl.ti_float(value)) + taichi.lang.impl.call_internal(call_func, self.ptr, self.i, self.j, + ops.cast(value, self.dtype)) elif op == 'Sub': - taichi.lang.impl.call_internal("insert_triplet", self.ptr, self.i, - self.j, - -taichi.lang.impl.ti_float(value)) + taichi.lang.impl.call_internal(call_func, self.ptr, self.i, self.j, + -ops.cast(value, self.dtype)) else: - assert False, f"Only operations '+=' and '-=' are supported on sparse matrices." + assert False, "Only operations '+=' and '-=' are supported on sparse matrices." class SparseMatrixProxy: - def __init__(self, ptr): + def __init__(self, ptr, dtype): self.ptr = ptr + self.dtype = dtype def subscript(self, i, j): - return SparseMatrixEntry(self.ptr, i, j) + return SparseMatrixEntry(self.ptr, i, j, self.dtype) def decl_scalar_arg(dtype): dtype = cook_dtype(dtype) - arg_id = _ti_core.decl_arg(dtype, False) + arg_id = impl.get_runtime().prog.decl_arg(dtype, False) return Expr(_ti_core.make_arg_load_expr(arg_id, dtype)) -def decl_sparse_matrix(): +def decl_matrix_arg(matrixtype): + return Matrix( + [[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.m)] + for _ in range(matrixtype.n)]) + + +def decl_sparse_matrix(dtype): + value_type = cook_dtype(dtype) ptr_type = cook_dtype(u64) # Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer - arg_id = _ti_core.decl_arg(ptr_type, False) - return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type)) + arg_id = impl.get_runtime().prog.decl_arg(ptr_type, False) + return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type), + value_type) def decl_any_arr_arg(dtype, dim, element_shape, layout): dtype = cook_dtype(dtype) - arg_id = _ti_core.decl_arg(dtype, True) element_dim = len(element_shape) + arg_id = impl.get_runtime().prog.decl_arr_arg(dtype, dim, element_shape) if layout == Layout.AOS: element_dim = -element_dim return AnyArray( - _ti_core.make_external_tensor_expr(dtype, dim, arg_id, element_dim), - element_shape, layout) + _ti_core.make_external_tensor_expr(dtype, dim, arg_id, element_dim, + element_shape), element_shape, + layout) -def decl_scalar_ret(dtype): - dtype = cook_dtype(dtype) - return _ti_core.decl_ret(dtype) +def decl_ret(dtype): + if isinstance(dtype, MatrixType): + dtype = _ti_core.decl_tensor_type([dtype.n, dtype.m], dtype.dtype) + else: + dtype = cook_dtype(dtype) + return impl.get_runtime().prog.decl_ret(dtype) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 1e5913dc2afc5..7ccad655421ff 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -1,49 +1,33 @@ import ast -import copy import functools import inspect import re +import sys +import textwrap import numpy as np import taichi.lang -from taichi.core.util import ti_core as _ti_core -from taichi.lang import impl, util -from taichi.lang.ast.checkers import KernelSimplicityASTChecker -from taichi.lang.ast.transformer import ASTTransformerTotal +from taichi._lib import core as _ti_core +from taichi.lang import impl, runtime_ops +from taichi.lang.ast import (ASTTransformerContext, KernelSimplicityASTChecker, + transform_tree) from taichi.lang.enums import Layout -from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError, + TaichiRuntimeTypeError, TaichiSyntaxError) +from taichi.lang.expr import Expr +from taichi.lang.matrix import Matrix, MatrixType from taichi.lang.shell import _shell_pop_print, oinspect -from taichi.lang.util import to_taichi_type -from taichi.linalg.sparse_matrix import sparse_matrix_builder -from taichi.misc.util import obsolete -from taichi.type import any_arr, primitive_types, template +from taichi.lang.util import has_pytorch, to_taichi_type +from taichi.types import (any_arr, primitive_types, sparse_matrix_builder, + template) -import taichi as ti +from taichi import _logging -if util.has_pytorch(): +if has_pytorch(): import torch -def _remove_indent(lines): - lines = lines.split('\n') - to_remove = 0 - for i in range(len(lines[0])): - if lines[0][i] == ' ': - to_remove = i + 1 - else: - break - - cleaned = [] - for l in lines: - cleaned.append(l[to_remove:]) - if len(l) >= to_remove: - for i in range(to_remove): - assert l[i] == ' ' - - return '\n'.join(cleaned) - - -def func(fn): +def func(fn, is_real_function=False): """Marks a function as callable in Taichi-scope. This decorator transforms a Python function into a Taichi one. Taichi @@ -51,6 +35,7 @@ def func(fn): Args: fn (Callable): The Python function to be decorated + is_real_function (bool): Whether the function is a real function Returns: Callable: The decorated function @@ -67,18 +52,21 @@ def func(fn): """ is_classfunc = _inside_class(level_of_class_stackframe=3) - _taichi_skip_traceback = 1 - fun = Func(fn, classfunc=is_classfunc) + fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function) @functools.wraps(fn) def decorated(*args): - _taichi_skip_traceback = 1 return fun.__call__(*args) decorated._is_taichi_function = True + decorated._is_real_function = is_real_function return decorated +def real_func(fn): + return func(fn, is_real_function=True) + + def pyfunc(fn): """Marks a function as callable in both Taichi and Python scopes. @@ -95,54 +83,99 @@ def pyfunc(fn): Callable: The decorated function """ is_classfunc = _inside_class(level_of_class_stackframe=3) - fun = Func(fn, classfunc=is_classfunc, pyfunc=True) + fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True) @functools.wraps(fn) def decorated(*args): - _taichi_skip_traceback = 1 return fun.__call__(*args) decorated._is_taichi_function = True return decorated +def _get_tree_and_ctx(self, + excluded_parameters=(), + is_kernel=True, + arg_features=None, + args=None, + ast_builder=None, + is_real_function=False): + file = oinspect.getsourcefile(self.func) + src, start_lineno = oinspect.getsourcelines(self.func) + src = [textwrap.fill(line, tabsize=4, width=9999) for line in src] + tree = ast.parse(textwrap.dedent("\n".join(src))) + + func_body = tree.body[0] + func_body.decorator_list = [] + + global_vars = _get_global_vars(self.func) + + for i, arg in enumerate(func_body.args.args): + anno = arg.annotation + if isinstance(anno, ast.Name): + global_vars[anno.id] = self.argument_annotations[i] + + if isinstance(func_body.returns, ast.Name): + global_vars[func_body.returns.id] = self.return_type + + if is_kernel or is_real_function: + # inject template parameters into globals + for i in self.template_slot_locations: + template_var_name = self.argument_names[i] + global_vars[template_var_name] = args[i] + + return tree, ASTTransformerContext(excluded_parameters=excluded_parameters, + is_kernel=is_kernel, + func=self, + arg_features=arg_features, + global_vars=global_vars, + argument_data=args, + src=src, + start_lineno=start_lineno, + file=file, + ast_builder=ast_builder, + is_real_function=is_real_function) + + class Func: function_counter = 0 - def __init__(self, func, classfunc=False, pyfunc=False): - self.func = func + def __init__(self, + _func, + _classfunc=False, + _pyfunc=False, + is_real_function=False): + self.func = _func self.func_id = Func.function_counter Func.function_counter += 1 self.compiled = None - self.classfunc = classfunc - self.pyfunc = pyfunc + self.classfunc = _classfunc + self.pyfunc = _pyfunc + self.is_real_function = is_real_function self.argument_annotations = [] self.argument_names = [] - _taichi_skip_traceback = 1 + self.return_type = None self.extract_arguments() self.template_slot_locations = [] - for i in range(len(self.argument_annotations)): - if isinstance(self.argument_annotations[i], template): + for i, anno in enumerate(self.argument_annotations): + if isinstance(anno, template): self.template_slot_locations.append(i) self.mapper = TaichiCallableTemplateMapper( self.argument_annotations, self.template_slot_locations) self.taichi_functions = {} # The |Function| class in C++ def __call__(self, *args): - _taichi_skip_traceback = 1 if not impl.inside_kernel(): if not self.pyfunc: raise TaichiSyntaxError( - "Taichi functions cannot be called from Python-scope." - " Use @ti.pyfunc if you wish to call Taichi functions " - "from both Python-scope and Taichi-scope.") + "Taichi functions cannot be called from Python-scope.") return self.func(*args) - if impl.get_runtime().experimental_real_function: + if self.is_real_function: if impl.get_runtime().current_kernel.is_grad: raise TaichiSyntaxError( "Real function in gradient kernels unsupported.") - instance_id, arg_features = self.mapper.lookup(args) + instance_id, _ = self.mapper.lookup(args) key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id) if self.compiled is None: @@ -150,58 +183,46 @@ def __call__(self, *args): if key.instance_id not in self.compiled: self.do_compile(key=key, args=args) return self.func_call_rvalue(key=key, args=args) - else: - if self.compiled is None: - self.do_compile(key=None, args=args) - ret = self.compiled(*args) - return ret + tree, ctx = _get_tree_and_ctx( + self, + is_kernel=False, + args=args, + ast_builder=impl.get_runtime().prog.current_ast_builder(), + is_real_function=self.is_real_function) + ret = transform_tree(tree, ctx) + if not self.is_real_function: + if self.return_type and not ctx.returned: + raise TaichiSyntaxError( + "Function has a return type but does not have a return statement" + ) + return ret def func_call_rvalue(self, key, args): # Skip the template args, e.g., |self| - assert impl.get_runtime().experimental_real_function + assert self.is_real_function non_template_args = [] - for i in range(len(self.argument_annotations)): - if not isinstance(self.argument_annotations[i], template): + for i, anno in enumerate(self.argument_annotations): + if not isinstance(anno, template): non_template_args.append(args[i]) non_template_args = impl.make_expr_group(non_template_args) - return ti.Expr( + return Expr( _ti_core.make_func_call_expr( self.taichi_functions[key.instance_id], non_template_args)) def do_compile(self, key, args): - src = _remove_indent(oinspect.getsource(self.func)) - tree = ast.parse(src) - - func_body = tree.body[0] - func_body.decorator_list = [] - - visitor = ASTTransformerTotal(is_kernel=False, func=self) - visitor.visit(tree) - - ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) + tree, ctx = _get_tree_and_ctx(self, + is_kernel=False, + args=args, + is_real_function=self.is_real_function) + fn = impl.get_runtime().prog.create_function(key) - local_vars = {} - global_vars = _get_global_vars(self.func) + def func_body(): + ctx.ast_builder = fn.ast_builder() + transform_tree(tree, ctx) - if impl.get_runtime().experimental_real_function: - # inject template parameters into globals - for i in self.template_slot_locations: - template_var_name = self.argument_names[i] - global_vars[template_var_name] = args[i] - - exec( - compile(tree, - filename=oinspect.getsourcefile(self.func), - mode='exec'), global_vars, local_vars) - - if impl.get_runtime().experimental_real_function: - self.compiled[key.instance_id] = local_vars[self.func.__name__] - self.taichi_functions[key.instance_id] = _ti_core.create_function( - key) - self.taichi_functions[key.instance_id].set_function_body( - self.compiled[key.instance_id]) - else: - self.compiled = local_vars[self.func.__name__] + self.taichi_functions[key.instance_id] = fn + self.compiled[key.instance_id] = func_body + self.taichi_functions[key.instance_id].set_function_body(func_body) def extract_arguments(self): sig = inspect.signature(self.func) @@ -212,18 +233,18 @@ def extract_arguments(self): for i, arg_name in enumerate(arg_names): param = params[arg_name] if param.kind == inspect.Parameter.VAR_KEYWORD: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi functions do not support variable keyword parameters (i.e., **kwargs)' ) if param.kind == inspect.Parameter.VAR_POSITIONAL: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi functions do not support variable positional parameters (i.e., *args)' ) if param.kind == inspect.Parameter.KEYWORD_ONLY: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi functions do not support keyword parameters') if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi functions only support "positional or keyword" parameters' ) annotation = param.annotation @@ -232,16 +253,15 @@ def extract_arguments(self): annotation = template() # TODO: pyfunc also need type annotation check when real function is enabled, # but that has to happen at runtime when we know which scope it's called from. - elif not self.pyfunc and impl.get_runtime( - ).experimental_real_function: - raise KernelDefError( + elif not self.pyfunc and self.is_real_function: + raise TaichiSyntaxError( f'Taichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated' ) else: if not id(annotation ) in primitive_types.type_ids and not isinstance( annotation, template): - raise KernelDefError( + raise TaichiSyntaxError( f'Invalid type annotation (argument {i}) of Taichi function: {annotation}' ) self.argument_annotations.append(annotation) @@ -269,17 +289,23 @@ def extract_arg(arg, anno): TaichiCallableTemplateMapper.extract_arg(item, anno) for item in arg) return arg - elif isinstance(anno, any_arr): + if isinstance(anno, any_arr): if isinstance(arg, taichi.lang._ndarray.ScalarNdarray): - anno.check_element_dim(arg, 0) + anno._check_element_dim(arg, 0) + anno._check_element_shape(()) + anno._check_field_dim(len(arg.shape)) return arg.dtype, len(arg.shape), (), Layout.AOS if isinstance(arg, taichi.lang.matrix.VectorNdarray): - anno.check_element_dim(arg, 1) - anno.check_layout(arg) + anno._check_element_dim(arg, 1) + anno._check_element_shape((arg.n, )) + anno._check_field_dim(len(arg.shape)) + anno._check_layout(arg) return arg.dtype, len(arg.shape) + 1, (arg.n, ), arg.layout if isinstance(arg, taichi.lang.matrix.MatrixNdarray): - anno.check_element_dim(arg, 2) - anno.check_layout(arg) + anno._check_element_dim(arg, 2) + anno._check_element_shape((arg.n, arg.m)) + anno._check_field_dim(len(arg.shape)) + anno._check_layout(arg) return arg.dtype, len(arg.shape) + 2, (arg.n, arg.m), arg.layout # external arrays @@ -288,14 +314,17 @@ def extract_arg(arg, anno): shape = tuple(arg.shape) if len(shape) < element_dim: raise ValueError( - f"Invalid argument into ti.any_arr() - required element_dim={element_dim}, but the argument has only {len(shape)} dimensions" - ) + f"Invalid argument into ti.any_arr() - required element_dim={element_dim}, " + f"but the argument has only {len(shape)} dimensions") element_shape = ( ) if element_dim == 0 else shape[: element_dim] if layout == Layout.SOA else shape[ -element_dim:] return to_taichi_type(arg.dtype), len(shape), element_shape, layout - return (type(arg).__name__, ) + if isinstance(anno, sparse_matrix_builder): + return arg.dtype + # Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation + return '#' def extract(self, args): extracted = [] @@ -305,7 +334,6 @@ def extract(self, args): def lookup(self, args): if len(args) != self.num_args: - _taichi_skip_traceback = 1 raise TypeError( f'{self.num_args} argument(s) needed but {len(args)} provided.' ) @@ -317,30 +345,25 @@ def lookup(self, args): return self.mapping[key], key -class KernelDefError(Exception): - def __init__(self, msg): - super().__init__(msg) - - -class KernelArgError(Exception): - def __init__(self, pos, needed, provided): - message = f'Argument {pos} (type={provided}) cannot be converted into required type {needed}' - super().__init__(message) - self.pos = pos - self.needed = needed - self.provided = provided +def _get_global_vars(_func): + # Discussions: https://github.com/taichi-dev/taichi/issues/282 + global_vars = _func.__globals__.copy() + freevar_names = _func.__code__.co_freevars + closure = _func.__closure__ + if closure: + freevar_values = list(map(lambda x: x.cell_contents, closure)) + for name, value in zip(freevar_names, freevar_values): + global_vars[name] = value -def _get_global_vars(func): - closure_vars = inspect.getclosurevars(func) - return {**closure_vars.globals, **closure_vars.nonlocals} + return global_vars class Kernel: counter = 0 - def __init__(self, func, is_grad, classkernel=False): - self.func = func + def __init__(self, _func, is_grad, _classkernel=False): + self.func = _func self.kernel_counter = Kernel.counter Kernel.counter += 1 self.is_grad = is_grad @@ -348,13 +371,11 @@ def __init__(self, func, is_grad, classkernel=False): self.argument_annotations = [] self.argument_names = [] self.return_type = None - self.classkernel = classkernel - _taichi_skip_traceback = 1 + self.classkernel = _classkernel self.extract_arguments() - del _taichi_skip_traceback self.template_slot_locations = [] - for i in range(len(self.argument_annotations)): - if isinstance(self.argument_annotations[i], template): + for i, anno in enumerate(self.argument_annotations): + if isinstance(anno, template): self.template_slot_locations.append(i) self.mapper = TaichiCallableTemplateMapper( self.argument_annotations, self.template_slot_locations) @@ -378,22 +399,22 @@ def extract_arguments(self): for i, arg_name in enumerate(arg_names): param = params[arg_name] if param.kind == inspect.Parameter.VAR_KEYWORD: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi kernels do not support variable keyword parameters (i.e., **kwargs)' ) if param.kind == inspect.Parameter.VAR_POSITIONAL: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi kernels do not support variable positional parameters (i.e., *args)' ) if param.default is not inspect.Parameter.empty: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi kernels do not support default values for arguments' ) if param.kind == inspect.Parameter.KEYWORD_ONLY: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi kernels do not support keyword parameters') if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD: - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi kernels only support "positional or keyword" parameters' ) annotation = param.annotation @@ -401,8 +422,7 @@ def extract_arguments(self): if i == 0 and self.classkernel: # The |self| parameter annotation = template() else: - _taichi_skip_traceback = 1 - raise KernelDefError( + raise TaichiSyntaxError( 'Taichi kernels parameters must be type annotated') else: if isinstance(annotation, (template, any_arr)): @@ -411,16 +431,16 @@ def extract_arguments(self): pass elif isinstance(annotation, sparse_matrix_builder): pass + elif isinstance(annotation, MatrixType): + pass else: - _taichi_skip_traceback = 1 - raise KernelDefError( + raise TaichiSyntaxError( f'Invalid type annotation (argument {i}) of Taichi kernel: {annotation}' ) self.argument_annotations.append(annotation) self.argument_names.append(param.name) def materialize(self, key=None, args=None, arg_features=None): - _taichi_skip_traceback = 1 if key is None: key = (self.func, 0) self.runtime.materialize() @@ -429,86 +449,100 @@ def materialize(self, key=None, args=None, arg_features=None): grad_suffix = "" if self.is_grad: grad_suffix = "_grad" - kernel_name = "{}_c{}_{}{}".format(self.func.__name__, - self.kernel_counter, key[1], - grad_suffix) - ti.trace("Compiling kernel {}...".format(kernel_name)) - - src = _remove_indent(oinspect.getsource(self.func)) - tree = ast.parse(src) - - func_body = tree.body[0] - func_body.decorator_list = [] - - local_vars = {} - global_vars = _get_global_vars(self.func) - - for i, arg in enumerate(func_body.args.args): - anno = arg.annotation - if isinstance(anno, ast.Name): - global_vars[anno.id] = self.argument_annotations[i] - - if isinstance(func_body.returns, ast.Name): - global_vars[func_body.returns.id] = self.return_type + kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}{grad_suffix}" + _logging.trace(f"Compiling kernel {kernel_name}...") - if self.is_grad: - KernelSimplicityASTChecker(self.func).visit(tree) - - visitor = ASTTransformerTotal( + tree, ctx = _get_tree_and_ctx( + self, + args=args, excluded_parameters=self.template_slot_locations, - func=self, arg_features=arg_features) - visitor.visit(tree) - - ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) - - # inject template parameters into globals - for i in self.template_slot_locations: - template_var_name = self.argument_names[i] - global_vars[template_var_name] = args[i] - - exec( - compile(tree, - filename=oinspect.getsourcefile(self.func), - mode='exec'), global_vars, local_vars) - compiled = local_vars[self.func.__name__] + if self.is_grad: + KernelSimplicityASTChecker(self.func).visit(tree) # Do not change the name of 'taichi_ast_generator' # The warning system needs this identifier to remove unnecessary messages - def taichi_ast_generator(): - _taichi_skip_traceback = 1 + def taichi_ast_generator(kernel_cxx): if self.runtime.inside_kernel: raise TaichiSyntaxError( - "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope." - ) + "Kernels cannot call other kernels. I.e., nested kernels are not allowed. " + "Please check if you have direct/indirect invocation of kernels within kernels. " + "Note that some methods provided by the Taichi standard library may invoke kernels, " + "and please move their invocations to Python-scope.") self.runtime.inside_kernel = True self.runtime.current_kernel = self try: - compiled() + ctx.ast_builder = kernel_cxx.ast_builder() + transform_tree(tree, ctx) + if not ctx.is_real_function: + if self.return_type and not ctx.returned: + raise TaichiSyntaxError( + "Kernel has a return type but does not have a return statement" + ) finally: self.runtime.inside_kernel = False self.runtime.current_kernel = None - taichi_kernel = _ti_core.create_kernel(taichi_ast_generator, - kernel_name, self.is_grad) + taichi_kernel = impl.get_runtime().prog.create_kernel( + taichi_ast_generator, kernel_name, self.is_grad) self.kernel_cpp = taichi_kernel assert key not in self.compiled_functions self.compiled_functions[key] = self.get_function_body(taichi_kernel) + def get_torch_callbacks(self, v, has_torch, is_ndarray=True): + callbacks = [] + + def get_call_back(u, v): + def call_back(): + u.copy_(v) + + return call_back + + assert has_torch + assert isinstance(v, torch.Tensor) + if not v.is_contiguous(): + raise ValueError( + "Non contiguous tensors are not supported, please call tensor.contiguous() before passing it into taichi kernel." + ) + tmp = v + taichi_arch = self.runtime.prog.config.arch + # Ndarray means its memory is allocated on the specified taichi arch. + # Since torch only supports CPU & CUDA, torch-base ndarray only supports + # taichi cpu/cuda backend as well. + # Note I put x64/arm64/cuda here to be more specific. + assert not is_ndarray or taichi_arch in ( + _ti_core.Arch.cuda, _ti_core.Arch.x64, _ti_core.Arch.arm64 + ), "Torch-based ndarray is only supported on taichi x64/arm64/cuda backend." + + if str(v.device).startswith('cuda'): + # External tensor on cuda + if taichi_arch != _ti_core.Arch.cuda: + # copy data back to cpu + host_v = v.to(device='cpu', copy=True) + tmp = host_v + callbacks.append(get_call_back(v, host_v)) + else: + # External tensor on cpu + if taichi_arch == _ti_core.Arch.cuda: + gpu_v = v.cuda() + tmp = gpu_v + callbacks.append(get_call_back(v, gpu_v)) + return tmp, callbacks + def get_function_body(self, t_kernel): # The actual function body def func__(*args): assert len(args) == len( self.argument_annotations - ), '{} arguments needed but {} provided'.format( - len(self.argument_annotations), len(args)) + ), f'{len(self.argument_annotations)} arguments needed but {len(args)} provided' tmps = [] callbacks = [] has_external_arrays = False + has_torch = has_pytorch() actual_argument_slot = 0 launch_ctx = t_kernel.make_launch_context() @@ -520,67 +554,67 @@ def func__(*args): # Note: do not use sth like "needed == f32". That would be slow. if id(needed) in primitive_types.real_type_ids: if not isinstance(v, (float, int)): - raise KernelArgError(i, needed.to_string(), provided) + raise TaichiRuntimeTypeError(i, needed.to_string(), + provided) launch_ctx.set_arg_float(actual_argument_slot, float(v)) elif id(needed) in primitive_types.integer_type_ids: if not isinstance(v, int): - raise KernelArgError(i, needed.to_string(), provided) + raise TaichiRuntimeTypeError(i, needed.to_string(), + provided) launch_ctx.set_arg_int(actual_argument_slot, int(v)) elif isinstance(needed, sparse_matrix_builder): - # Pass only the base pointer of the ti.linalg.sparse_matrix_builder() argument - launch_ctx.set_arg_int(actual_argument_slot, v.get_addr()) - elif isinstance(needed, any_arr) and ( - self.match_ext_arr(v) - or isinstance(v, taichi.lang._ndarray.Ndarray)): - if isinstance(v, taichi.lang._ndarray.Ndarray): - v = v.arr + # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument + launch_ctx.set_arg_int(actual_argument_slot, v._get_addr()) + elif isinstance(needed, any_arr) and isinstance( + v, taichi.lang._ndarray.Ndarray): + has_external_arrays = True + v = v.arr + launch_ctx.set_arg_ndarray(actual_argument_slot, v) + elif isinstance(needed, any_arr) and (self.match_ext_arr(v)): has_external_arrays = True is_numpy = isinstance(v, np.ndarray) if is_numpy: tmp = np.ascontiguousarray(v) # Purpose: DO NOT GC |tmp|! tmps.append(tmp) - launch_ctx.set_arg_external_array( + launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(tmp.ctypes.data), - tmp.nbytes) + tmp.nbytes, v.shape) else: - - def get_call_back(u, v): - def call_back(): - u.copy_(v) - - return call_back - - assert util.has_pytorch() - assert isinstance(v, torch.Tensor) - tmp = v - taichi_arch = self.runtime.prog.config.arch - - if str(v.device).startswith('cuda'): - # External tensor on cuda - if taichi_arch != _ti_core.Arch.cuda: - # copy data back to cpu - host_v = v.to(device='cpu', copy=True) - tmp = host_v - callbacks.append(get_call_back(v, host_v)) - else: - # External tensor on cpu - if taichi_arch == _ti_core.Arch.cuda: - gpu_v = v.cuda() - tmp = gpu_v - callbacks.append(get_call_back(v, gpu_v)) - launch_ctx.set_arg_external_array( + is_ndarray = False + tmp, torch_callbacks = self.get_torch_callbacks( + v, has_torch, is_ndarray) + callbacks += torch_callbacks + launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(tmp.data_ptr()), - tmp.element_size() * tmp.nelement()) - shape = v.shape - max_num_indices = _ti_core.get_max_num_indices() - assert len( - shape - ) <= max_num_indices, "External array cannot have > {} indices".format( - max_num_indices) - for ii, s in enumerate(shape): - launch_ctx.set_extra_arg_int(actual_argument_slot, ii, - s) + tmp.element_size() * tmp.nelement(), v.shape) + + elif isinstance(needed, MatrixType): + if id(needed.dtype) in primitive_types.real_type_ids: + for a in range(needed.n): + for b in range(needed.m): + if not isinstance(v[a, b], (int, float)): + raise TaichiRuntimeTypeError( + i, needed.dtype.to_string(), + type(v[a, b])) + launch_ctx.set_arg_float( + actual_argument_slot, float(v[a, b])) + actual_argument_slot += 1 + elif id(needed.dtype) in primitive_types.integer_type_ids: + for a in range(needed.n): + for b in range(needed.m): + if not isinstance(v[a, b], int): + raise TaichiRuntimeTypeError( + i, needed.dtype.to_string(), + type(v[a, b])) + launch_ctx.set_arg_int(actual_argument_slot, + int(v[a, b])) + actual_argument_slot += 1 + else: + raise ValueError( + f'Matrix dtype {needed.dtype} is not integer type or real type.' + ) + continue else: raise ValueError( f'Argument type mismatch. Expecting {needed}, got {type(v)}.' @@ -592,21 +626,43 @@ def call_back(): if not self.is_grad and self.runtime.target_tape and not self.runtime.grad_replaced: self.runtime.target_tape.insert(self, args) + if actual_argument_slot > 8 and ( + impl.current_cfg().arch == _ti_core.opengl + or impl.current_cfg().arch == _ti_core.cc): + raise TaichiRuntimeError( + f"The number of elements in kernel arguments is too big! Do not exceed 8 on {_ti_core.arch_name(impl.current_cfg().arch)} backend." + ) + + if actual_argument_slot > 64 and ( + (impl.current_cfg().arch != _ti_core.opengl + and impl.current_cfg().arch != _ti_core.cc)): + raise TaichiRuntimeError( + f"The number of elements in kernel arguments is too big! Do not exceed 64 on {_ti_core.arch_name(impl.current_cfg().arch)} backend." + ) + t_kernel(launch_ctx) ret = None ret_dt = self.return_type has_ret = ret_dt is not None - if has_external_arrays or has_ret: - ti.sync() + if has_ret or (impl.current_cfg().async_mode + and has_external_arrays): + runtime_ops.sync() if has_ret: if id(ret_dt) in primitive_types.integer_type_ids: ret = t_kernel.get_ret_int(0) - else: + elif id(ret_dt) in primitive_types.real_type_ids: ret = t_kernel.get_ret_float(0) - + elif id(ret_dt.dtype) in primitive_types.integer_type_ids: + it = iter(t_kernel.get_ret_int_tensor(0)) + ret = Matrix([[next(it) for _ in range(ret_dt.m)] + for _ in range(ret_dt.n)]) + else: + it = iter(t_kernel.get_ret_float_tensor(0)) + ret = Matrix([[next(it) for _ in range(ret_dt.m)] + for _ in range(ret_dt.n)]) if callbacks: for c in callbacks: c() @@ -615,9 +671,10 @@ def call_back(): return func__ - def match_ext_arr(self, v): + @staticmethod + def match_ext_arr(v): has_array = isinstance(v, np.ndarray) - if not has_array and util.has_pytorch(): + if not has_array and has_pytorch(): has_array = isinstance(v, torch.Tensor) return has_array @@ -631,7 +688,11 @@ def ensure_compiled(self, *args): # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU) @_shell_pop_print def __call__(self, *args, **kwargs): - _taichi_skip_traceback = 1 + if self.is_grad and impl.current_cfg().opt_level == 0: + _logging.warn( + """opt_level = 1 is enforced to enable gradient computation.""" + ) + impl.current_cfg().opt_level = 1 assert len(kwargs) == 0, 'kwargs not supported for Taichi kernels' key = self.ensure_compiled(*args) return self.compiled_functions[key](*args) @@ -657,10 +718,9 @@ def __call__(self, *args, **kwargs): def _inside_class(level_of_class_stackframe): - frames = oinspect.stack() try: - maybe_class_frame = frames[level_of_class_stackframe] - statement_list = maybe_class_frame[4] + maybe_class_frame = sys._getframe(level_of_class_stackframe) + statement_list = inspect.getframeinfo(maybe_class_frame)[3] first_statment = statement_list[0].strip() for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES: if pat.match(first_statment): @@ -670,16 +730,15 @@ def _inside_class(level_of_class_stackframe): return False -def _kernel_impl(func, level_of_class_stackframe, verbose=False): +def _kernel_impl(_func, level_of_class_stackframe, verbose=False): # Can decorators determine if a function is being defined inside a class? # https://stackoverflow.com/a/8793684/12003165 is_classkernel = _inside_class(level_of_class_stackframe + 1) - _taichi_skip_traceback = 1 if verbose: - print(f'kernel={func.__name__} is_classkernel={is_classkernel}') - primal = Kernel(func, is_grad=False, classkernel=is_classkernel) - adjoint = Kernel(func, is_grad=True, classkernel=is_classkernel) + print(f'kernel={_func.__name__} is_classkernel={is_classkernel}') + primal = Kernel(_func, is_grad=False, _classkernel=is_classkernel) + adjoint = Kernel(_func, is_grad=True, _classkernel=is_classkernel) # Having |primal| contains |grad| makes the tape work. primal.grad = adjoint @@ -691,22 +750,23 @@ def _kernel_impl(func, level_of_class_stackframe, verbose=False): # owning the kernel, which is not known until the kernel is accessed. # # See also: _BoundedDifferentiableMethod, data_oriented. - @functools.wraps(func) + @functools.wraps(_func) def wrapped(*args, **kwargs): - _taichi_skip_traceback = 1 # If we reach here (we should never), it means the class is not decorated # with @ti.data_oriented, otherwise getattr would have intercepted the call. clsobj = type(args[0]) assert not hasattr(clsobj, '_data_oriented') - raise KernelDefError( + raise TaichiSyntaxError( f'Please decorate class {clsobj.__name__} with @ti.data_oriented' ) else: - @functools.wraps(func) + @functools.wraps(_func) def wrapped(*args, **kwargs): - _taichi_skip_traceback = 1 - return primal(*args, **kwargs) + try: + return primal(*args, **kwargs) + except TaichiCompilationError as e: + raise type(e)('\n' + str(e)) from None wrapped.grad = adjoint @@ -745,34 +805,28 @@ def kernel(fn): >>> for i in x: >>> x[i] = i """ - _taichi_skip_traceback = 1 return _kernel_impl(fn, level_of_class_stackframe=3) -classfunc = obsolete('@ti.classfunc', '@ti.func directly') -classkernel = obsolete('@ti.classkernel', '@ti.kernel directly') - - class _BoundedDifferentiableMethod: def __init__(self, kernel_owner, wrapped_kernel_func): clsobj = type(kernel_owner) if not getattr(clsobj, '_data_oriented', False): - raise KernelDefError( + raise TaichiSyntaxError( f'Please decorate class {clsobj.__name__} with @ti.data_oriented' ) self._kernel_owner = kernel_owner self._primal = wrapped_kernel_func._primal self._adjoint = wrapped_kernel_func._adjoint self._is_staticmethod = wrapped_kernel_func._is_staticmethod + self.__name__ = None def __call__(self, *args, **kwargs): - _taichi_skip_traceback = 1 if self._is_staticmethod: return self._primal(*args, **kwargs) return self._primal(self._kernel_owner, *args, **kwargs) def grad(self, *args, **kwargs): - _taichi_skip_traceback = 1 return self._adjoint(self._kernel_owner, *args, **kwargs) @@ -806,7 +860,6 @@ def data_oriented(cls): The decorated class. """ def _getattr(self, item): - _taichi_skip_traceback = 1 method = cls.__dict__.get(item, None) is_property = method.__class__ == property is_staticmethod = method.__class__ == staticmethod @@ -835,3 +888,6 @@ def _getattr(self, item): cls._data_oriented = True return cls + + +__all__ = ["data_oriented", "func", "kernel"] diff --git a/python/taichi/lang/linalg_impl.py b/python/taichi/lang/linalg_impl.py deleted file mode 100644 index 8e2f8c2a96fa7..0000000000000 --- a/python/taichi/lang/linalg_impl.py +++ /dev/null @@ -1,262 +0,0 @@ -from taichi.core.util import ti_core as _ti_core -from taichi.lang.impl import expr_init -from taichi.lang.kernel_impl import func - -import taichi as ti - - -@func -def polar_decompose2d(A, dt): - """Perform polar decomposition (A=UP) for 2x2 matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. - - Args: - A (ti.Matrix(2, 2)): input 2x2 matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed 2x2 matrices `U` and `P`. - """ - x, y = A(0, 0) + A(1, 1), A(1, 0) - A(0, 1) - scale = (1.0 / ti.sqrt(x * x + y * y)) - c = x * scale - s = y * scale - r = ti.Matrix([[c, -s], [s, c]], dt=dt) - return r, r.transpose() @ A - - -@func -def polar_decompose3d(A, dt): - """Perform polar decomposition (A=UP) for 3x3 matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. - - Args: - A (ti.Matrix(3, 3)): input 3x3 matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed 3x3 matrices `U` and `P`. - """ - U, sig, V = ti.svd(A, dt) - return U @ V.transpose(), V @ sig @ V.transpose() - - -# https://www.seas.upenn.edu/~cffjiang/research/svd/svd.pdf -@func -def svd2d(A, dt): - """Perform singular value decomposition (A=USV^T) for 2x2 matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. - - Args: - A (ti.Matrix(2, 2)): input 2x2 matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed 2x2 matrices `U`, 'S' and `V`. - """ - R, S = polar_decompose2d(A, dt) - c, s = ti.cast(0.0, dt), ti.cast(0.0, dt) - s1, s2 = ti.cast(0.0, dt), ti.cast(0.0, dt) - if abs(S[0, 1]) < 1e-5: - c, s = 1, 0 - s1, s2 = S[0, 0], S[1, 1] - else: - tao = ti.cast(0.5, dt) * (S[0, 0] - S[1, 1]) - w = ti.sqrt(tao**2 + S[0, 1]**2) - t = ti.cast(0.0, dt) - if tao > 0: - t = S[0, 1] / (tao + w) - else: - t = S[0, 1] / (tao - w) - c = 1 / ti.sqrt(t**2 + 1) - s = -t * c - s1 = c**2 * S[0, 0] - 2 * c * s * S[0, 1] + s**2 * S[1, 1] - s2 = s**2 * S[0, 0] + 2 * c * s * S[0, 1] + c**2 * S[1, 1] - V = ti.Matrix.zero(dt, 2, 2) - if s1 < s2: - tmp = s1 - s1 = s2 - s2 = tmp - V = ti.Matrix([[-s, c], [-c, -s]], dt=dt) - else: - V = ti.Matrix([[c, s], [-s, c]], dt=dt) - U = R @ V - return U, ti.Matrix([[s1, ti.cast(0, dt)], [ti.cast(0, dt), s2]], dt=dt), V - - -def svd3d(A, dt, iters=None): - """Perform singular value decomposition (A=USV^T) for 3x3 matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. - - Args: - A (ti.Matrix(3, 3)): input 3x3 matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - iters (int): iteration number to control algorithm precision. - - Returns: - Decomposed 3x3 matrices `U`, 'S' and `V`. - """ - assert A.n == 3 and A.m == 3 - inputs = tuple([e.ptr for e in A.entries]) - assert dt in [ti.f32, ti.f64] - if iters is None: - if dt == ti.f32: - iters = 5 - else: - iters = 8 - if dt == ti.f32: - rets = _ti_core.sifakis_svd_f32(*inputs, iters) - else: - rets = _ti_core.sifakis_svd_f64(*inputs, iters) - assert len(rets) == 21 - U_entries = rets[:9] - V_entries = rets[9:18] - sig_entries = rets[18:] - U = expr_init(ti.Matrix.zero(dt, 3, 3)) - V = expr_init(ti.Matrix.zero(dt, 3, 3)) - sigma = expr_init(ti.Matrix.zero(dt, 3, 3)) - for i in range(3): - for j in range(3): - U(i, j).assign(U_entries[i * 3 + j]) - V(i, j).assign(V_entries[i * 3 + j]) - sigma(i, i).assign(sig_entries[i]) - return U, sigma, V - - -@func -def eig2x2(A, dt): - """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. - - Args: - A (ti.Matrix(2, 2)): input 2x2 matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - eigenvalues (ti.Matrix(2, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part. - eigenvectors: (ti.Matrix(4, 2)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of 2 entries, each of which is represented by two numbers for its real part and imaginary part. - """ - tr = A.trace() - det = A.determinant() - gap = tr**2 - 4 * det - lambda1 = ti.Vector.zero(dt, 2) - lambda2 = ti.Vector.zero(dt, 2) - v1 = ti.Vector.zero(dt, 4) - v2 = ti.Vector.zero(dt, 4) - if gap > 0: - lambda1 = ti.Vector([tr + ti.sqrt(gap), 0.0], dt=dt) * 0.5 - lambda2 = ti.Vector([tr - ti.sqrt(gap), 0.0], dt=dt) * 0.5 - A1 = A - lambda1[0] * ti.Matrix.identity(dt, 2) - A2 = A - lambda2[0] * ti.Matrix.identity(dt, 2) - if all(A1 == ti.Matrix.zero(dt, 2, 2)) and all( - A1 == ti.Matrix.zero(dt, 2, 2)): - v1 = ti.Vector([0.0, 0.0, 1.0, 0.0]).cast(dt) - v2 = ti.Vector([1.0, 0.0, 0.0, 0.0]).cast(dt) - else: - v1 = ti.Vector([A2[0, 0], 0.0, A2[1, 0], 0.0], dt=dt).normalized() - v2 = ti.Vector([A1[0, 0], 0.0, A1[1, 0], 0.0], dt=dt).normalized() - else: - lambda1 = ti.Vector([tr, ti.sqrt(-gap)], dt=dt) * 0.5 - lambda2 = ti.Vector([tr, -ti.sqrt(-gap)], dt=dt) * 0.5 - A1r = A - lambda1[0] * ti.Matrix.identity(dt, 2) - A1i = -lambda1[1] * ti.Matrix.identity(dt, 2) - A2r = A - lambda2[0] * ti.Matrix.identity(dt, 2) - A2i = -lambda2[1] * ti.Matrix.identity(dt, 2) - v1 = ti.Vector([A2r[0, 0], A2i[0, 0], A2r[1, 0], A2i[1, 0]], - dt=dt).normalized() - v2 = ti.Vector([A1r[0, 0], A1i[0, 0], A1r[1, 0], A1i[1, 0]], - dt=dt).normalized() - eigenvalues = ti.Matrix.rows([lambda1, lambda2]) - eigenvectors = ti.Matrix.cols([v1, v2]) - - return eigenvalues, eigenvectors - - -@func -def sym_eig2x2(A, dt): - """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real symmetric matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. - - Args: - A (ti.Matrix(2, 2)): input 2x2 symmetric matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - eigenvalues (ti.Vector(2)): The eigenvalues. Each entry store one eigen value. - eigenvectors (ti.Matrix(2, 2)): The eigenvectors. Each column stores one eigenvector. - """ - tr = A.trace() - det = A.determinant() - gap = tr**2 - 4 * det - lambda1 = (tr + ti.sqrt(gap)) * 0.5 - lambda2 = (tr - ti.sqrt(gap)) * 0.5 - eigenvalues = ti.Vector([lambda1, lambda2], dt=dt) - - A1 = A - lambda1 * ti.Matrix.identity(dt, 2) - A2 = A - lambda2 * ti.Matrix.identity(dt, 2) - v1 = ti.Vector.zero(dt, 2) - v2 = ti.Vector.zero(dt, 2) - if all(A1 == ti.Matrix.zero(dt, 2, 2)) and all( - A1 == ti.Matrix.zero(dt, 2, 2)): - v1 = ti.Vector([0.0, 1.0]).cast(dt) - v2 = ti.Vector([1.0, 0.0]).cast(dt) - else: - v1 = ti.Vector([A2[0, 0], A2[1, 0]], dt=dt).normalized() - v2 = ti.Vector([A1[0, 0], A1[1, 0]], dt=dt).normalized() - eigenvectors = ti.Matrix.cols([v1, v2]) - return eigenvalues, eigenvectors - - -@func -def svd(A, dt): - """Perform singular value decomposition (A=USV^T) for arbitrary size matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. - 2D implementation refers to :func:`taichi.lang.linalg_impl.svd2d`. - 3D implementation refers to :func:`taichi.lang.linalg_impl.svd3d`. - - Args: - A (ti.Matrix(n, n)): input nxn matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed nxn matrices `U`, 'S' and `V`. - """ - if ti.static(A.n == 2): - ret = svd2d(A, dt) - return ret - elif ti.static(A.n == 3): - return svd3d(A, dt) - else: - raise Exception("SVD only supports 2D and 3D matrices.") - - -@func -def polar_decompose(A, dt): - """Perform polar decomposition (A=UP) for arbitrary size matrix. - - Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. - 2D implementation refers to :func:`taichi.lang.linalg_impl.polar_decompose2d`. - 3D implementation refers to :func:`taichi.lang.linalg_impl.polar_decompose3d`. - - Args: - A (ti.Matrix(n, n)): input nxn matrix `A`. - dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. - - Returns: - Decomposed nxn matrices `U` and `P`. - """ - if ti.static(A.n == 2): - ret = polar_decompose2d(A, dt) - return ret - elif ti.static(A.n == 3): - return polar_decompose3d(A, dt) - else: - raise Exception( - "Polar decomposition only supports 2D and 3D matrices.") diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index ec27c56bce229..02efa2e6aedbd 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1,47 +1,38 @@ -import copy import numbers from collections.abc import Iterable import numpy as np -import taichi.lang +from taichi._lib import core as ti_core from taichi.lang import expr, impl -from taichi.lang import kernel_impl as kern_mod from taichi.lang import ops as ops_mod +from taichi.lang import runtime_ops from taichi.lang._ndarray import Ndarray, NdarrayHostAccess from taichi.lang.common_ops import TaichiOperations from taichi.lang.enums import Layout from taichi.lang.exception import TaichiSyntaxError from taichi.lang.field import Field, ScalarField, SNodeHostAccess -from taichi.lang.ops import cast -from taichi.lang.types import CompoundType -from taichi.lang.util import (cook_dtype, in_python_scope, is_taichi_class, - python_scope, taichi_scope, to_numpy_type, - to_pytorch_type) -from taichi.misc.util import deprecated, warning - -import taichi as ti +from taichi.lang.util import (cook_dtype, in_python_scope, python_scope, + taichi_scope, to_numpy_type, to_pytorch_type, + warning) +from taichi.types import primitive_types +from taichi.types.compound_types import CompoundType class Matrix(TaichiOperations): """The matrix class. Args: - n (int): the first dimension of a matrix. + n (Union[int, list, tuple, np.ndarray]): the first dimension of a matrix. m (int): the second dimension of a matrix. - dt (DataType): the elmement data type. - keep_raw (Bool, optional): Keep the contents in `n` as is. + dt (DataType): the element data type. """ - is_taichi_class = True - - def __init__(self, - n=1, - m=1, - dt=None, - keep_raw=False, - disable_local_tensor=False): + _is_taichi_class = True + + def __init__(self, n=1, m=1, dt=None, suppress_warning=False): self.local_tensor_proxy = None self.any_array_access = None self.grad = None + self.dynamic_index_stride = None if isinstance(n, (list, tuple, np.ndarray)): if len(n) == 0: @@ -49,60 +40,69 @@ def __init__(self, elif isinstance(n[0], Matrix): raise Exception( 'cols/rows required when using list of vectors') - elif not isinstance(n[0], Iterable): - if impl.inside_kernel(): - # wrap potential constants with Expr - if keep_raw: - mat = [list([x]) for x in n] - else: - if in_python_scope( - ) or disable_local_tensor or not ti.current_cfg( - ).dynamic_index: - mat = [list([expr.Expr(x)]) for x in n] - else: - if not ti.is_extension_supported( - ti.cfg.arch, ti.extension.dynamic_index): - raise Exception( - 'Backend ' + str(ti.cfg.arch) + - ' doesn\'t support dynamic index') - if dt is None: - if isinstance(n[0], int): - dt = impl.get_runtime().default_ip - elif isinstance(n[0], float): - dt = impl.get_runtime().default_fp - else: - raise Exception( - 'dt required when using dynamic_index for local tensor' - ) - self.local_tensor_proxy = impl.expr_init_local_tensor( - [len(n)], dt, - expr.make_expr_group([expr.Expr(x) - for x in n])) - mat = [] - for i in range(len(n)): - mat.append( - list([ - ti.local_subscript_with_offset( - self.local_tensor_proxy, (i, ), - (len(n), )) - ])) - else: + elif not isinstance(n[0], Iterable): # now init a Vector + if in_python_scope(): mat = [[x] for x in n] - else: - if in_python_scope( - ) or disable_local_tensor or not ti.current_cfg( - ).dynamic_index: - mat = [list(r) for r in n] + elif not impl.current_cfg().dynamic_index: + mat = [[impl.expr_init(x)] for x in n] else: - if not ti.is_extension_supported( - ti.cfg.arch, ti.extension.dynamic_index): - raise Exception('Backend ' + str(ti.cfg.arch) + - ' doesn\'t support dynamic index') + if not ti_core.is_extension_supported( + impl.current_cfg().arch, + ti_core.Extension.dynamic_index): + raise Exception( + f"Backend {impl.current_cfg().arch} doesn't support dynamic index" + ) if dt is None: - if isinstance(n[0][0], int): + if isinstance(n[0], (int, np.integer)): + dt = impl.get_runtime().default_ip + elif isinstance(n[0], float): + dt = impl.get_runtime().default_fp + elif isinstance(n[0], expr.Expr): + dt = n[0].ptr.get_ret_type() + if dt == ti_core.DataType_unknown: + raise TypeError( + 'Element type of the matrix cannot be inferred. Please set dt instead for now.' + ) + else: + raise Exception( + 'dt required when using dynamic_index for local tensor' + ) + self.local_tensor_proxy = impl.expr_init_local_tensor( + [len(n)], dt, + expr.make_expr_group([expr.Expr(x) for x in n])) + self.dynamic_index_stride = 1 + mat = [] + for i in range(len(n)): + mat.append( + list([ + impl.make_tensor_element_expr( + self.local_tensor_proxy, (expr.Expr( + i, dtype=primitive_types.i32), ), + (len(n), ), self.dynamic_index_stride) + ])) + else: # now init a Matrix + if in_python_scope(): + mat = [list(row) for row in n] + elif not impl.current_cfg().dynamic_index: + mat = [[impl.expr_init(x) for x in row] for row in n] + else: + if not ti_core.is_extension_supported( + impl.current_cfg().arch, + ti_core.Extension.dynamic_index): + raise Exception( + f"Backend {impl.current_cfg().arch} doesn't support dynamic index" + ) + if dt is None: + if isinstance(n[0][0], (int, np.integer)): dt = impl.get_runtime().default_ip elif isinstance(n[0][0], float): dt = impl.get_runtime().default_fp + elif isinstance(n[0][0], expr.Expr): + dt = n[0][0].ptr.get_ret_type() + if dt == ti_core.DataType_unknown: + raise TypeError( + 'Element type of the matrix cannot be inferred. Please set dt instead for now.' + ) else: raise Exception( 'dt required when using dynamic_index for local tensor' @@ -111,14 +111,18 @@ def __init__(self, [len(n), len(n[0])], dt, expr.make_expr_group( [expr.Expr(x) for row in n for x in row])) + self.dynamic_index_stride = 1 mat = [] for i in range(len(n)): mat.append([]) for j in range(len(n[0])): mat[i].append( - ti.local_subscript_with_offset( - self.local_tensor_proxy, (i, j), - (len(n), len(n[0])))) + impl.make_tensor_element_expr( + self.local_tensor_proxy, + (expr.Expr(i, dtype=primitive_types.i32), + expr.Expr(j, dtype=primitive_types.i32)), + (len(n), len(n[0])), + self.dynamic_index_stride)) self.n = len(mat) if len(mat) > 0: self.m = len(mat[0]) @@ -134,10 +138,10 @@ def __init__(self, self.m = m else: raise ValueError( - "Declaring matrix fields using `ti.Matrix(n, m, dt, shape)` is no longer supported. Use `ti.Matrix.field(n, m, dtype, shape)` instead." - ) + "Declaring matrix fields using `ti.Matrix(n, m, dt, shape)` is no longer supported. " + "Use `ti.Matrix.field(n, m, dtype, shape)` instead.") - if self.n * self.m > 32: + if self.n * self.m > 32 and not suppress_warning: warning( f'Taichi matrices/vectors with {self.n}x{self.m} > 32 entries are not suggested.' ' Matrices/vectors will be automatically unrolled at compile-time for performance.' @@ -149,65 +153,42 @@ def __init__(self, UserWarning, stacklevel=2) - def element_wise_binary(self, foo, other): - _taichi_skip_traceback = 1 - ret = self.empty_copy() - if isinstance(other, (list, tuple)): - other = Matrix(other) - if isinstance(other, Matrix): - assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i], other.entries[i]) - else: # assumed to be scalar - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i], other) - return ret + def _element_wise_binary(self, foo, other): + other = self._broadcast_copy(other) + return Matrix([[foo(self(i, j), other(i, j)) for j in range(self.m)] + for i in range(self.n)]) - def broadcast_copy(self, other): + def _broadcast_copy(self, other): if isinstance(other, (list, tuple)): other = Matrix(other) if not isinstance(other, Matrix): - ret = self.empty_copy() - ret.entries = [other for _ in ret.entries] - other = ret + other = Matrix([[other for _ in range(self.m)] + for _ in range(self.n)]) assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" return other - def element_wise_ternary(self, foo, other, extra): - ret = self.empty_copy() - other = self.broadcast_copy(other) - extra = self.broadcast_copy(extra) - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i], other.entries[i], - extra.entries[i]) - return ret + def _element_wise_ternary(self, foo, other, extra): + other = self._broadcast_copy(other) + extra = self._broadcast_copy(extra) + return Matrix([[ + foo(self(i, j), other(i, j), extra(i, j)) for j in range(self.m) + ] for i in range(self.n)]) - def element_wise_writeback_binary(self, foo, other): - ret = self.empty_copy() - if isinstance(other, (list, tuple)): - other = Matrix(other) - if is_taichi_class(other): - other = other.variable() - if foo.__name__ == 'assign' and not isinstance(other, Matrix): + def _element_wise_writeback_binary(self, foo, other): + if foo.__name__ == 'assign' and not isinstance(other, + (list, tuple, Matrix)): raise TaichiSyntaxError( 'cannot assign scalar expr to ' f'taichi class {type(self)}, maybe you want to use `a.fill(b)` instead?' ) - if isinstance(other, Matrix): - assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i], other.entries[i]) - else: # assumed to be scalar - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i], other) - return ret + other = self._broadcast_copy(other) + entries = [[foo(self(i, j), other(i, j)) for j in range(self.m)] + for i in range(self.n)] + return self if foo.__name__ == 'assign' else Matrix(entries) - def element_wise_unary(self, foo): - _taichi_skip_traceback = 1 - ret = self.empty_copy() - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i]) - return ret + def _element_wise_unary(self, foo): + return Matrix([[foo(self(i, j)) for j in range(self.m)] + for i in range(self.n)]) def __matmul__(self, other): """Matrix-matrix or matrix-vector multiply. @@ -219,26 +200,24 @@ def __matmul__(self, other): The matrix-matrix product or matrix-vector product. """ - _taichi_skip_traceback = 1 assert isinstance(other, Matrix), "rhs of `@` is not a matrix / vector" assert self.m == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" - del _taichi_skip_traceback - ret = Matrix.new(self.n, other.m) + entries = [] for i in range(self.n): + entries.append([]) for j in range(other.m): acc = self(i, 0) * other(0, j) for k in range(1, other.n): acc = acc + self(i, k) * other(k, j) - ret.set_entry(i, j, acc) - return ret + entries[i].append(acc) + return Matrix(entries) - def linearize_entry_id(self, *args): + def _linearize_entry_id(self, *args): assert 1 <= len(args) <= 2 if len(args) == 1 and isinstance(args[0], (list, tuple)): args = args[0] if len(args) == 1: args = args + (0, ) - _taichi_skip_traceback = 1 # TODO(#1004): See if it's possible to support indexing at runtime for i, a in enumerate(args): if not isinstance(a, int): @@ -253,25 +232,24 @@ def linearize_entry_id(self, *args): 'See https://docs.taichi.graphics/lang/articles/advanced/meta#when-to-use-for-loops-with-tistatic for more details.' ) assert 0 <= args[0] < self.n, \ - f"The 0-th matrix index is out of range: 0 <= {args[0]} < {self.n}" + f"The 0-th matrix index is out of range: 0 <= {args[0]} < {self.n}" assert 0 <= args[1] < self.m, \ - f"The 1-th matrix index is out of range: 0 <= {args[1]} < {self.m}" + f"The 1-th matrix index is out of range: 0 <= {args[1]} < {self.m}" return args[0] * self.m + args[1] def __call__(self, *args, **kwargs): - _taichi_skip_traceback = 1 assert kwargs == {} - ret = self.entries[self.linearize_entry_id(*args)] + ret = self.entries[self._linearize_entry_id(*args)] if isinstance(ret, SNodeHostAccess): ret = ret.accessor.getter(*ret.key) elif isinstance(ret, NdarrayHostAccess): ret = ret.getter() return ret - def set_entry(self, i, j, e): - idx = self.linearize_entry_id(i, j) + def _set_entry(self, i, j, e): + idx = self._linearize_entry_id(i, j) if impl.inside_kernel(): - self.entries[idx].assign(e) + self.entries[idx]._assign(e) else: if isinstance(self.entries[idx], SNodeHostAccess): self.entries[idx].accessor.setter(e, *self.entries[idx].key) @@ -280,106 +258,103 @@ def set_entry(self, i, j, e): else: self.entries[idx] = e + def _get_slice(self, a, b): + if not isinstance(a, slice): + a = [a] + else: + a = range(a.start or 0, a.stop or self.n, a.step or 1) + if not isinstance(b, slice): + b = [b] + else: + b = range(b.start or 0, b.stop or self.m, b.step or 1) + return Matrix([[self(i, j) for j in b] for i in a]) + @taichi_scope - def subscript(self, *indices): - _taichi_skip_traceback = 1 + def _subscript(self, *indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] + if isinstance(i, slice) or isinstance(j, slice): + for a in (i, j): + if isinstance(a, slice): + if isinstance(a.start, expr.Expr) or isinstance( + a.step, expr.Expr) or isinstance( + a.stop, expr.Expr): + raise TaichiSyntaxError( + "The element type of slice of Matrix/Vector index must be a compile-time constant integer!" + ) + return self._get_slice(i, j) if self.any_array_access: return self.any_array_access.subscript(i, j) - elif self.local_tensor_proxy is not None: + if self.local_tensor_proxy is not None: + assert self.dynamic_index_stride is not None if len(indices) == 1: - return ti.local_subscript_with_offset(self.local_tensor_proxy, - (i, ), (self.n, )) - else: - return ti.local_subscript_with_offset(self.local_tensor_proxy, - (i, j), (self.n, self.m)) - # ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()). - elif isinstance(self.entries[0], - ti.Expr) and self.entries[0].ptr.is_global_ptr( - ) and ti.current_cfg().dynamic_index: - # TODO: Add API to query whether AOS or SOA - return ti.global_subscript_with_offset(self.entries[0], (i, j), - (self.n, self.m), True) - else: - return self(i, j) + return impl.make_tensor_element_expr(self.local_tensor_proxy, + (i, ), (self.n, ), + self.dynamic_index_stride) + return impl.make_tensor_element_expr(self.local_tensor_proxy, + (i, j), (self.n, self.m), + self.dynamic_index_stride) + if impl.current_cfg().dynamic_index and isinstance( + self, + _MatrixFieldElement) and self.dynamic_index_stride is not None: + return impl.make_tensor_element_expr(self.entries[0].ptr, (i, j), + (self.n, self.m), + self.dynamic_index_stride) + return self(i, j) @property def x(self): """Get the first element of a matrix.""" - _taichi_skip_traceback = 1 if impl.inside_kernel(): - return self.subscript(0) - else: - return self[0] + return self._subscript(0) + return self[0] @property def y(self): """Get the second element of a matrix.""" - _taichi_skip_traceback = 1 if impl.inside_kernel(): - return self.subscript(1) - else: - return self[1] + return self._subscript(1) + return self[1] @property def z(self): """Get the third element of a matrix.""" - _taichi_skip_traceback = 1 if impl.inside_kernel(): - return self.subscript(2) - else: - return self[2] + return self._subscript(2) + return self[2] @property def w(self): """Get the fourth element of a matrix.""" - _taichi_skip_traceback = 1 if impl.inside_kernel(): - return self.subscript(3) - else: - return self[3] + return self._subscript(3) + return self[3] # since Taichi-scope use v.x.assign() instead @x.setter @python_scope def x(self, value): - _taichi_skip_traceback = 1 self[0] = value @y.setter @python_scope def y(self, value): - _taichi_skip_traceback = 1 self[1] = value @z.setter @python_scope def z(self, value): - _taichi_skip_traceback = 1 self[2] = value @w.setter @python_scope def w(self, value): - _taichi_skip_traceback = 1 self[3] = value - @property - @python_scope - def value(self): - if isinstance(self.entries[0], SNodeHostAccess): - # fetch values from SNodeHostAccessor - ret = self.empty_copy() - for i in range(self.n): - for j in range(self.m): - ret.entries[i * self.m + j] = self(i, j) - else: - # is local python-scope matrix - ret = self.entries - return ret + def to_list(self): + return [[self(i, j) for j in range(self.m)] for i in range(self.n)] # host access & python scope operation @python_scope @@ -398,6 +373,8 @@ def __getitem__(self, indices): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] + if isinstance(i, slice) or isinstance(j, slice): + return self._get_slice(i, j) return self(i, j) @python_scope @@ -413,7 +390,7 @@ def __setitem__(self, indices, item): assert len(indices) in [1, 2] i = indices[0] j = 0 if len(indices) == 1 else indices[1] - self.set_entry(i, j, item) + self._set_entry(i, j, item) def __len__(self): """Get the length of each row of a matrix""" @@ -422,11 +399,10 @@ def __len__(self): def __iter__(self): if self.m == 1: return (self(i) for i in range(self.n)) - else: - return ([self(i, j) for j in range(self.m)] for i in range(self.n)) + return ([self(i, j) for j in range(self.m)] for i in range(self.n)) @python_scope - def set_entries(self, value): + def _set_entries(self, value): if not isinstance(value, (list, tuple)): value = list(value) if not isinstance(value[0], (list, tuple)): @@ -435,20 +411,6 @@ def set_entries(self, value): for j in range(self.m): self[i, j] = value[i][j] - def empty_copy(self): - return Matrix.empty(self.n, self.m) - - def copy(self): - ret = self.empty_copy() - ret.entries = copy.copy(self.entries) - return ret - - @taichi_scope - def variable(self): - ret = self.copy() - ret.entries = [impl.expr_init(e) for e in ret.entries] - return ret - @taichi_scope def cast(self, dtype): """Cast the matrix element data type. @@ -460,11 +422,9 @@ def cast(self, dtype): A new matrix with each element's type is dtype. """ - _taichi_skip_traceback = 1 - ret = self.copy() - for i in range(len(self.entries)): - ret.entries[i] = ops_mod.cast(ret.entries[i], dtype) - return ret + return Matrix( + [[ops_mod.cast(self(i, j), dtype) for j in range(self.m)] + for i in range(self.n)]) def trace(self): """The sum of a matrix diagonal elements. @@ -474,10 +434,10 @@ def trace(self): """ assert self.n == self.m - sum = self(0, 0) + _sum = self(0, 0) for i in range(1, self.n): - sum = sum + self(i, i) - return sum + _sum = _sum + self(i, i) + return _sum @taichi_scope def inverse(self): @@ -495,14 +455,12 @@ def inverse(self): """ assert self.n == self.m, 'Only square matrices are invertible' if self.n == 1: - return Matrix([1 / self(0, 0)], disable_local_tensor=True) - elif self.n == 2: - inv_det = impl.expr_init(1.0 / self.determinant()) - # Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323 - return inv_det * Matrix([[self(1, 1), -self(0, 1)], - [-self(1, 0), self(0, 0)]], - disable_local_tensor=True).variable() - elif self.n == 3: + return Matrix([1 / self(0, 0)]) + if self.n == 2: + inv_determinant = impl.expr_init(1.0 / self.determinant()) + return inv_determinant * Matrix([[self( + 1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]]) + if self.n == 3: n = 3 inv_determinant = impl.expr_init(1.0 / self.determinant()) entries = [[0] * n for _ in range(n)] @@ -512,11 +470,11 @@ def E(x, y): for i in range(n): for j in range(n): - entries[j][i] = impl.expr_init( - inv_determinant * (E(i + 1, j + 1) * E(i + 2, j + 2) - - E(i + 2, j + 1) * E(i + 1, j + 2))) - return Matrix(entries, disable_local_tensor=True) - elif self.n == 4: + entries[j][i] = inv_determinant * ( + E(i + 1, j + 1) * E(i + 2, j + 2) - + E(i + 2, j + 1) * E(i + 1, j + 2)) + return Matrix(entries) + if self.n == 4: n = 4 inv_determinant = impl.expr_init(1.0 / self.determinant()) entries = [[0] * n for _ in range(n)] @@ -526,25 +484,18 @@ def E(x, y): for i in range(n): for j in range(n): - entries[j][i] = impl.expr_init( - inv_determinant * (-1)**(i + j) * - ((E(i + 1, j + 1) * - (E(i + 2, j + 2) * E(i + 3, j + 3) - - E(i + 3, j + 2) * E(i + 2, j + 3)) - - E(i + 2, j + 1) * - (E(i + 1, j + 2) * E(i + 3, j + 3) - - E(i + 3, j + 2) * E(i + 1, j + 3)) + - E(i + 3, j + 1) * - (E(i + 1, j + 2) * E(i + 2, j + 3) - - E(i + 2, j + 2) * E(i + 1, j + 3))))) - return Matrix(entries, disable_local_tensor=True) - else: - raise Exception( - "Inversions of matrices with sizes >= 5 are not supported") - - inversed = deprecated('a.inversed()', 'a.inverse()')(inverse) + entries[j][i] = inv_determinant * (-1)**(i + j) * (( + E(i + 1, j + 1) * + (E(i + 2, j + 2) * E(i + 3, j + 3) - + E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) * + (E(i + 1, j + 2) * E(i + 3, j + 3) - + E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) * + (E(i + 1, j + 2) * E(i + 2, j + 3) - + E(i + 2, j + 2) * E(i + 1, j + 3)))) + return Matrix(entries) + raise Exception( + "Inversions of matrices with sizes >= 5 are not supported") - @kern_mod.pyfunc def normalized(self, eps=0): """Normalize a vector. @@ -567,16 +518,6 @@ def normalized(self, eps=0): invlen = 1 / (self.norm() + eps) return invlen * self - @staticmethod - @deprecated('ti.Matrix.transposed(a)', 'a.transpose()') - def transposed(a): - return a.transpose() - - @deprecated('a.T()', 'a.transpose()') - def T(self): - return self.transpose() - - @kern_mod.pyfunc def transpose(self): """Get the transpose of a matrix. @@ -584,10 +525,8 @@ def transpose(self): Get the transpose of a matrix. """ - ret = Matrix([[self[i, j] for i in range(self.n)] - for j in range(self.m)], - disable_local_tensor=True) - return ret + from taichi._funcs import _matrix_transpose # pylint: disable=C0415 + return _matrix_transpose(self) @taichi_scope def determinant(a): @@ -605,11 +544,11 @@ def determinant(a): """ if a.n == 2 and a.m == 2: return a(0, 0) * a(1, 1) - a(0, 1) * a(1, 0) - elif a.n == 3 and a.m == 3: + if a.n == 3 and a.m == 3: return a(0, 0) * (a(1, 1) * a(2, 2) - a(2, 1) * a(1, 2)) - a( 1, 0) * (a(0, 1) * a(2, 2) - a(2, 1) * a(0, 2)) + a( 2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2)) - elif a.n == 4 and a.m == 4: + if a.n == 4 and a.m == 4: n = 4 def E(x, y): @@ -626,9 +565,8 @@ def E(x, y): E(i + 3, 1) * (E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3)))) return det - else: - raise Exception( - "Determinants of matrices with sizes >= 5 are not supported") + raise Exception( + "Determinants of matrices with sizes >= 5 are not supported") @staticmethod def diag(dim, val): @@ -636,7 +574,7 @@ def diag(dim, val): Args: dim (int): the dimension of a square matrix. - val (TypeVar): the diagonal elment value. + val (TypeVar): the diagonal element value. Returns: The constructed diagonal square matrix. @@ -646,9 +584,9 @@ def diag(dim, val): for i in range(dim): for j in range(dim): if i == j: - ret.set_entry(i, j, val) + ret._set_entry(i, j, val) else: - ret.set_entry(i, j, 0 * val) + ret._set_entry(i, j, 0 * val) # TODO: need a more systematic way to create a "0" with the right type return ret @@ -659,7 +597,6 @@ def sum(self): ret = ret + self.entries[i] return ret - @kern_mod.pyfunc def norm(self, eps=0): """Return the square root of the sum of the absolute squares of its elements. @@ -678,7 +615,6 @@ def norm(self, eps=0): """ return ops_mod.sqrt(self.norm_sqr() + eps) - @kern_mod.pyfunc def norm_inv(self, eps=0): """Return the inverse of the matrix/vector `norm`. For `norm`: please see :func:`~taichi.lang.matrix.Matrix.norm`. @@ -691,20 +627,17 @@ def norm_inv(self, eps=0): """ return ops_mod.rsqrt(self.norm_sqr() + eps) - @kern_mod.pyfunc def norm_sqr(self): """Return the sum of the absolute squares of its elements.""" - return (self**2).sum() + return (self * self).sum() - @kern_mod.pyfunc def max(self): """Return the maximum element value.""" - return ops_mod.ti_max(*self.entries) + return ops_mod.max(*self.entries) - @kern_mod.pyfunc def min(self): - """Return the minumum element value.""" - return ops_mod.ti_min(*self.entries) + """Return the minimum element value.""" + return ops_mod.min(*self.entries) def any(self): """Test whether any element not equal zero. @@ -713,10 +646,10 @@ def any(self): bool: True if any element is not equal zero, False otherwise. """ - ret = ti.cmp_ne(self.entries[0], 0) + ret = ops_mod.cmp_ne(self.entries[0], 0) for i in range(1, len(self.entries)): - ret = ret + ti.cmp_ne(self.entries[i], 0) - return -ti.cmp_lt(ret, 0) + ret = ret + ops_mod.cmp_ne(self.entries[i], 0) + return -ops_mod.cmp_lt(ret, 0) def all(self): """Test whether all element not equal zero. @@ -725,10 +658,10 @@ def all(self): bool: True if all elements are not equal zero, False otherwise. """ - ret = ti.cmp_ne(self.entries[0], 0) + ret = ops_mod.cmp_ne(self.entries[0], 0) for i in range(1, len(self.entries)): - ret = ret + ti.cmp_ne(self.entries[i], 0) - return -ti.cmp_eq(ret, -len(self.entries)) + ret = ret + ops_mod.cmp_ne(self.entries[i], 0) + return -ops_mod.cmp_eq(ret, -len(self.entries)) @taichi_scope def fill(self, val): @@ -738,9 +671,9 @@ def fill(self, val): val (Union[int, float]): Value to fill. """ def assign_renamed(x, y): - return ti.assign(x, y) + return ops_mod.assign(x, y) - return self.element_wise_writeback_binary(assign_renamed, val) + return self._element_wise_writeback_binary(assign_renamed, val) @python_scope def to_numpy(self, keep_dims=False): @@ -755,7 +688,7 @@ def to_numpy(self, keep_dims=False): """ as_vector = self.m == 1 and not keep_dims shape_ext = (self.n, ) if as_vector else (self.n, self.m) - return np.array(self.value).reshape(shape_ext) + return np.array(self.to_list()).reshape(shape_ext) @taichi_scope def __ti_repr__(self): @@ -788,8 +721,7 @@ def __str__(self): So we have to make it happy with a dummy string... ''' return f'<{self.n}x{self.m} ti.Matrix>' - else: - return str(self.to_numpy()) + return str(self.to_numpy()) def __repr__(self): return str(self.to_numpy()) @@ -809,12 +741,9 @@ def zero(dt, n, m=None): """ if m is None: - return Vector([ti.cast(0, dt) for _ in range(n)], - disable_local_tensor=True) - else: - return Matrix([[ti.cast(0, dt) for _ in range(m)] - for _ in range(n)], - disable_local_tensor=True) + return Vector([ops_mod.cast(0, dt) for _ in range(n)]) + return Matrix([[ops_mod.cast(0, dt) for _ in range(m)] + for _ in range(n)]) @staticmethod @taichi_scope @@ -831,12 +760,9 @@ def one(dt, n, m=None): """ if m is None: - return Vector([ti.cast(1, dt) for _ in range(n)], - disable_local_tensor=True) - else: - return Matrix([[ti.cast(1, dt) for _ in range(m)] - for _ in range(n)], - disable_local_tensor=True) + return Vector([ops_mod.cast(1, dt) for _ in range(n)]) + return Matrix([[ops_mod.cast(1, dt) for _ in range(m)] + for _ in range(n)]) @staticmethod @taichi_scope @@ -855,8 +781,7 @@ def unit(n, i, dt=None): if dt is None: dt = int assert 0 <= i < n - return Matrix([ti.cast(int(j == i), dt) for j in range(n)], - disable_local_tensor=True) + return Vector([ops_mod.cast(int(j == i), dt) for j in range(n)]) @staticmethod @taichi_scope @@ -871,14 +796,14 @@ def identity(dt, n): :class:`~taichi.lang.matrix.Matrix`: A n x n identity :class:`~taichi.lang.matrix.Matrix` instance. """ - return Matrix([[ti.cast(int(i == j), dt) for j in range(n)] - for i in range(n)], - disable_local_tensor=True) + return Matrix([[ops_mod.cast(int(i == j), dt) for j in range(n)] + for i in range(n)]) @staticmethod def rotation2d(alpha): - return Matrix([[ti.cos(alpha), -ti.sin(alpha)], - [ti.sin(alpha), ti.cos(alpha)]]) + return Matrix([[ops_mod.cos(alpha), -ops_mod.sin(alpha)], + [ops_mod.sin(alpha), + ops_mod.cos(alpha)]]) @classmethod @python_scope @@ -932,7 +857,8 @@ def field(cls, entries, entries_grad = zip(*entries) entries, entries_grad = MatrixField(entries, n, m), MatrixField( entries_grad, n, m) - entries.set_grad(entries_grad) + entries._set_grad(entries_grad) + impl.get_runtime().matrix_fields.append(entries) if shape is None: assert offset is None, "shape cannot be None when offset is being set" @@ -950,43 +876,27 @@ def field(cls, dim = len(shape) if layout == Layout.SOA: - for e in entries.get_field_members(): - ti.root.dense(impl.index_nd(dim), - shape).place(ScalarField(e), offset=offset) + for e in entries._get_field_members(): + impl.root.dense(impl.index_nd(dim), + shape).place(ScalarField(e), offset=offset) if needs_grad: - for e in entries_grad.get_field_members(): - ti.root.dense(impl.index_nd(dim), - shape).place(ScalarField(e), - offset=offset) + for e in entries_grad._get_field_members(): + impl.root.dense(impl.index_nd(dim), + shape).place(ScalarField(e), + offset=offset) else: - ti.root.dense(impl.index_nd(dim), shape).place(entries, - offset=offset) + impl.root.dense(impl.index_nd(dim), shape).place(entries, + offset=offset) if needs_grad: - ti.root.dense(impl.index_nd(dim), - shape).place(entries_grad, offset=offset) + impl.root.dense(impl.index_nd(dim), + shape).place(entries_grad, offset=offset) return entries - @classmethod - @python_scope - @deprecated('ti.Matrix.var', 'ti.Matrix.field') - def var(cls, n, m, dt, *args, **kwargs): - """ti.Matrix.var""" - _taichi_skip_traceback = 1 - return cls.field(n, m, dt, *args, **kwargs) - @classmethod def _Vector_field(cls, n, dtype, *args, **kwargs): """ti.Vector.field""" - _taichi_skip_traceback = 1 return cls.field(n, 1, dtype, *args, **kwargs) - @classmethod - @deprecated('ti.Vector.var', 'ti.Vector.field') - def _Vector_var(cls, n, dt, *args, **kwargs): - """ti.Vector.var""" - _taichi_skip_traceback = 1 - return cls._Vector_field(n, dt, *args, **kwargs) - @classmethod @python_scope def ndarray(cls, n, m, dtype, shape, layout=Layout.AOS): @@ -1030,7 +940,7 @@ def _Vector_ndarray(cls, n, dtype, shape, layout=Layout.AOS): @staticmethod def rows(rows): - """Construct a Matrix instance by concactinating Vectors/lists row by row. + """Construct a Matrix instance by concatenating Vectors/lists row by row. Args: rows (List): A list of Vector (1-D Matrix) or a list of list. @@ -1063,7 +973,7 @@ def rows(rows): @staticmethod def cols(cols): - """Construct a Matrix instance by concactinating Vectors/lists column by column. + """Construct a Matrix instance by concatenating Vectors/lists column by column. Args: cols (List): A list of Vector (1-D Matrix) or a list of list. @@ -1074,51 +984,12 @@ def cols(cols): """ return Matrix.rows(cols).transpose() - @classmethod - def empty(cls, n, m): - """Clear the matrix and fill None. - - Args: - n (int): The number of the row of the matrix. - m (int): The number of the column of the matrix. - - Returns: - :class:`~taichi.lang.matrix.Matrix`: A :class:`~taichi.lang.matrix.Matrix` instance filled with None. - - """ - return cls([[None] * m for _ in range(n)], disable_local_tensor=True) - - @classmethod - def with_entries(cls, n, m, entries): - """Construct a Matrix instance by giving all entries. - - Args: - n (int): Number of rows of the matrix. - m (int): Number of columns of the matrix. - entries (List[Any]): Given entries. - - Returns: - Matrix: A :class:`~taichi.lang.matrix.Matrix` instance filled with given entries. - """ - assert n * m == len(entries), "Number of entries doesn't match n * m" - mat = cls.empty(n, m) - mat.entries = entries - return mat - - @classmethod - def new(cls, n, m): - if impl.inside_kernel(): - return cls(n, m) - else: - return cls.empty(n, m) - def __hash__(self): # TODO: refactor KernelTemplateMapper # If not, we get `unhashable type: Matrix` when # using matrices as template arguments. return id(self) - @kern_mod.pyfunc def dot(self, other): """Perform the dot product with the input Vector (1-D Matrix). @@ -1135,20 +1006,13 @@ def dot(self, other): impl.static_assert(other.m == 1, "rhs for dot is not a vector")) return (self * other).sum() - @kern_mod.pyfunc def _cross3d(self, other): - ret = Matrix([ - self[1] * other[2] - self[2] * other[1], - self[2] * other[0] - self[0] * other[2], - self[0] * other[1] - self[1] * other[0], - ], - disable_local_tensor=True) - return ret + from taichi._funcs import _matrix_cross3d # pylint: disable=C0415 + return _matrix_cross3d(self, other) - @kern_mod.pyfunc def _cross2d(self, other): - ret = self[0] * other[1] - self[1] * other[0] - return ret + from taichi._funcs import _matrix_cross2d # pylint: disable=C0415 + return _matrix_cross2d(self, other) def cross(self, other): """Perform the cross product with the input Vector (1-D Matrix). @@ -1163,15 +1027,12 @@ def cross(self, other): if self.n == 3 and self.m == 1 and other.n == 3 and other.m == 1: return self._cross3d(other) - elif self.n == 2 and self.m == 1 and other.n == 2 and other.m == 1: + if self.n == 2 and self.m == 1 and other.n == 2 and other.m == 1: return self._cross2d(other) - else: - raise ValueError( - "Cross product is only supported between pairs of 2D/3D vectors" - ) + raise ValueError( + "Cross product is only supported between pairs of 2D/3D vectors") - @kern_mod.pyfunc def outer_product(self, other): """Perform the outer product with the input Vector (1-D Matrix). @@ -1182,23 +1043,16 @@ def outer_product(self, other): :class:`~taichi.lang.matrix.Matrix`: The outer product result (Matrix) of the two Vectors. """ - impl.static( - impl.static_assert(self.m == 1, - "lhs for outer_product is not a vector")) - impl.static( - impl.static_assert(other.m == 1, - "rhs for outer_product is not a vector")) - ret = Matrix([[self[i] * other[j] for j in range(other.n)] - for i in range(self.n)], - disable_local_tensor=True) - return ret + from taichi._funcs import \ + _matrix_outer_product # pylint: disable=C0415 + return _matrix_outer_product(self, other) def Vector(n, dt=None, **kwargs): """Construct a `Vector` instance i.e. 1-D Matrix. Args: - n (int): The desired number of entries of the Vector. + n (Union[int, list, tuple], np.ndarray): The desired number of entries of the Vector. dt (DataType, optional): The desired data type of the Vector. Returns: @@ -1208,7 +1062,6 @@ def Vector(n, dt=None, **kwargs): return Matrix(n, 1, dt=dt, **kwargs) -Vector.var = Matrix._Vector_var Vector.field = Matrix._Vector_field Vector.ndarray = Matrix._Vector_ndarray Vector.zero = Matrix.zero @@ -1220,6 +1073,41 @@ def Vector(n, dt=None, **kwargs): Vector.normalized = Matrix.normalized +class _IntermediateMatrix(Matrix): + """Intermediate matrix class for compiler internal use only. + + Args: + n (int): Number of rows of the matrix. + m (int): Number of columns of the matrix. + entries (List[Expr]): All entries of the matrix. + """ + def __init__(self, n, m, entries): + assert isinstance(entries, list) + assert n * m == len(entries), "Number of entries doesn't match n * m" + self.n = n + self.m = m + self.entries = entries + self.local_tensor_proxy = None + self.any_array_access = None + self.grad = None + self.dynamic_index_stride = None + + +class _MatrixFieldElement(_IntermediateMatrix): + """Matrix field element class for compiler internal use only. + + Args: + field (MatrixField): The matrix field. + indices (taichi_core.ExprGroup): Indices of the element. + """ + def __init__(self, field, indices): + super().__init__(field.n, field.m, [ + expr.Expr(ti_core.subscript(e.ptr, indices)) + for e in field._get_field_members() + ]) + self.dynamic_index_stride = field.dynamic_index_stride + + class MatrixField(Field): """Taichi matrix field with SNode implementation. @@ -1228,15 +1116,12 @@ class MatrixField(Field): n (Int): Number of rows. m (Int): Number of columns. """ - def __init__(self, vars, n, m): - assert len(vars) == n * m - super().__init__(vars) + def __init__(self, _vars, n, m): + assert len(_vars) == n * m + super().__init__(_vars) self.n = n self.m = m - - @deprecated('x(i, j)', 'x.get_scalar_field(i, j)') - def __call__(self, *indices): - return self.get_scalar_field(*indices) + self.dynamic_index_stride = None def get_scalar_field(self, *indices): """Creates a ScalarField using a specific field member. Only used for quant. @@ -1252,6 +1137,37 @@ def get_scalar_field(self, *indices): j = 0 if len(indices) == 1 else indices[1] return ScalarField(self.vars[i * self.m + j]) + def _calc_dynamic_index_stride(self): + # Algorithm: https://github.com/taichi-dev/taichi/issues/3810 + paths = [ScalarField(var).snode._path_from_root() for var in self.vars] + num_members = len(paths) + if num_members == 1: + self.dynamic_index_stride = 0 + return + length = len(paths[0]) + if any( + len(path) != length or ti_core.is_custom_type(path[length - + 1]._dtype) + for path in paths): + return + for i in range(length): + if any(path[i] != paths[0][i] for path in paths): + depth_below_lca = i + break + for i in range(depth_below_lca, length - 1): + if any(path[i].ptr.type != ti_core.SNodeType.dense + or path[i]._cell_size_bytes != paths[0][i]._cell_size_bytes + or path[i + 1]._offset_bytes_in_parent_cell != paths[0][ + i + 1]._offset_bytes_in_parent_cell for path in paths): + return + stride = paths[1][depth_below_lca]._offset_bytes_in_parent_cell - \ + paths[0][depth_below_lca]._offset_bytes_in_parent_cell + for i in range(2, num_members): + if stride != paths[i][depth_below_lca]._offset_bytes_in_parent_cell \ + - paths[i - 1][depth_below_lca]._offset_bytes_in_parent_cell: + return + self.dynamic_index_stride = stride + @python_scope def fill(self, val): """Fills `self` with specific values. @@ -1266,7 +1182,7 @@ def fill(self, val): (list, tuple)) and isinstance(val[0], numbers.Number): assert self.m == 1 val = tuple([(v, ) for v in val]) - elif isinstance(val, ti.Matrix): + elif isinstance(val, Matrix): val_tuple = [] for i in range(val.n): row = [] @@ -1277,10 +1193,11 @@ def fill(self, val): val = tuple(val_tuple) assert len(val) == self.n assert len(val[0]) == self.m - taichi.lang.meta.fill_matrix(self, val) + from taichi._kernels import fill_matrix # pylint: disable=C0415 + fill_matrix(self, val) @python_scope - def to_numpy(self, keep_dims=False, as_vector=None, dtype=None): + def to_numpy(self, keep_dims=False, dtype=None): """Converts the field instance to a NumPy array. Args: @@ -1288,28 +1205,19 @@ def to_numpy(self, keep_dims=False, as_vector=None, dtype=None): When keep_dims=True, on an n-D matrix field, the numpy array always has n+2 dims, even for 1x1, 1xn, nx1 matrix fields. When keep_dims=False, the resulting numpy array should skip the matrix dims with size 1. For example, a 4x1 or 1x4 matrix field with 5x6x7 elements results in an array of shape 5x6x7x4. - as_vector (bool, deprecated): Whether to make the returned numpy array as a vector, i.e., with shape (n,) rather than (n, 1). - Note that this argument has been deprecated. - More discussion about `as_vector`: https://github.com/taichi-dev/taichi/pull/1046#issuecomment-633548858. dtype (DataType, optional): The desired data type of returned numpy array. Returns: numpy.ndarray: The result NumPy array. """ - if as_vector is not None: - warning( - 'v.to_numpy(as_vector=True) is deprecated, ' - 'please use v.to_numpy() directly instead', - DeprecationWarning, - stacklevel=3) if dtype is None: dtype = to_numpy_type(self.dtype) as_vector = self.m == 1 and not keep_dims shape_ext = (self.n, ) if as_vector else (self.n, self.m) - import numpy as np # pylint: disable=C0415 arr = np.zeros(self.shape + shape_ext, dtype=dtype) - taichi.lang.meta.matrix_to_ext_arr(self, arr, as_vector) - ti.sync() + from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415 + matrix_to_ext_arr(self, arr, as_vector) + runtime_ops.sync() return arr def to_torch(self, device=None, keep_dims=False): @@ -1326,11 +1234,13 @@ def to_torch(self, device=None, keep_dims=False): import torch # pylint: disable=C0415 as_vector = self.m == 1 and not keep_dims shape_ext = (self.n, ) if as_vector else (self.n, self.m) + # pylint: disable=E1101 arr = torch.empty(self.shape + shape_ext, dtype=to_pytorch_type(self.dtype), device=device) - taichi.lang.meta.matrix_to_ext_arr(self, arr, as_vector) - ti.sync() + from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415 + matrix_to_ext_arr(self, arr, as_vector) + runtime_ops.sync() return arr @python_scope @@ -1343,19 +1253,22 @@ def from_numpy(self, arr): assert len(arr.shape) == len(self.shape) + 2 dim_ext = 1 if as_vector else 2 assert len(arr.shape) == len(self.shape) + dim_ext - taichi.lang.meta.ext_arr_to_matrix(arr, self, as_vector) - ti.sync() + from taichi._kernels import ext_arr_to_matrix # pylint: disable=C0415 + ext_arr_to_matrix(arr, self, as_vector) + runtime_ops.sync() @python_scope def __setitem__(self, key, value): - self.initialize_host_accessors() - self[key].set_entries(value) + self._initialize_host_accessors() + self[key]._set_entries(value) @python_scope def __getitem__(self, key): - self.initialize_host_accessors() - key = self.pad_key(key) - return Matrix.with_entries(self.n, self.m, self.host_access(key)) + self._initialize_host_accessors() + key = self._pad_key(key) + _host_access = self._host_access(key) + return Matrix([[_host_access[i * self.m + j] for j in range(self.m)] + for i in range(self.n)]) def __repr__(self): # make interactive shell happy, prevent materialization @@ -1376,7 +1289,7 @@ def __call__(self, *args): elif len(args) == 1: # fill a single scalar if isinstance(args[0], (numbers.Number, expr.Expr)): - return self.scalar_filled(args[0]) + return self.filled_with_scalar(args[0]) # fill a single vector or matrix entries = args[0] else: @@ -1396,36 +1309,28 @@ def __call__(self, *args): mat = self.cast(Matrix(entries, dt=self.dtype)) return mat - def cast(self, mat, in_place=False): - if not in_place: - mat = mat.copy() + def cast(self, mat): # sanity check shape if self.m != mat.m or self.n != mat.n: raise TaichiSyntaxError( f"Incompatible arguments for the custom vector/matrix type: ({self.n}, {self.m}), ({mat.n}, {mat.m})" ) if in_python_scope(): - mat.entries = [ - int(x) if self.dtype in ti.integer_types else x - for x in mat.entries - ] - else: - # only performs casting in Taichi scope - mat.entries = [cast(x, self.dtype) for x in mat.entries] - return mat + return Matrix([[ + int(mat(i, j)) if self.dtype in primitive_types.integer_types + else float(mat(i, j)) for j in range(self.m) + ] for i in range(self.n)]) + return mat.cast(self.dtype) - def empty(self): - """ - Create an empty instance of the given compound type. - """ - return Matrix.empty(self.n, self.m) + def filled_with_scalar(self, value): + return Matrix([[value for _ in range(self.m)] for _ in range(self.n)]) def field(self, **kwargs): return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs) class MatrixNdarray(Ndarray): - """Taichi ndarray with matrix elements implemented with a torch tensor. + """Taichi ndarray with matrix elements. Args: n (int): Number of rows of the matrix. @@ -1436,21 +1341,16 @@ class MatrixNdarray(Ndarray): """ def __init__(self, n, m, dtype, shape, layout): self.layout = layout + self.shape = shape + self.n = n + self.m = m arr_shape = (n, m) + shape if layout == Layout.SOA else shape + (n, m) super().__init__(dtype, arr_shape) @property - def n(self): - return self.arr.shape[0 if self.layout == Layout.SOA else -2] - - @property - def m(self): - return self.arr.shape[1 if self.layout == Layout.SOA else -1] - - @property - def shape(self): + def element_shape(self): arr_shape = tuple(self.arr.shape) - return arr_shape[2:] if self.layout == Layout.SOA else arr_shape[:-2] + return arr_shape[:2] if self.layout == Layout.SOA else arr_shape[-2:] @python_scope def __setitem__(self, key, value): @@ -1466,17 +1366,35 @@ def __setitem__(self, key, value): def __getitem__(self, key): key = () if key is None else ( key, ) if isinstance(key, numbers.Number) else tuple(key) - return Matrix.with_entries(self.n, self.m, [ - NdarrayHostAccess(self, key, (i, j)) for i in range(self.n) - for j in range(self.m) - ]) + return Matrix( + [[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)] + for i in range(self.n)]) + + @python_scope + def to_numpy(self): + return self._ndarray_matrix_to_numpy(as_vector=0) + + @python_scope + def from_numpy(self, arr): + self._ndarray_matrix_from_numpy(arr, as_vector=0) + + def __deepcopy__(self, memo=None): + ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape, + self.layout) + ret_arr.copy_from(self) + return ret_arr + + def _fill_by_kernel(self, val): + from taichi._kernels import \ + fill_ndarray_matrix # pylint: disable=C0415 + fill_ndarray_matrix(self, val) def __repr__(self): return f'<{self.n}x{self.m} {self.layout} ti.Matrix.ndarray>' class VectorNdarray(Ndarray): - """Taichi ndarray with vector elements implemented with a torch tensor. + """Taichi ndarray with vector elements. Args: n (int): Size of the vector. @@ -1486,17 +1404,15 @@ class VectorNdarray(Ndarray): """ def __init__(self, n, dtype, shape, layout): self.layout = layout + self.shape = shape + self.n = n arr_shape = (n, ) + shape if layout == Layout.SOA else shape + (n, ) super().__init__(dtype, arr_shape) @property - def n(self): - return self.arr.shape[0 if self.layout == Layout.SOA else -1] - - @property - def shape(self): + def element_shape(self): arr_shape = tuple(self.arr.shape) - return arr_shape[1:] if self.layout == Layout.SOA else arr_shape[:-1] + return arr_shape[:1] if self.layout == Layout.SOA else arr_shape[-1:] @python_scope def __setitem__(self, key, value): @@ -1509,9 +1425,29 @@ def __setitem__(self, key, value): def __getitem__(self, key): key = () if key is None else ( key, ) if isinstance(key, numbers.Number) else tuple(key) - return Matrix.with_entries( - self.n, 1, + return Vector( [NdarrayHostAccess(self, key, (i, )) for i in range(self.n)]) + @python_scope + def to_numpy(self): + return self._ndarray_matrix_to_numpy(as_vector=1) + + @python_scope + def from_numpy(self, arr): + self._ndarray_matrix_from_numpy(arr, as_vector=1) + + def __deepcopy__(self, memo=None): + ret_arr = VectorNdarray(self.n, self.dtype, self.shape, self.layout) + ret_arr.copy_from(self) + return ret_arr + + def _fill_by_kernel(self, val): + from taichi._kernels import \ + fill_ndarray_matrix # pylint: disable=C0415 + fill_ndarray_matrix(self, val) + def __repr__(self): return f'<{self.n} {self.layout} ti.Vector.ndarray>' + + +__all__ = ["Matrix", "Vector", "MatrixField", "MatrixNdarray", "VectorNdarray"] diff --git a/python/taichi/lang/mesh.py b/python/taichi/lang/mesh.py new file mode 100644 index 0000000000000..850ece441b5cf --- /dev/null +++ b/python/taichi/lang/mesh.py @@ -0,0 +1,537 @@ +import json + +import numpy as np +from taichi._lib import core as _ti_core +from taichi.lang import impl +from taichi.lang.enums import Layout +from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.field import Field, ScalarField +from taichi.lang.matrix import (MatrixField, _IntermediateMatrix, + _MatrixFieldElement) +from taichi.lang.struct import StructField +from taichi.lang.util import python_scope +from taichi.types import i32 +from taichi.types.compound_types import CompoundType + +from taichi import lang + +MeshTopology = _ti_core.MeshTopology +MeshElementType = _ti_core.MeshElementType +MeshRelationType = _ti_core.MeshRelationType +ConvType = _ti_core.ConvType +element_order = _ti_core.element_order +from_end_element_order = _ti_core.from_end_element_order +to_end_element_order = _ti_core.to_end_element_order +relation_by_orders = _ti_core.relation_by_orders +inverse_relation = _ti_core.inverse_relation +element_type_name = _ti_core.element_type_name + + +class MeshAttrType: + def __init__(self, name, dtype, reorder, needs_grad): + self.name = name + self.dtype = dtype + self.reorder = reorder + self.needs_grad = needs_grad + + +class MeshReorderedScalarFieldProxy(ScalarField): + def __init__(self, field: ScalarField, mesh_ptr: _ti_core.MeshPtr, + element_type: MeshElementType, g2r_field: ScalarField): + self.vars = field.vars + self.host_accessors = field.host_accessors + self.grad = field.grad + + self.mesh_ptr = mesh_ptr + self.element_type = element_type + self.g2r_field = g2r_field + + @python_scope + def __setitem__(self, key, value): + self._initialize_host_accessors() + key = self.g2r_field[key] + self.host_accessors[0].setter(value, *self._pad_key(key)) + + @python_scope + def __getitem__(self, key): + self._initialize_host_accessors() + key = self.g2r_field[key] + return self.host_accessors[0].getter(*self._pad_key(key)) + + +class MeshReorderedMatrixFieldProxy(MatrixField): + def __init__(self, field: MatrixField, mesh_ptr: _ti_core.MeshPtr, + element_type: MeshElementType, g2r_field: ScalarField): + self.vars = field.vars + self.host_accessors = field.host_accessors + self.grad = field.grad + self.n = field.n + self.m = field.m + self.dynamic_index_stride = field.dynamic_index_stride + + self.mesh_ptr = mesh_ptr + self.element_type = element_type + self.g2r_field = g2r_field + + @python_scope + def __setitem__(self, key, value): + self._initialize_host_accessors() + self[key]._set_entries(value) + + @python_scope + def __getitem__(self, key): + self._initialize_host_accessors() + key = self.g2r_field[key] + key = self._pad_key(key) + return _IntermediateMatrix(self.n, self.m, self._host_access(key)) + + +class MeshElementField: + def __init__(self, mesh_instance, _type, attr_dict, field_dict, g2r_field): + self.mesh = mesh_instance + self._type = _type + self.attr_dict = attr_dict + self.field_dict = field_dict + self.g2r_field = g2r_field + + self._register_fields() + + @property + def keys(self): + return list(self.field_dict.keys()) + + @property + def _members(self): + return list(self.field_dict.values()) + + @property + def _items(self): + return self.field_dict.items() + + @staticmethod + def _make_getter(key): + def getter(self): + if key not in self.getter_dict: + if self.attr_dict[key].reorder: + if isinstance(self.field_dict[key], ScalarField): + self.getter_dict[key] = MeshReorderedScalarFieldProxy( + self.field_dict[key], self.mesh.mesh_ptr, + self._type, self.g2r_field) + elif isinstance(self.field_dict[key], MatrixField): + self.getter_dict[key] = MeshReorderedMatrixFieldProxy( + self.field_dict[key], self.mesh.mesh_ptr, + self._type, self.g2r_field) + else: + self.getter_dict[key] = self.field_dict[key] + """Get an entry from custom struct by name.""" + return self.getter_dict[key] + + return getter + + def _register_fields(self): + self.getter_dict = {} + for k in self.keys: + setattr(MeshElementField, k, + property(fget=MeshElementField._make_getter(k))) + + def _get_field_members(self): + field_members = [] + for m in self._members: + assert isinstance(m, Field) + field_members += m._get_field_members() + return field_members + + @python_scope + def copy_from(self, other): + assert isinstance(other, Field) + assert set(self.keys) == set(other.keys) + for k in self.keys: + self.field_dict[k].copy_from(other[k]) + + @python_scope + def fill(self, val): + for v in self._members: + v.fill(val) + + def _initialize_host_accessors(self): + for v in self._members: + v._initialize_host_accessors() + + def get_member_field(self, key): + return self.field_dict[key] + + @python_scope + def from_numpy(self, array_dict): + for k, v in self._items: + v.from_numpy(array_dict[k]) + + @python_scope + def from_torch(self, array_dict): + for k, v in self._items: + v.from_torch(array_dict[k]) + + @python_scope + def to_numpy(self): + return {k: v.to_numpy() for k, v in self._items} + + @python_scope + def to_torch(self, device=None): + return {k: v.to_torch(device=device) for k, v in self._items} + + @python_scope + def __len__(self): + return _ti_core.get_num_elements(self.mesh.mesh_ptr, self._type) + + +class MeshElement: + def __init__(self, _type, builder): + self.builder = builder + self._type = _type + self.layout = Layout.SOA + self.attr_dict = {} + + def _SOA(self, soa=True): # AOS/SOA + self.layout = Layout.SOA if soa else Layout.AOS + + def _AOS(self, aos=True): + self.layout = Layout.AOS if aos else Layout.SOA + + SOA = property(fget=_SOA) + AOS = property(fget=_AOS) + + def place( + self, + members, + reorder=False, + needs_grad=False, + ): + self.builder.elements.add(self._type) + for key, dtype in members.items(): + if key in {'verts', 'edges', 'faces', 'cells'}: + raise TaichiSyntaxError( + f"'{key}' cannot use as attribute name. It has been reserved as ti.Mesh's keyword." + ) + self.attr_dict[key] = MeshAttrType(key, dtype, reorder, needs_grad) + + def build(self, mesh_instance, size, g2r_field): + field_dict = {} + + for key, attr in self.attr_dict.items(): + if isinstance(attr.dtype, CompoundType): + field_dict[key] = attr.dtype.field(shape=None, + needs_grad=attr.needs_grad) + else: + field_dict[key] = impl.field(attr.dtype, + shape=None, + needs_grad=attr.needs_grad) + + if self.layout == Layout.SOA: + for key, field in field_dict.items(): + impl.root.dense(impl.axes(0), size).place(field) + if self.attr_dict[key].needs_grad: + impl.root.dense(impl.axes(0), size).place(field.grad) + elif len(field_dict) > 0: + impl.root.dense(impl.axes(0), + size).place(*tuple(field_dict.values())) + grads = [] + for key, field in field_dict.items(): + if self.attr_dict[key].needs_grad: + grads.append(field.grad) + if len(grads) > 0: + impl.root.dense(impl.axes(0), size).place(*grads) + + return MeshElementField(mesh_instance, self._type, self.attr_dict, + field_dict, g2r_field) + + def link(self, element): + assert isinstance(element, MeshElement) + assert element.builder == self.builder + self.builder.relations.add(tuple([self._type, element._type])) + self.builder.elements.add(self._type) + self.builder.elements.add(element._type) + + +# Define the instance of the Mesh Type, stores the field (type and data) info +class MeshInstance: + def __init__(self, _type): + self._type = _type + self.mesh_ptr = _ti_core.create_mesh() + + def set_owned_offset(self, element_type: MeshElementType, + owned_offset: ScalarField): + _ti_core.set_owned_offset(self.mesh_ptr, element_type, + owned_offset.vars[0].ptr.snode()) + + def set_total_offset(self, element_type: MeshElementType, + total_offset: ScalarField): + _ti_core.set_total_offset(self.mesh_ptr, element_type, + total_offset.vars[0].ptr.snode()) + + def set_index_mapping(self, element_type: MeshElementType, + conv_type: ConvType, mapping: ScalarField): + _ti_core.set_index_mapping(self.mesh_ptr, element_type, conv_type, + mapping.vars[0].ptr.snode()) + + def set_num_patches(self, num_patches: int): + _ti_core.set_num_patches(self.mesh_ptr, num_patches) + + def set_patch_max_element_num(self, element_type: MeshElementType, + max_element_num: int): + _ti_core.set_patch_max_element_num(self.mesh_ptr, element_type, + max_element_num) + + def set_relation_fixed(self, rel_type: MeshRelationType, + value: ScalarField): + _ti_core.set_relation_fixed(self.mesh_ptr, rel_type, + value.vars[0].ptr.snode()) + + def set_relation_dynamic(self, rel_type: MeshRelationType, + value: ScalarField, offset: ScalarField): + _ti_core.set_relation_dynamic(self.mesh_ptr, rel_type, + value.vars[0].ptr.snode(), + offset.vars[0].ptr.snode()) + + def add_mesh_attribute(self, element_type, snode, reorder_type): + _ti_core.add_mesh_attribute(self.mesh_ptr, element_type, snode, + reorder_type) + + +class MeshMetadata: + def __init__(self, data): + self.num_patches = data["num_patches"] + + self.element_fields = {} + self.relation_fields = {} + self.num_elements = {} + self.max_num_per_patch = {} + + for element in data["elements"]: + element_type = MeshElementType(element["order"]) + self.num_elements[element_type] = element["num"] + self.max_num_per_patch[element_type] = element["max_num_per_patch"] + + element["l2g_mapping"] = np.array(element["l2g_mapping"]) + element["l2r_mapping"] = np.array(element["l2r_mapping"]) + element["g2r_mapping"] = np.array(element["g2r_mapping"]) + self.element_fields[element_type] = {} + self.element_fields[element_type]["owned"] = impl.field( + dtype=i32, shape=self.num_patches + 1) + self.element_fields[element_type]["total"] = impl.field( + dtype=i32, shape=self.num_patches + 1) + self.element_fields[element_type]["l2g"] = impl.field( + dtype=i32, shape=element["l2g_mapping"].shape[0]) + self.element_fields[element_type]["l2r"] = impl.field( + dtype=i32, shape=element["l2r_mapping"].shape[0]) + self.element_fields[element_type]["g2r"] = impl.field( + dtype=i32, shape=element["g2r_mapping"].shape[0]) + + for relation in data["relations"]: + from_order = relation["from_order"] + to_order = relation["to_order"] + rel_type = MeshRelationType( + relation_by_orders(from_order, to_order)) + self.relation_fields[rel_type] = {} + self.relation_fields[rel_type]["value"] = impl.field( + dtype=i32, shape=len(relation["value"])) + if from_order <= to_order: + self.relation_fields[rel_type]["offset"] = impl.field( + dtype=i32, shape=len(relation["offset"])) + + for element in data["elements"]: + element_type = MeshElementType(element["order"]) + self.element_fields[element_type]["owned"].from_numpy( + np.array(element["owned_offsets"])) + self.element_fields[element_type]["total"].from_numpy( + np.array(element["total_offsets"])) + self.element_fields[element_type]["l2g"].from_numpy( + element["l2g_mapping"]) + self.element_fields[element_type]["l2r"].from_numpy( + element["l2r_mapping"]) + self.element_fields[element_type]["g2r"].from_numpy( + element["g2r_mapping"]) + + for relation in data["relations"]: + from_order = relation["from_order"] + to_order = relation["to_order"] + rel_type = MeshRelationType( + relation_by_orders(from_order, to_order)) + self.relation_fields[rel_type]["value"].from_numpy( + np.array(relation["value"])) + if from_order <= to_order: + self.relation_fields[rel_type]["offset"].from_numpy( + np.array(relation["offset"])) + + self.attrs = {} + self.attrs["x"] = np.array(data["attrs"]["x"]).reshape(-1, 3) + + +# Define the Mesh Type, stores the field type info +class MeshBuilder: + def __init__(self, topology): + if not lang.misc.is_extension_supported(impl.current_cfg().arch, + lang.extension.mesh): + raise Exception('Backend ' + str(impl.current_cfg().arch) + + ' doesn\'t support MeshTaichi extension') + + self.topology = topology + self.verts = MeshElement(MeshElementType.Vertex, self) + self.edges = MeshElement(MeshElementType.Edge, self) + self.faces = MeshElement(MeshElementType.Face, self) + if topology == MeshTopology.Tetrahedron: + self.cells = MeshElement(MeshElementType.Cell, self) + + self.elements = set() + self.relations = set() + + def build(self, metadata: MeshMetadata): + instance = MeshInstance(self) + instance.fields = {} + + instance.set_num_patches(metadata.num_patches) + + for element in self.elements: + _ti_core.set_num_elements(instance.mesh_ptr, element, + metadata.num_elements[element]) + instance.set_patch_max_element_num( + element, metadata.max_num_per_patch[element]) + + element_name = element_type_name(element) + setattr( + instance, element_name, + getattr(self, element_name).build( + instance, metadata.num_elements[element], + metadata.element_fields[element]["g2r"])) + instance.fields[element] = getattr(instance, element_name) + + instance.set_owned_offset( + element, metadata.element_fields[element]["owned"]) + instance.set_total_offset( + element, metadata.element_fields[element]["total"]) + instance.set_index_mapping(element, ConvType.l2g, + metadata.element_fields[element]["l2g"]) + instance.set_index_mapping(element, ConvType.l2r, + metadata.element_fields[element]["l2r"]) + instance.set_index_mapping(element, ConvType.g2r, + metadata.element_fields[element]["g2r"]) + + for relation in self.relations: + from_order = element_order(relation[0]) + to_order = element_order(relation[1]) + rel_type = MeshRelationType( + relation_by_orders(from_order, to_order)) + if from_order <= to_order: + instance.set_relation_dynamic( + rel_type, metadata.relation_fields[rel_type]["value"], + metadata.relation_fields[rel_type]["offset"]) + else: + instance.set_relation_fixed( + rel_type, metadata.relation_fields[rel_type]["value"]) + + if "x" in instance.verts.attr_dict: # pylint: disable=E1101 + instance.verts.x.from_numpy(metadata.attrs["x"]) # pylint: disable=E1101 + + return instance + + +# Mesh First Class +class Mesh: + def __init__(self): + pass + + @staticmethod + def Tet(): + return MeshBuilder(MeshTopology.Tetrahedron) + + @staticmethod + def Tri(): + return MeshBuilder(MeshTopology.Triangle) + + @staticmethod + def load_meta(filename): + with open(filename, "r") as fi: + data = json.loads(fi.read()) + return MeshMetadata(data) + + @staticmethod + def generate_meta(data): + return MeshMetadata(data) + + +def TriMesh(): + return Mesh.Tri() + + +def TetMesh(): + return Mesh.Tet() + + +class MeshElementFieldProxy: + def __init__(self, mesh: MeshInstance, element_type: MeshElementType, + entry_expr: impl.Expr): + self.mesh = mesh + self.element_type = element_type + self.entry_expr = entry_expr + + element_field = self.mesh.fields[self.element_type] + for key, attr in element_field.field_dict.items(): + global_entry_expr = impl.Expr( + _ti_core.get_index_conversion( + self.mesh.mesh_ptr, element_type, entry_expr, + ConvType.l2r if element_field.attr_dict[key].reorder else + ConvType.l2g)) # transform index space + global_entry_expr_group = impl.make_expr_group( + *tuple([global_entry_expr])) + if isinstance(attr, MatrixField): + setattr(self, key, + _MatrixFieldElement(attr, global_entry_expr_group)) + elif isinstance(attr, StructField): + raise RuntimeError('ti.Mesh has not support StructField yet') + else: # isinstance(attr, Field) + var = attr._get_field_members()[0].ptr + setattr( + self, key, + impl.Expr(_ti_core.subscript(var, + global_entry_expr_group))) + + for element_type in self.mesh._type.elements: + setattr(self, element_type_name(element_type), + impl.mesh_relation_access(self.mesh, self, element_type)) + + @property + def ptr(self): + return self.entry_expr + + @property + def id(self): # return the global non-reordered index + l2g_expr = impl.Expr( + _ti_core.get_index_conversion(self.mesh.mesh_ptr, + self.element_type, self.entry_expr, + ConvType.l2g)) + return l2g_expr + + +class MeshRelationAccessProxy: + def __init__(self, mesh: MeshInstance, from_index: impl.Expr, + to_element_type: MeshElementType): + self.mesh = mesh + self.from_index = from_index + self.to_element_type = to_element_type + + @property + def size(self): + return impl.Expr( + _ti_core.get_relation_size(self.mesh.mesh_ptr, self.from_index.ptr, + self.to_element_type)) + + def subscript(self, *indices): + assert len(indices) == 1 + entry_expr = _ti_core.get_relation_access(self.mesh.mesh_ptr, + self.from_index.ptr, + self.to_element_type, + impl.Expr(indices[0]).ptr) + entry_expr.type_check(impl.get_runtime().prog.config) + return MeshElementFieldProxy(self.mesh, self.to_element_type, + entry_expr) + + +__all__ = ["Mesh", "TetMesh", "TriMesh"] diff --git a/python/taichi/lang/meta.py b/python/taichi/lang/meta.py deleted file mode 100644 index 5dc3a080d9348..0000000000000 --- a/python/taichi/lang/meta.py +++ /dev/null @@ -1,135 +0,0 @@ -from taichi.core import get_os_name -from taichi.lang import impl -from taichi.lang.expr import Expr -from taichi.lang.field import ScalarField -from taichi.lang.kernel_impl import kernel -from taichi.type.annotations import ext_arr, template - -import taichi as ti - -# A set of helper (meta)functions - - -@kernel -def fill_tensor(tensor: template(), val: template()): - for I in ti.grouped(tensor): - tensor[I] = val - - -@kernel -def tensor_to_ext_arr(tensor: template(), arr: ext_arr()): - for I in ti.grouped(tensor): - arr[I] = tensor[I] - - -@kernel -def vector_to_fast_image(img: template(), out: ext_arr()): - # FIXME: Why is ``for i, j in img:`` slower than: - for i, j in ti.ndrange(*img.shape): - r, g, b = 0, 0, 0 - color = img[i, img.shape[1] - 1 - j] - if ti.static(img.dtype in [ti.f32, ti.f64]): - r, g, b = min(255, max(0, int(color * 255))) - else: - impl.static_assert(img.dtype == ti.u8) - r, g, b = color - idx = j * img.shape[0] + i - # We use i32 for |out| since OpenGL and Metal doesn't support u8 types - if ti.static(get_os_name() != 'osx'): - out[idx] = (r << 16) + (g << 8) + b - else: - # What's -16777216? - # - # On Mac, we need to set the alpha channel to 0xff. Since Mac's GUI - # is big-endian, the color is stored in ABGR order, and we need to - # add 0xff000000, which is -16777216 in I32's legit range. (Albeit - # the clarity, adding 0xff000000 doesn't work.) - alpha = -16777216 - out[idx] = (b << 16) + (g << 8) + r + alpha - - -@kernel -def tensor_to_image(tensor: template(), arr: ext_arr()): - for I in ti.grouped(tensor): - t = ti.cast(tensor[I], ti.f32) - arr[I, 0] = t - arr[I, 1] = t - arr[I, 2] = t - - -@kernel -def vector_to_image(mat: template(), arr: ext_arr()): - for I in ti.grouped(mat): - for p in ti.static(range(mat.n)): - arr[I, p] = ti.cast(mat[I][p], ti.f32) - if ti.static(mat.n <= 2): - arr[I, 2] = 0 - - -@kernel -def tensor_to_tensor(tensor: template(), other: template()): - for I in ti.grouped(tensor): - tensor[I] = other[I] - - -@kernel -def ext_arr_to_tensor(arr: ext_arr(), tensor: template()): - for I in ti.grouped(tensor): - tensor[I] = arr[I] - - -@kernel -def matrix_to_ext_arr(mat: template(), arr: ext_arr(), as_vector: template()): - for I in ti.grouped(mat): - for p in ti.static(range(mat.n)): - for q in ti.static(range(mat.m)): - if ti.static(as_vector): - arr[I, p] = mat[I][p] - else: - arr[I, p, q] = mat[I][p, q] - - -@kernel -def ext_arr_to_matrix(arr: ext_arr(), mat: template(), as_vector: template()): - for I in ti.grouped(mat): - for p in ti.static(range(mat.n)): - for q in ti.static(range(mat.m)): - if ti.static(as_vector): - mat[I][p] = arr[I, p] - else: - mat[I][p, q] = arr[I, p, q] - - -@kernel -def clear_gradients(vars: template()): - for I in ti.grouped(ScalarField(Expr(vars[0]))): - for s in ti.static(vars): - ScalarField(Expr(s))[I] = 0 - - -@kernel -def clear_loss(l: template()): - # Using SNode writers would result in a forced sync, therefore we wrap these - # writes into a kernel. - l[None] = 0 - l.grad[None] = 1 - - -@kernel -def fill_matrix(mat: template(), vals: template()): - for I in ti.grouped(mat): - for p in ti.static(range(mat.n)): - for q in ti.static(range(mat.m)): - mat[I][p, q] = vals[p][q] - - -@kernel -def snode_deactivate(b: template()): - for I in ti.grouped(b): - ti.deactivate(b, I) - - -@kernel -def snode_deactivate_dynamic(b: template()): - for I in ti.grouped(b.parent()): - ti.deactivate(b, I) diff --git a/python/taichi/lang/misc.py b/python/taichi/lang/misc.py new file mode 100644 index 0000000000000..8d54644f427ba --- /dev/null +++ b/python/taichi/lang/misc.py @@ -0,0 +1,667 @@ +import atexit +import functools +import os +import shutil +import tempfile +import warnings +from copy import deepcopy as _deepcopy + +from taichi._lib import core as _ti_core +from taichi._lib.utils import locale_encode +from taichi.lang import impl +from taichi.lang.expr import Expr +from taichi.lang.impl import axes, get_runtime +from taichi.lang.snode import SNode +from taichi.profiler.kernel_profiler import get_default_kernel_profiler +from taichi.types.primitive_types import f32, f64, i32, i64 + +from taichi import _logging, _snode, _version_check + +warnings.filterwarnings("once", category=DeprecationWarning, module="taichi") + +# ---------------------- +i = axes(0) +"""Axis 0. For multi-dimensional arrays it's the direction downward the rows. +For a 1d array it's the direction along this array. +""" +# ---------------------- + +j = axes(1) +"""Axis 1. For multi-dimensional arrays it's the direction across the columns. +""" +# ---------------------- + +k = axes(2) +"""Axis 2. For arrays of dimension `d` >= 3, view each cell as an array of +lower dimension d-2, it's the first axis of this cell. +""" +# ---------------------- + +l = axes(3) +"""Axis 3. For arrays of dimension `d` >= 4, view each cell as an array of +lower dimension d-2, it's the second axis of this cell. +""" +# ---------------------- + +ij = axes(0, 1) +"""Axes (0, 1). +""" +# ---------------------- + +ik = axes(0, 2) +"""Axes (0, 2). +""" +# ---------------------- + +il = axes(0, 3) +"""Axes (0, 3). +""" +# ---------------------- + +jk = axes(1, 2) +"""Axes (1, 2). +""" +# ---------------------- + +jl = axes(1, 3) +"""Axes (1, 3). +""" +# ---------------------- + +kl = axes(2, 3) +"""Axes (2, 3). +""" +# ---------------------- + +ijk = axes(0, 1, 2) +"""Axes (0, 1, 2). +""" +# ---------------------- + +ijl = axes(0, 1, 3) +"""Axes (0, 1, 3). +""" +# ---------------------- + +ikl = axes(0, 2, 3) +"""Axes (0, 2, 3). +""" +# ---------------------- + +jkl = axes(1, 2, 3) +"""Axes (1, 2, 3). +""" +# ---------------------- + +ijkl = axes(0, 1, 2, 3) +"""Axes (0, 1, 2, 3). +""" +# ---------------------- + +# ---------------------- + +x86_64 = _ti_core.x64 +"""The x64 CPU backend. +""" +# ---------------------- + +x64 = _ti_core.x64 +"""The X64 CPU backend. +""" +# ---------------------- + +arm64 = _ti_core.arm64 +"""The ARM CPU backend. +""" +# ---------------------- + +cuda = _ti_core.cuda +"""The CUDA backend. +""" +# ---------------------- + +metal = _ti_core.metal +"""The Apple Metal backend. +""" +# ---------------------- + +opengl = _ti_core.opengl +"""The OpenGL backend. OpenGL 4.3 required. +""" +# ---------------------- + +# Skip annotating this one because it is barely maintained. +cc = _ti_core.cc + +# ---------------------- + +wasm = _ti_core.wasm +"""The WebAssembly backend. +""" +# ---------------------- + +vulkan = _ti_core.vulkan +"""The Vulkan backend. +""" +# ---------------------- + +dx11 = _ti_core.dx11 +"""The DX11 backend. +""" +# ---------------------- + +gpu = [cuda, metal, opengl, vulkan, dx11] +"""A list of GPU backends supported on the current system. + +When this is used, Taichi automatically picks the matching GPU backend. If no +GPU is detected, Taichi falls back to the CPU backend. +""" +# ---------------------- + +cpu = _ti_core.host_arch() +"""A list of CPU backends supported on the current system. + +When this is used, Taichi automatically picks the matching CPU backend. +""" +# ---------------------- + +timeline_clear = lambda: impl.get_runtime().prog.timeline_clear() # pylint: disable=unnecessary-lambda +timeline_save = lambda fn: impl.get_runtime().prog.timeline_save(fn) # pylint: disable=unnecessary-lambda + +# Legacy API +type_factory_ = _ti_core.get_type_factory_instance() + +extension = _ti_core.Extension +"""An instance of Taichi extension. + +The list of currently available extensions is ['sparse', 'async_mode', 'quant', \ + 'mesh', 'quant_basic', 'data64', 'adstack', 'bls', 'assertion', \ + 'extfunc', 'packed', 'dynamic_index']. +""" + + +def is_extension_supported(arch, ext): + """Checks whether an extension is supported on an arch. + + Args: + arch (taichi_core.Arch): Specified arch. + ext (taichi_core.Extension): Specified extension. + + Returns: + bool: Whether `ext` is supported on `arch`. + """ + return _ti_core.is_extension_supported(arch, ext) + + +def reset(): + """Resets Taichi to its initial state. + This will destroy all the allocated fields and kernels, and restore + the runtime to its default configuration. + + Example:: + + >>> a = ti.field(ti.i32, shape=()) + >>> a[None] = 1 + >>> print("before reset: ", a) + before rest: 1 + >>> + >>> ti.reset() + >>> print("after reset: ", a) + # will raise error because a is unavailable after reset. + """ + impl.reset() + global runtime + runtime = impl.get_runtime() + + +class _EnvironmentConfigurator: + def __init__(self, kwargs, _cfg): + self.cfg = _cfg + self.kwargs = kwargs + self.keys = [] + + def add(self, key, _cast=None): + _cast = _cast or self.bool_int + + self.keys.append(key) + + # TI_ASYNC= : no effect + # TI_ASYNC=0 : False + # TI_ASYNC=1 : True + name = 'TI_' + key.upper() + value = os.environ.get(name, '') + if len(value): + self[key] = _cast(value) + if key in self.kwargs: + _ti_core.warn( + f'ti.init argument "{key}" overridden by environment variable {name}={value}' + ) + del self.kwargs[key] # mark as recognized + elif key in self.kwargs: + self[key] = self.kwargs[key] + del self.kwargs[key] # mark as recognized + + def __getitem__(self, key): + return getattr(self.cfg, key) + + def __setitem__(self, key, value): + setattr(self.cfg, key, value) + + @staticmethod + def bool_int(x): + return bool(int(x)) + + +class _SpecialConfig: + # like CompileConfig in C++, this is the configurations that belong to other submodules + def __init__(self): + self.log_level = 'info' + self.gdb_trigger = False + self.short_circuit_operators = False + + +def prepare_sandbox(): + ''' + Returns a temporary directory, which will be automatically deleted on exit. + It may contain the taichi_core shared object or some misc. files. + ''' + tmp_dir = tempfile.mkdtemp(prefix='taichi-') + atexit.register(shutil.rmtree, tmp_dir) + print(f'[Taichi] preparing sandbox at {tmp_dir}') + os.mkdir(os.path.join(tmp_dir, 'runtime/')) + return tmp_dir + + +def check_require_version(require_version): + ''' + Check if installed version meets the requirements. + Allow to specify .... + . is optional. If not match, raise an exception. + ''' + # Extract version number part (i.e. toss any revision / hash parts). + version_number_str = require_version + for c_idx, c in enumerate(require_version): + if not (c.isdigit() or c == "."): + version_number_str = require_version[:c_idx] + break + # Get required version. + try: + version_number_tuple = tuple( + [int(n) for n in version_number_str.split(".")]) + major = version_number_tuple[0] + minor = version_number_tuple[1] + patch = 0 + if len(version_number_tuple) > 2: + patch = version_number_tuple[2] + except: + raise Exception("The require_version should be formatted following PEP 440, " \ + "and inlucdes major, minor, and patch number, " \ + "e.g., major.minor.patch.") from None + # Get installed version + versions = [ + int(_ti_core.get_version_major()), + int(_ti_core.get_version_minor()), + int(_ti_core.get_version_patch()), + ] + # Match installed version and required version. + match = major == versions[0] and ( + minor < versions[1] or minor == versions[1] and patch <= versions[2]) + + if not match: + raise Exception( + f"Taichi version mismatch. Required version >= {major}.{minor}.{patch}, installed version = {_ti_core.get_version_string()}." + ) + + +def init(arch=None, + default_fp=None, + default_ip=None, + _test_mode=False, + enable_fallback=True, + require_version=None, + **kwargs): + """Initializes the Taichi runtime. + + This should always be the entry point of your Taichi program. Most + importantly, it sets the backend used throughout the program. + + Args: + arch: Backend to use. This is usually :const:`~taichi.lang.cpu` or :const:`~taichi.lang.gpu`. + default_fp (Optional[type]): Default floating-point type. + default_ip (Optional[type]): Default integral type. + require_version (Optional[string]): A version string. + **kwargs: Taichi provides highly customizable compilation through + ``kwargs``, which allows for fine grained control of Taichi compiler + behavior. Below we list some of the most frequently used ones. For a + complete list, please check out + https://github.com/taichi-dev/taichi/blob/master/taichi/program/compile_config.h. + + * ``cpu_max_num_threads`` (int): Sets the number of threads used by the CPU thread pool. + * ``debug`` (bool): Enables the debug mode, under which Taichi does a few more things like boundary checks. + * ``print_ir`` (bool): Prints the CHI IR of the Taichi kernels. + * ``packed`` (bool): Enables the packed memory layout. See https://docs.taichi.graphics/lang/articles/advanced/layout. + """ + # Check version for users every 7 days if not disabled by users. + _version_check.start_version_check_thread() + + cfg = impl.default_cfg() + # Check if installed version meets the requirements. + if require_version is not None: + check_require_version(require_version) + + # Make a deepcopy in case these args reference to items from ti.cfg, which are + # actually references. If no copy is made and the args are indeed references, + # ti.reset() could override the args to their default values. + default_fp = _deepcopy(default_fp) + default_ip = _deepcopy(default_ip) + kwargs = _deepcopy(kwargs) + reset() + + spec_cfg = _SpecialConfig() + env_comp = _EnvironmentConfigurator(kwargs, cfg) + env_spec = _EnvironmentConfigurator(kwargs, spec_cfg) + + # configure default_fp/ip: + # TODO: move these stuff to _SpecialConfig too: + env_default_fp = os.environ.get("TI_DEFAULT_FP") + if env_default_fp: + if default_fp is not None: + _ti_core.warn( + f'ti.init argument "default_fp" overridden by environment variable TI_DEFAULT_FP={env_default_fp}' + ) + if env_default_fp == '32': + default_fp = f32 + elif env_default_fp == '64': + default_fp = f64 + elif env_default_fp is not None: + raise ValueError( + f'Invalid TI_DEFAULT_FP={env_default_fp}, should be 32 or 64') + + env_default_ip = os.environ.get("TI_DEFAULT_IP") + if env_default_ip: + if default_ip is not None: + _ti_core.warn( + f'ti.init argument "default_ip" overridden by environment variable TI_DEFAULT_IP={env_default_ip}' + ) + if env_default_ip == '32': + default_ip = i32 + elif env_default_ip == '64': + default_ip = i64 + elif env_default_ip is not None: + raise ValueError( + f'Invalid TI_DEFAULT_IP={env_default_ip}, should be 32 or 64') + + if default_fp is not None: + impl.get_runtime().set_default_fp(default_fp) + if default_ip is not None: + impl.get_runtime().set_default_ip(default_ip) + + # submodule configurations (spec_cfg): + env_spec.add('log_level', str) + env_spec.add('gdb_trigger') + env_spec.add('short_circuit_operators') + + # compiler configurations (ti.cfg): + for key in dir(cfg): + if key in ['arch', 'default_fp', 'default_ip']: + continue + _cast = type(getattr(cfg, key)) + if _cast is bool: + _cast = None + env_comp.add(key, _cast) + + unexpected_keys = kwargs.keys() + + if len(unexpected_keys): + raise KeyError( + f'Unrecognized keyword argument(s) for ti.init: {", ".join(unexpected_keys)}' + ) + + # dispatch configurations that are not in ti.cfg: + if not _test_mode: + _ti_core.set_core_trigger_gdb_when_crash(spec_cfg.gdb_trigger) + impl.get_runtime().short_circuit_operators = \ + spec_cfg.short_circuit_operators + _logging.set_logging_level(spec_cfg.log_level.lower()) + + # select arch (backend): + env_arch = os.environ.get('TI_ARCH') + if env_arch is not None: + _logging.info(f'Following TI_ARCH setting up for arch={env_arch}') + arch = _ti_core.arch_from_name(env_arch) + cfg.arch = adaptive_arch_select(arch, enable_fallback, cfg.use_gles) + if cfg.arch == cc: + _ti_core.set_tmp_dir(locale_encode(prepare_sandbox())) + print(f'[Taichi] Starting on arch={_ti_core.arch_name(cfg.arch)}') + + # user selected visible device + visible_device = os.environ.get("TI_VISIBLE_DEVICE") + if visible_device and (cfg.arch == vulkan or _ti_core.GGUI_AVAILABLE): + _ti_core.set_vulkan_visible_device(visible_device) + + if _test_mode: + return spec_cfg + + get_default_kernel_profiler().set_kernel_profiler_mode(cfg.kernel_profiler) + + # create a new program: + impl.get_runtime().create_program() + + _logging.trace('Materializing runtime...') + impl.get_runtime().prog.materialize_runtime() + + impl._root_fb = _snode.FieldsBuilder() + + if not os.environ.get("TI_DISABLE_SIGNAL_HANDLERS", False): + impl.get_runtime()._register_signal_handlers() + + return None + + +def no_activate(*args): + for v in args: + get_runtime().prog.no_activate(v._snode.ptr) + + +def block_local(*args): + """Hints Taichi to cache the fields and to enable the BLS optimization. + + Please visit https://docs.taichi.graphics/lang/articles/advanced/performance + for how BLS is used. + + Args: + *args (List[Field]): A list of sparse Taichi fields. + """ + if impl.current_cfg().opt_level == 0: + _logging.warn("""opt_level = 1 is enforced to enable bls analysis.""") + impl.current_cfg().opt_level = 1 + for a in args: + for v in a._get_field_members(): + get_runtime().prog.current_ast_builder().insert_snode_access_flag( + _ti_core.SNodeAccessFlag.block_local, v.ptr) + + +def mesh_local(*args): + for a in args: + for v in a._get_field_members(): + get_runtime().prog.current_ast_builder().insert_snode_access_flag( + _ti_core.SNodeAccessFlag.mesh_local, v.ptr) + + +def cache_read_only(*args): + for a in args: + for v in a._get_field_members(): + get_runtime().prog.current_ast_builder().insert_snode_access_flag( + _ti_core.SNodeAccessFlag.read_only, v.ptr) + + +def assume_in_range(val, base, low, high): + return _ti_core.expr_assume_in_range( + Expr(val).ptr, + Expr(base).ptr, low, high) + + +def loop_unique(val, covers=None): + if covers is None: + covers = [] + if not isinstance(covers, (list, tuple)): + covers = [covers] + covers = [x.snode.ptr if isinstance(x, Expr) else x.ptr for x in covers] + return _ti_core.expr_loop_unique(Expr(val).ptr, covers) + + +def parallelize(v): + get_runtime().prog.current_ast_builder().parallelize(v) + + +serialize = lambda: parallelize(1) + + +def block_dim(dim): + """Set the number of threads in a block to `dim`. + """ + get_runtime().prog.current_ast_builder().block_dim(dim) + + +def global_thread_idx(): + return impl.get_runtime().prog.current_ast_builder( + ).insert_thread_idx_expr() + + +def mesh_patch_idx(): + return impl.get_runtime().prog.current_ast_builder().insert_patch_idx_expr( + ) + + +def Tape(loss, clear_gradients=True): + """Return a context manager of :class:`~taichi.lang.tape.TapeImpl`. The + context manager would catching all of the callings of functions that + decorated by :func:`~taichi.lang.kernel_impl.kernel` or + :func:`~taichi.ad.grad_replaced` under `with` statement, and calculate + all the partial gradients of a given loss variable by calling all of the + gradient function of the callings caught in reverse order while `with` + statement ended. + + See also :func:`~taichi.lang.kernel_impl.kernel` and + :func:`~taichi.ad.grad_replaced` for gradient functions. + + Args: + loss(:class:`~taichi.lang.expr.Expr`): The loss field, which shape should be (). + clear_gradients(Bool): Before `with` body start, clear all gradients or not. + + Returns: + :class:`~taichi.lang.tape.TapeImpl`: The context manager. + + Example:: + + >>> @ti.kernel + >>> def sum(a: ti.float32): + >>> for I in ti.grouped(x): + >>> y[None] += x[I] ** a + >>> + >>> with ti.Tape(loss = y): + >>> sum(2) + """ + impl.get_runtime().materialize() + if len(loss.shape) != 0: + raise RuntimeError( + 'The loss of `Tape` must be a 0-D field, i.e. scalar') + if not loss.snode.ptr.has_grad(): + raise RuntimeError( + 'Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)' + ' for all fields that are required by autodiff.') + if clear_gradients: + clear_all_gradients() + + from taichi._kernels import clear_loss # pylint: disable=C0415 + clear_loss(loss) + + return impl.get_runtime().get_tape(loss) + + +def clear_all_gradients(): + """Set the gradients of all fields to zero. + """ + impl.get_runtime().materialize() + + def visit(node): + places = [] + for _i in range(node.ptr.get_num_ch()): + ch = node.ptr.get_ch(_i) + if not ch.is_place(): + visit(SNode(ch)) + else: + if not ch.is_primal(): + places.append(ch.get_expr()) + + places = tuple(places) + if places: + from taichi._kernels import \ + clear_gradients # pylint: disable=C0415 + clear_gradients(places) + + for root_fb in _snode.FieldsBuilder._finalized_roots(): + visit(root_fb) + + +def is_arch_supported(arch, use_gles=False): + """Checks whether an arch is supported on the machine. + + Args: + arch (taichi_core.Arch): Specified arch. + use_gles (bool): If True, check is GLES is available otherwise + check if GLSL is available. Only effective when `arch` is `ti.opengl`. + Default is `False`. + + Returns: + bool: Whether `arch` is supported on the machine. + """ + + arch_table = { + cuda: _ti_core.with_cuda, + metal: _ti_core.with_metal, + opengl: functools.partial(_ti_core.with_opengl, use_gles), + cc: _ti_core.with_cc, + vulkan: _ti_core.with_vulkan, + dx11: _ti_core.with_dx11, + wasm: lambda: True, + cpu: lambda: True, + } + with_arch = arch_table.get(arch, lambda: False) + try: + return with_arch() + except Exception as e: + arch = _ti_core.arch_name(arch) + _ti_core.warn( + f"{e.__class__.__name__}: '{e}' occurred when detecting " + f"{arch}, consider adding `TI_ENABLE_{arch.upper()}=0` " + f" to environment variables to suppress this warning message.") + return False + + +def adaptive_arch_select(arch, enable_fallback, use_gles): + if arch is None: + return cpu + if not isinstance(arch, (list, tuple)): + arch = [arch] + for a in arch: + if is_arch_supported(a, use_gles): + return a + if not enable_fallback: + raise RuntimeError(f'Arch={arch} is not supported') + _logging.warn(f'Arch={arch} is not supported, falling back to CPU') + return cpu + + +def get_host_arch_list(): + return [_ti_core.host_arch()] + + +__all__ = [ + 'i', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'j', 'jk', 'jkl', 'jl', + 'k', 'kl', 'l', 'x86_64', 'x64', 'dx11', 'wasm', 'arm64', 'cc', 'cpu', + 'cuda', 'gpu', 'metal', 'opengl', 'vulkan', 'extension', 'parallelize', + 'block_dim', 'global_thread_idx', 'Tape', 'assume_in_range', 'block_local', + 'cache_read_only', 'clear_all_gradients', 'init', 'mesh_local', + 'no_activate', 'reset', 'mesh_patch_idx' +] diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index e8635e31efc04..b0a67774bf848 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -1,16 +1,12 @@ import builtins -import ctypes import functools import math import operator as _bt_ops_mod # bt for builtin import traceback -from taichi.core.util import ti_core as _ti_core -from taichi.lang import impl, matrix +from taichi._lib import core as _ti_core +from taichi.lang import expr, impl from taichi.lang.exception import TaichiSyntaxError -from taichi.lang.expr import Expr, make_expr_group -from taichi.lang.field import Field -from taichi.lang.snode import SNode from taichi.lang.util import cook_dtype, is_taichi_class, taichi_scope unary_ops = [] @@ -28,27 +24,23 @@ def stack_info(): def is_taichi_expr(a): - return isinstance(a, Expr) + return isinstance(a, expr.Expr) def wrap_if_not_expr(a): - _taichi_skip_traceback = 1 - return Expr(a) if not is_taichi_expr(a) else a + return expr.Expr(a) if not is_taichi_expr(a) else a def unary(foo): @functools.wraps(foo) def imp_foo(x): - _taichi_skip_traceback = 2 return foo(x) @functools.wraps(foo) def wrapped(a): - _taichi_skip_traceback = 1 if is_taichi_class(a): - return a.element_wise_unary(imp_foo) - else: - return imp_foo(a) + return a._element_wise_unary(imp_foo) + return imp_foo(a) return wrapped @@ -59,23 +51,19 @@ def wrapped(a): def binary(foo): @functools.wraps(foo) def imp_foo(x, y): - _taichi_skip_traceback = 2 return foo(x, y) @functools.wraps(foo) def rev_foo(x, y): - _taichi_skip_traceback = 2 return foo(y, x) @functools.wraps(foo) def wrapped(a, b): - _taichi_skip_traceback = 1 if is_taichi_class(a): - return a.element_wise_binary(imp_foo, b) - elif is_taichi_class(b): - return b.element_wise_binary(rev_foo, a) - else: - return imp_foo(a, b) + return a._element_wise_binary(imp_foo, b) + if is_taichi_class(b): + return b._element_wise_binary(rev_foo, a) + return imp_foo(a, b) binary_ops.append(wrapped) return wrapped @@ -87,30 +75,25 @@ def wrapped(a, b): def ternary(foo): @functools.wraps(foo) def abc_foo(a, b, c): - _taichi_skip_traceback = 2 return foo(a, b, c) @functools.wraps(foo) def bac_foo(b, a, c): - _taichi_skip_traceback = 2 return foo(a, b, c) @functools.wraps(foo) def cab_foo(c, a, b): - _taichi_skip_traceback = 2 return foo(a, b, c) @functools.wraps(foo) def wrapped(a, b, c): - _taichi_skip_traceback = 1 if is_taichi_class(a): - return a.element_wise_ternary(abc_foo, b, c) - elif is_taichi_class(b): - return b.element_wise_ternary(bac_foo, a, c) - elif is_taichi_class(c): - return c.element_wise_ternary(cab_foo, a, b) - else: - return abc_foo(a, b, c) + return a._element_wise_ternary(abc_foo, b, c) + if is_taichi_class(b): + return b._element_wise_ternary(bac_foo, a, c) + if is_taichi_class(c): + return c._element_wise_ternary(cab_foo, a, b) + return abc_foo(a, b, c) ternary_ops.append(wrapped) return wrapped @@ -122,15 +105,13 @@ def wrapped(a, b, c): def writeback_binary(foo): @functools.wraps(foo) def imp_foo(x, y): - _taichi_skip_traceback = 2 return foo(x, wrap_if_not_expr(y)) @functools.wraps(foo) def wrapped(a, b): - _taichi_skip_traceback = 1 if is_taichi_class(a): - return a.element_wise_writeback_binary(imp_foo, b) - elif is_taichi_class(b): + return a._element_wise_writeback_binary(imp_foo, b) + if is_taichi_class(b): raise TaichiSyntaxError( f'cannot augassign taichi class {type(b)} to scalar expr') else: @@ -141,233 +122,441 @@ def wrapped(a, b): def cast(obj, dtype): - _taichi_skip_traceback = 1 + """Copy and cast a scalar or a matrix to a specified data type. + Must be called in Taichi scope. + + Args: + obj (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. + + dtype (:mod:`~taichi.types.primitive_types`): A primitive type defined in :mod:`~taichi.types.primitive_types`. + + Returns: + A copy of `obj`, casted to the specified data type `dtype`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([0, 1, 2], ti.i32) + >>> y = ti.cast(x, ti.f32) + >>> print(y) + >>> + >>> test() + [0.0, 1.0, 2.0] + """ dtype = cook_dtype(dtype) if is_taichi_class(obj): # TODO: unify with element_wise_unary return obj.cast(dtype) - else: - return Expr(_ti_core.value_cast(Expr(obj).ptr, dtype)) + return expr.Expr(_ti_core.value_cast(expr.Expr(obj).ptr, dtype)) def bit_cast(obj, dtype): - _taichi_skip_traceback = 1 + """Copy and cast a scalar to a specified data type with its underlying + bits preserved. Must be called in taichi scope. + + This function is equivalent to `reinterpret_cast` in C++. + + Args: + obj (:mod:`~taichi.types.primitive_types`): Input scalar. + + dtype (:mod:`~taichi.types.primitive_types`): Target data type, must have \ + the same precision bits as the input (hence `f32` -> `f64` is not allowed). + + Returns: + A copy of `obj`, casted to the specified data type `dtype`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = 3.14 + >>> y = ti.bit_cast(x, ti.i32) + >>> print(y) # 1078523331 + >>> + >>> z = ti.bit_cast(y, ti.f32) + >>> print(z) # 3.14 + """ dtype = cook_dtype(dtype) if is_taichi_class(obj): raise ValueError('Cannot apply bit_cast on Taichi classes') else: - return Expr(_ti_core.bits_cast(Expr(obj).ptr, dtype)) + return expr.Expr(_ti_core.bits_cast(expr.Expr(obj).ptr, dtype)) def _unary_operation(taichi_op, python_op, a): - _taichi_skip_traceback = 1 if is_taichi_expr(a): - return Expr(taichi_op(a.ptr), tb=stack_info()) - else: - return python_op(a) + return expr.Expr(taichi_op(a.ptr), tb=stack_info()) + return python_op(a) def _binary_operation(taichi_op, python_op, a, b): - _taichi_skip_traceback = 1 if is_taichi_expr(a) or is_taichi_expr(b): a, b = wrap_if_not_expr(a), wrap_if_not_expr(b) - return Expr(taichi_op(a.ptr, b.ptr), tb=stack_info()) - else: - return python_op(a, b) + return expr.Expr(taichi_op(a.ptr, b.ptr), tb=stack_info()) + return python_op(a, b) def _ternary_operation(taichi_op, python_op, a, b, c): - _taichi_skip_traceback = 1 if is_taichi_expr(a) or is_taichi_expr(b) or is_taichi_expr(c): a, b, c = wrap_if_not_expr(a), wrap_if_not_expr(b), wrap_if_not_expr(c) - return Expr(taichi_op(a.ptr, b.ptr, c.ptr), tb=stack_info()) - else: - return python_op(a, b, c) + return expr.Expr(taichi_op(a.ptr, b.ptr, c.ptr), tb=stack_info()) + return python_op(a, b, c) @unary -def neg(a): - """The negate function. +def neg(x): + """Numerical negative, element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - The negative value of `a`. + Matrix or scalar `y`, so that `y = -x`. `y` has the same type as `x`. + + Example:: + >>> x = ti.Matrix([1, -1]) + >>> y = ti.neg(a) + >>> y + [-1, 1] """ - return _unary_operation(_ti_core.expr_neg, _bt_ops_mod.neg, a) + return _unary_operation(_ti_core.expr_neg, _bt_ops_mod.neg, x) @unary -def sin(a): - """The sine function. +def sin(x): + """Trigonometric sine, element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Angle, in radians. Returns: - Sine of `a`. + The sine of each element of `x`. + + Example:: + + >>> from math import pi + >>> x = ti.Matrix([-pi/2., 0, pi/2.]) + >>> ti.sin(x) + [-1., 0., 1.] """ - return _unary_operation(_ti_core.expr_sin, math.sin, a) + return _unary_operation(_ti_core.expr_sin, math.sin, x) @unary -def cos(a): - """The cosine function. +def cos(x): + """Trigonometric cosine, element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.type.primitive_types`, :class:`~taichi.Matrix`]): \ + Angle, in radians. Returns: - Cosine of `a`. + The cosine of each element of `x`. + + Example:: + + >>> from math import pi + >>> x = ti.Matrix([-pi, 0, pi/2.]) + >>> ti.cos(x) + [-1., 1., 0.] """ - return _unary_operation(_ti_core.expr_cos, math.cos, a) + return _unary_operation(_ti_core.expr_cos, math.cos, x) @unary -def asin(a): - """The inverses function of sine. +def asin(x): + """Trigonometric inverse sine, element-wise. + + The inverse of `sin` so that, if `y = sin(x)`, then `x = asin(y)`. + + For input `x` not in the domain `[-1, 1]`, this function returns `nan` if \ + it's called in taichi scope, or raises exception if it's called in python scope. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements in [-1,1]. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + A scalar or a matrix with elements in [-1, 1]. Returns: - The inverses function of sine of `a`. + The inverse sine of each element in `x`, in radians and in the closed \ + interval `[-pi/2, pi/2]`. + + Example:: + + >>> from math import pi + >>> ti.asin(ti.Matrix([-1.0, 0.0, 1.0])) * 180 / pi + [-90., 0., 90.] """ - return _unary_operation(_ti_core.expr_asin, math.asin, a) + return _unary_operation(_ti_core.expr_asin, math.asin, x) @unary -def acos(a): - """The inverses function of cosine. +def acos(x): + """Trigonometric inverse cosine, element-wise. + + The inverse of `cos` so that, if `y = cos(x)`, then `x = acos(y)`. + + For input `x` not in the domain `[-1, 1]`, this function returns `nan` if \ + it's called in taichi scope, or raises exception if it's called in python scope. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements in [-1,1]. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + A scalar or a matrix with elements in [-1, 1]. Returns: - The inverses function of cosine of `a`. + The inverse cosine of each element in `x`, in radians and in the closed \ + interval `[0, pi]`. This is a scalar if `x` is a scalar. + + Example:: + + >>> from math import pi + >>> ti.acos(ti.Matrix([-1.0, 0.0, 1.0])) * 180 / pi + [180., 90., 0.] """ - return _unary_operation(_ti_core.expr_acos, math.acos, a) + return _unary_operation(_ti_core.expr_acos, math.acos, x) @unary -def sqrt(a): - """The square root function. +def sqrt(x): + """Return the non-negative square-root of a scalar or a matrix, + element wise. If `x < 0` an exception is raised. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements not less than zero. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The scalar or matrix whose square-roots are required. Returns: - `x` such that `x>=0` and `x^2=a`. + The square-root `y` so that `y >= 0` and `y^2 = x`. `y` has the same type as `x`. + + Example:: + + >>> x = ti.Matrix([1., 4., 9.]) + >>> y = ti.sqrt(x) + >>> y + [1.0, 2.0, 3.0] """ - return _unary_operation(_ti_core.expr_sqrt, math.sqrt, a) + return _unary_operation(_ti_core.expr_sqrt, math.sqrt, x) @unary -def rsqrt(a): +def rsqrt(x): """The reciprocal of the square root function. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + A scalar or a matrix. Returns: - The reciprocal of `sqrt(a)`. + The reciprocal of `sqrt(x)`. """ - def _rsqrt(a): - return 1 / math.sqrt(a) + def _rsqrt(x): + return 1 / math.sqrt(x) - return _unary_operation(_ti_core.expr_rsqrt, _rsqrt, a) + return _unary_operation(_ti_core.expr_rsqrt, _rsqrt, x) @unary -def floor(a): - """The floor function. +def round(x): # pylint: disable=redefined-builtin + """Round to the nearest integer, element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + A scalar or a matrix. Returns: - The greatest integer less than or equal to `a`. + The nearest integer of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([-1.5, 1.2, 2.7]) + >>> print(ti.round(x)) + [-2., 1., 3.] """ - return _unary_operation(_ti_core.expr_floor, math.floor, a) + return _unary_operation(_ti_core.expr_round, builtins.round, x) @unary -def ceil(a): - """The ceil function. +def floor(x): + """Return the floor of the input, element-wise. + + The floor of the scalar `x` is the largest integer `k`, such that `k <= x`. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - The least integer greater than or equal to `a`. + The floor of each element in `x`, with float type. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([3.14, -1.5]) + >>> y = ti.floor(x) + >>> print(y) # [3.0, -2.0] """ - return _unary_operation(_ti_core.expr_ceil, math.ceil, a) + return _unary_operation(_ti_core.expr_floor, math.floor, x) @unary -def tan(a): - """The tangent function. +def ceil(x): + """Return the ceiling of the input, element-wise. + + The ceil of the scalar `x` is the smallest integer `k`, such that `k >= x`. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - Tangent of `a`. + The ceiling of each element in `x`, with float dtype. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([3.14, -1.5]) + >>> y = ti.ceil(x) + >>> print(y) # [4.0, -1.0] """ - return _unary_operation(_ti_core.expr_tan, math.tan, a) + return _unary_operation(_ti_core.expr_ceil, math.ceil, x) @unary -def tanh(a): - """The hyperbolic tangent function. +def tan(x): + """Trigonometric tangent function, element-wise. + + Equivalent to `ti.sin(x)/ti.cos(x)` element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - `(e**x - e**(-x)) / (e**x + e**(-x))`. + The tangent values of `x`. + + Example:: + + >>> from math import pi + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([-pi, pi/2, pi]) + >>> y = ti.tan(x) + >>> print(y) + >>> + >>> test() + [-0.0, -22877334.0, 0.0] """ - return _unary_operation(_ti_core.expr_tanh, math.tanh, a) + return _unary_operation(_ti_core.expr_tan, math.tan, x) @unary -def exp(a): - """The exp function. +def tanh(x): + """Compute the hyperbolic tangent of `x`, element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. + + Returns: + The corresponding hyperbolic tangent values. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([-1.0, 0.0, 1.0]) + >>> y = ti.tanh(x) + >>> print(y) + >>> + >>> test() + [-0.761594, 0.000000, 0.761594] + """ + return _unary_operation(_ti_core.expr_tanh, math.tanh, x) + + +@unary +def exp(x): + """Compute the exponential of all elements in `x`, element-wise. + + Args: + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - `e` to the `a`. + Element-wise exponential of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([-1.0, 0.0, 1.0]) + >>> y = ti.exp(x) + >>> print(y) + >>> + >>> test() + [0.367879, 1.000000, 2.718282] """ - return _unary_operation(_ti_core.expr_exp, math.exp, a) + return _unary_operation(_ti_core.expr_exp, math.exp, x) @unary -def log(a): - """The natural logarithm function. +def log(x): + """Compute the natural logarithm, element-wise. + + The natural logarithm `log` is the inverse of the exponential function, + so that `log(exp(x)) = x`. The natural logarithm is logarithm in base `e`. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements greater than zero. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - The natural logarithm of `a`. + The natural logarithm of `x`, element-wise. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([-1.0, 0.0, 1.0]) + >>> y = ti.log(x) + >>> print(y) + >>> + >>> test() + [-nan, -inf, 0.000000] """ - return _unary_operation(_ti_core.expr_log, math.log, a) + return _unary_operation(_ti_core.expr_log, math.log, x) @unary -def abs(a): - """The absolute value function. +def abs(x): # pylint: disable=W0622 + """Compute the absolute value :math:`|x|` of `x`, element-wise. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input scalar or matrix. Returns: - The absolute value of `a`. + The absolute value of each element in `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([-1.0, 0.0, 1.0]) + >>> y = ti.abs(x) + >>> print(y) + >>> + >>> test() + [1.0, 0.0, 1.0] """ - return _unary_operation(_ti_core.expr_abs, builtins.abs, a) + return _unary_operation(_ti_core.expr_abs, builtins.abs, x) @unary @@ -397,16 +586,41 @@ def logical_not(a): def random(dtype=float): - """The random function. + """Return a single random float/integer according to the specified data type. + Must be called in taichi scope. + + If the required `dtype` is float type, this function returns a random number + sampled from the uniform distribution in the half-open interval [0, 1). + + For integer types this function returns a random integer in the + half-open interval [0, 2^32) if a 32-bit integer is required, + or a random integer in the half-open interval [0, 2^64) if a + 64-bit integer is required. Args: - dtype (DataType): Type of the random variable. + dtype (:mod:`~taichi.types.primitive_types`): Type of the required random value. Returns: - A random variable whose type is `dtype`. + A random value with type `dtype`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.random(float) + >>> print(x) # 0.090257 + >>> + >>> y = ti.random(ti.f64) + >>> print(y) # 0.716101627301 + >>> + >>> i = ti.random(ti.i32) + >>> print(i) # -963722261 + >>> + >>> j = ti.random(ti.i64) + >>> print(j) # 73412986184350777 """ dtype = cook_dtype(dtype) - x = Expr(_ti_core.make_rand_expr(dtype)) + x = expr.Expr(_ti_core.make_rand_expr(dtype)) return impl.expr_init(x) @@ -456,37 +670,74 @@ def mul(a, b): @binary -def mod(a, b): - """The remainder function. +def mod(x1, x2): + """Returns the element-wise remainder of division. + + This is equivalent to the Python modulus operator `x1 % x2` and + has the same sign as the divisor x2. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. - b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero. + x1 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Dividend scalar or matrix. + + x2 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Divisor scalar or matrix. When both `x1` and `x2` are matrices they must have the same shape. Returns: - The remainder of `a` divided by `b`. + The element-wise remainder of the quotient `floordiv(x1, x2)`. This is a scalar \ + if both `x1` and `x2` are scalars. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([3.0, 4.0, 5.0]) + >>> y = 3 + >>> z = ti.mod(y, x) + >>> print(z) + >>> + >>> test() + [1.0, 0.0, 4.0] """ def expr_python_mod(a, b): # a % b = a - (a // b) * b - quotient = Expr(_ti_core.expr_floordiv(a, b)) - multiply = Expr(_ti_core.expr_mul(b, quotient.ptr)) + quotient = expr.Expr(_ti_core.expr_floordiv(a, b)) + multiply = expr.Expr(_ti_core.expr_mul(b, quotient.ptr)) return _ti_core.expr_sub(a, multiply.ptr) - return _binary_operation(expr_python_mod, _bt_ops_mod.mod, a, b) + return _binary_operation(expr_python_mod, _bt_ops_mod.mod, x1, x2) @binary -def pow(a, b): - """The power function. +def pow(x, a): # pylint: disable=W0622 + """First array elements raised to powers from second array :math:`x^a`, element-wise. + + Negative values raised to a non-integral value will return `nan`. + A zero value raised to a negative value will return `inf`. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. - b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. + x (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The bases. + a (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The exponents. Returns: - `a` to the `b`. + The bases in `x1` raised to the exponents in `x2`. This is a scalar if both \ + `x1` and `x2` are scalars. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([-2.0, 0.0, 2.0]) + >>> y = -2.2 + >>> z = ti.pow(x, y) + >>> print(z) + >>> + >>> test() + [-nan, inf, 0.217638] """ - return _binary_operation(_ti_core.expr_pow, _bt_ops_mod.pow, a, b) + return _binary_operation(_ti_core.expr_pow, _bt_ops_mod.pow, x, a) @binary @@ -519,7 +770,7 @@ def truediv(a, b): @binary -def max(a, b): +def max_impl(a, b): """The maxnimum function. Args: @@ -533,7 +784,7 @@ def max(a, b): @binary -def min(a, b): +def min_impl(a, b): """The minimum function. Args: @@ -547,54 +798,89 @@ def min(a, b): @binary -def atan2(a, b): - """The inverses of the tangent function. +def atan2(x1, x2): + """Element-wise arc tangent of `x1/x2`. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. - b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero. + x1 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + y-coordinates. + x2 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + x-coordinates. Returns: - The inverses function of tangent of `b/a`. + Angles in radians, in the range `[-pi, pi]`. + This is a scalar if both `x1` and `x2` are scalars. + + Example:: + + >>> from math import pi + >>> @ti.kernel + >>> def test(): + >>> x = ti.Matrix([-1.0, 1.0, -1.0, 1.0]) + >>> y = ti.Matrix([-1.0, -1.0, 1.0, 1.0]) + >>> z = ti.atan2(y, x) * 180 / pi + >>> print(z) + >>> + >>> test() + [-135.0, -45.0, 135.0, 45.0] """ - return _binary_operation(_ti_core.expr_atan2, math.atan2, a, b) + return _binary_operation(_ti_core.expr_atan2, math.atan2, x1, x2) @binary -def raw_div(a, b): - """Raw_div function. +def raw_div(x1, x2): + """Return `x1 // x2` if both `x1`, `x2` are integers, otherwise return `x1/x2`. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. - b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero. + x1 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): Dividend. + x2 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): Divisor. Returns: - If `a` is a `int` and `b` is a `int`, then return `a//b`. Else return `a/b`. + Return `x1 // x2` if both `x1`, `x2` are integers, otherwise return `x1/x2`. + + Example:: + + >>> @ti.kernel + >>> def main(): + >>> x = 5 + >>> y = 3 + >>> print(raw_div(x, y)) # 1 + >>> z = 4.0 + >>> print(raw_div(x, z)) # 1.25 """ def c_div(a, b): if isinstance(a, int) and isinstance(b, int): return a // b - else: - return a / b + return a / b - return _binary_operation(_ti_core.expr_div, c_div, a, b) + return _binary_operation(_ti_core.expr_div, c_div, x1, x2) @binary -def raw_mod(a, b): - """Raw_mod function. Both `a` and `b` can be `float`. +def raw_mod(x1, x2): + """Return the remainder of `x1/x2`, element-wise. + This is the C-style `mod` function. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix. - b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): A number or a matrix with elements not equal to zero. + x1 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The dividend. + x2 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The divisor. Returns: - The remainder of `a` divided by `b`. + The remainder of `x1` divided by `x2`. + + Example:: + + >>> @ti.kernel + >>> def main(): + >>> print(ti.mod(-4, 3)) # 2 + >>> print(ti.raw_mod(-4, 3)) # -1 """ - def c_mod(a, b): - return a - b * int(float(a) / b) + def c_mod(x, y): + return x - y * int(float(x) / y) - return _binary_operation(_ti_core.expr_mod, c_mod, a, b) + return _binary_operation(_ti_core.expr_mod, c_mod, x1, x2) @binary @@ -770,18 +1056,31 @@ def bit_sar(a, b): @taichi_scope @binary -def bit_shr(a, b): - """Compute bitwise shift right (in taichi scope) +def bit_shr(x1, x2): + """Elements in `x1` shifted to the right by number of bits in `x2`. + Both `x1`, `x2` must have integer type. Args: - a (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value LHS - b (Union[:class:`~taichi.lang.expr.Expr`, :class:`~taichi.lang.matrix.Matrix`]): value RHS + x1 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Input data. + x2 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + Number of bits to remove at the right of `x1`. Returns: - Union[:class:`~taichi.lang.expr.Expr`, int]: LHS >> RHS - + Return `x1` with bits shifted `x2` times to the right. + This is a scalar if both `x1` and `x2` are scalars. + + Example:: + >>> @ti.kernel + >>> def main(): + >>> x = ti.Matrix([7, 8]) + >>> y = ti.Matrix([1, 2]) + >>> print(ti.bit_shr(x, y)) + >>> + >>> main() + [3, 2] """ - return _binary_operation(_ti_core.expr_bit_shr, _bt_ops_mod.rshift, a, b) + return _binary_operation(_ti_core.expr_bit_shr, _bt_ops_mod.rshift, x1, x2) # We don't have logic_and/or instructions yet: @@ -790,177 +1089,341 @@ def bit_shr(a, b): @ternary -def select(cond, a, b): +def select(cond, x1, x2): + """Return an array drawn from elements in `x1` or `x2`, + depending on the conditions in `cond`. + + Args: + cond (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The array of conditions. + x1, x2 (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The arrays where the output elements are taken from. + + Returns: + The output at position `k` is the k-th element of `x1` if the k-th element + in `cond` is `True`, otherwise it's the k-th element of `x2`. + + Example:: + + >>> @ti.kernel + >>> def main(): + >>> cond = ti.Matrix([0, 1, 0, 1]) + >>> x = ti.Matrix([1, 2, 3, 4]) + >>> y = ti.Matrix([-1, -2, -3, -4]) + >>> print(ti.select(cond, x, y)) + >>> + >>> main() + [-1, 2, -3, 4] + """ # TODO: systematically resolve `-1 = True` problem by introducing u1: cond = logical_not(logical_not(cond)) - def py_select(cond, a, b): - return a * cond + b * (1 - cond) + def py_select(cond, x1, x2): + return x1 * cond + x2 * (1 - cond) - return _ternary_operation(_ti_core.expr_select, py_select, cond, a, b) + return _ternary_operation(_ti_core.expr_select, py_select, cond, x1, x2) @writeback_binary -def atomic_add(a, b): - return impl.expr_init( - Expr(_ti_core.expr_atomic_add(a.ptr, b.ptr), tb=stack_info())) +def atomic_add(x, y): + """Atomically compute `x + y`, store the result in `x`, + and return the old value of `x`. + `x` must be a writable target, constant expressions or scalars + are not allowed. -@writeback_binary -def atomic_sub(a, b): - return impl.expr_init( - Expr(_ti_core.expr_atomic_sub(a.ptr, b.ptr), tb=stack_info())) - + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. -@writeback_binary -def atomic_min(a, b): + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([0, 0, 0]) + >>> y = ti.Vector([1, 2, 3]) + >>> z = ti.atomic_add(x, y) + >>> print(x) # [1, 2, 3] the new value of x + >>> print(z) # [0, 0, 0], the old value of x + >>> + >>> ti.atomic_add(1, x) # will raise TaichiSyntaxError + """ return impl.expr_init( - Expr(_ti_core.expr_atomic_min(a.ptr, b.ptr), tb=stack_info())) + expr.Expr(_ti_core.expr_atomic_add(x.ptr, y.ptr), tb=stack_info())) @writeback_binary -def atomic_max(a, b): - return impl.expr_init( - Expr(_ti_core.expr_atomic_max(a.ptr, b.ptr), tb=stack_info())) +def atomic_sub(x, y): + """Atomically subtract `x` by `y`, store the result in `x`, + and return the old value of `x`. + `x` must be a writable target, constant expressions or scalars + are not allowed. -@writeback_binary -def atomic_and(a, b): + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. + + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([0, 0, 0]) + >>> y = ti.Vector([1, 2, 3]) + >>> z = ti.atomic_sub(x, y) + >>> print(x) # [-1, -2, -3] the new value of x + >>> print(z) # [0, 0, 0], the old value of x + >>> + >>> ti.atomic_sub(1, x) # will raise TaichiSyntaxError + """ return impl.expr_init( - Expr(_ti_core.expr_atomic_bit_and(a.ptr, b.ptr), tb=stack_info())) + expr.Expr(_ti_core.expr_atomic_sub(x.ptr, y.ptr), tb=stack_info())) @writeback_binary -def atomic_or(a, b): - return impl.expr_init( - Expr(_ti_core.expr_atomic_bit_or(a.ptr, b.ptr), tb=stack_info())) +def atomic_min(x, y): + """Atomically compute the minimum of `x` and `y`, element-wise. + Store the result in `x`, and return the old value of `x`. + `x` must be a writable target, constant expressions or scalars + are not allowed. -@writeback_binary -def atomic_xor(a, b): + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. + + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = 2 + >>> y = 1 + >>> z = ti.atomic_min(x, y) + >>> print(x) # 1 the new value of x + >>> print(z) # 2, the old value of x + >>> + >>> ti.atomic_min(1, x) # will raise TaichiSyntaxError + """ return impl.expr_init( - Expr(_ti_core.expr_atomic_bit_xor(a.ptr, b.ptr), tb=stack_info())) + expr.Expr(_ti_core.expr_atomic_min(x.ptr, y.ptr), tb=stack_info())) @writeback_binary -def assign(a, b): - _ti_core.expr_assign(a.ptr, b.ptr, stack_info()) - return a +def atomic_max(x, y): + """Atomically compute the maximum of `x` and `y`, element-wise. + Store the result in `x`, and return the old value of `x`. + `x` must be a writable target, constant expressions or scalars + are not allowed. -def ti_max(*args): - num_args = len(args) - assert num_args >= 1 - if num_args == 1: - return args[0] - elif num_args == 2: - return max(args[0], args[1]) - else: - return max(args[0], ti_max(*args[1:])) + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = 1 + >>> y = 2 + >>> z = ti.atomic_max(x, y) + >>> print(x) # 2 the new value of x + >>> print(z) # 1, the old value of x + >>> + >>> ti.atomic_max(1, x) # will raise TaichiSyntaxError + """ + return impl.expr_init( + expr.Expr(_ti_core.expr_atomic_max(x.ptr, y.ptr), tb=stack_info())) -def ti_min(*args): - num_args = len(args) - assert num_args >= 1 - if num_args == 1: - return args[0] - elif num_args == 2: - return min(args[0], args[1]) - else: - return min(args[0], ti_min(*args[1:])) +@writeback_binary +def atomic_and(x, y): + """Atomically compute the bit-wise AND of `x` and `y`, element-wise. + Store the result in `x`, and return the old value of `x`. -def ti_any(a): - return a.any() + `x` must be a writable target, constant expressions or scalars + are not allowed. + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. When both are matrices they must have the same shape. -def ti_all(a): - return a.all() + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([-1, 0, 1]) + >>> y = ti.Vector([1, 2, 3]) + >>> z = ti.atomic_and(x, y) + >>> print(x) # [1, 0, 1] the new value of x + >>> print(z) # [-1, 0, 1], the old value of x + >>> + >>> ti.atomic_and(1, x) # will raise TaichiSyntaxError + """ + return impl.expr_init( + expr.Expr(_ti_core.expr_atomic_bit_and(x.ptr, y.ptr), tb=stack_info())) -def append(l, indices, val): - a = impl.expr_init( - _ti_core.insert_append(l.snode.ptr, make_expr_group(indices), - Expr(val).ptr)) - return a +@writeback_binary +def atomic_or(x, y): + """Atomically compute the bit-wise OR of `x` and `y`, element-wise. + Store the result in `x`, and return the old value of `x`. + `x` must be a writable target, constant expressions or scalars + are not allowed. -def external_func_call(func, args=[], outputs=[]): - func_addr = ctypes.cast(func, ctypes.c_void_p).value - _ti_core.insert_external_func_call(func_addr, '', make_expr_group(args), - make_expr_group(outputs)) + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. When both are matrices they must have the same shape. + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([-1, 0, 1]) + >>> y = ti.Vector([1, 2, 3]) + >>> z = ti.atomic_or(x, y) + >>> print(x) # [-1, 2, 3] the new value of x + >>> print(z) # [-1, 0, 1], the old value of x + >>> + >>> ti.atomic_or(1, x) # will raise TaichiSyntaxError + """ + return impl.expr_init( + expr.Expr(_ti_core.expr_atomic_bit_or(x.ptr, y.ptr), tb=stack_info())) -def asm(source, inputs=[], outputs=[]): - _ti_core.insert_external_func_call(0, source, make_expr_group(inputs), - make_expr_group(outputs)) +@writeback_binary +def atomic_xor(x, y): + """Atomically compute the bit-wise XOR of `x` and `y`, element-wise. + Store the result in `x`, and return the old value of `x`. -def is_active(l, indices): - return Expr( - _ti_core.insert_is_active(l.snode.ptr, make_expr_group(indices))) + `x` must be a writable target, constant expressions or scalars + are not allowed. + Args: + x, y (Union[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. When both are matrices they must have the same shape. -def activate(l, indices): - _ti_core.insert_activate(l.snode.ptr, make_expr_group(indices)) + Returns: + The old value of `x`. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> x = ti.Vector([-1, 0, 1]) + >>> y = ti.Vector([1, 2, 3]) + >>> z = ti.atomic_xor(x, y) + >>> print(x) # [-2, 2, 2] the new value of x + >>> print(z) # [-1, 0, 1], the old value of x + >>> + >>> ti.atomic_xor(1, x) # will raise TaichiSyntaxError + """ + return impl.expr_init( + expr.Expr(_ti_core.expr_atomic_bit_xor(x.ptr, y.ptr), tb=stack_info())) -def deactivate(l, indices): - _ti_core.insert_deactivate(l.snode.ptr, make_expr_group(indices)) +@writeback_binary +def assign(a, b): + impl.get_runtime().prog.current_ast_builder().expr_assign( + a.ptr, b.ptr, stack_info()) + return a -def length(l, indices): - return Expr(_ti_core.insert_len(l.snode.ptr, make_expr_group(indices))) +def max(*args): # pylint: disable=W0622 + """Compute the maximum of the arguments, element-wise. + This function takes no effect on a single argument, even it's array-like. + When there are both scalar and matrix arguments in `args`, the matrices + must have the same shape, and scalars will be broadcasted to the same shape as the matrix. -def rescale_index(a, b, I): - """Rescales the index 'I' of field (or SNode) 'a' to match the shape of SNode 'b' + Args: + args: (List[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. - Parameters - ---------- - a: ti.field(), ti.Vector.field, ti.Matrix.field() - input taichi field or snode - b: ti.field(), ti.Vector.field, ti.Matrix.field() - output taichi field or snode - I: ti.Vector() - grouped loop index + Returns: + Maximum of the inputs. - Returns - ------- - Ib: ti.Vector() - rescaled grouped loop index + Example:: + >>> @ti.kernel + >>> def foo(): + >>> x = ti.Vector([0, 1, 2]) + >>> y = ti.Vector([3, 4, 5]) + >>> z = ti.max(x, y, 4) + >>> print(z) # [4, 4, 5] """ - assert isinstance( - a, (Field, SNode)), "The first argument must be a field or an SNode" - assert isinstance( - b, (Field, SNode)), "The second argument must be a field or an SNode" - if isinstance(I, list): - I = matrix.Vector(I) - else: - assert isinstance( - I, matrix.Matrix - ), f"The third argument must be an index (list or ti.Vector)" - Ib = I.copy() - for n in range(min(I.n, min(len(a.shape), len(b.shape)))): - if a.shape[n] > b.shape[n]: - Ib.entries[n] = I.entries[n] // (a.shape[n] // b.shape[n]) - if a.shape[n] < b.shape[n]: - Ib.entries[n] = I.entries[n] * (b.shape[n] // a.shape[n]) - return Ib + num_args = len(args) + assert num_args >= 1 + if num_args == 1: + return args[0] + if num_args == 2: + return max_impl(args[0], args[1]) + return max_impl(args[0], max(*args[1:])) -def get_addr(f, indices): - """Query the memory address (on CUDA/x64) of field `f` at index `indices`. +def min(*args): # pylint: disable=W0622 + """Compute the minimum of the arguments, element-wise. - Currently, this function can only be called inside a taichi kernel. + This function takes no effect on a single argument, even it's array-like. + When there are both scalar and matrix arguments in `args`, the matrices + must have the same shape, and scalars will be broadcasted to the same shape as the matrix. Args: - f (Union[ti.field, ti.Vector.field, ti.Matrix.field]): Input taichi field for memory address query. - indices (Union[int, ti.Vector()]): The specified field indices of the query. + args: (List[:mod:`~taichi.types.primitive_types`, :class:`~taichi.Matrix`]): \ + The input. Returns: - ti.u64: The memory address of `f[indices]`. + Minimum of the inputs. + Example:: + + >>> @ti.kernel + >>> def foo(): + >>> x = ti.Vector([0, 1, 2]) + >>> y = ti.Vector([3, 4, 5]) + >>> z = ti.min(x, y, 1) + >>> print(z) # [0, 1, 1] """ - return Expr(_ti_core.expr_get_addr(f.snode.ptr, make_expr_group(indices))) + num_args = len(args) + assert num_args >= 1 + if num_args == 1: + return args[0] + if num_args == 2: + return min_impl(args[0], args[1]) + return min_impl(args[0], min(*args[1:])) + + +def ti_any(a): + return a.any() + + +def ti_all(a): + return a.all() + + +__all__ = [ + "acos", "asin", "atan2", "atomic_and", "atomic_or", "atomic_xor", + "atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast", + "bit_shr", "cast", "ceil", "cos", "exp", "floor", "log", "random", + "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh", + "max", "min", "select", "abs", "pow" +] diff --git a/python/taichi/lang/quant_impl.py b/python/taichi/lang/quant_impl.py deleted file mode 100644 index 2b369d975f230..0000000000000 --- a/python/taichi/lang/quant_impl.py +++ /dev/null @@ -1,77 +0,0 @@ -from taichi.lang import impl -from taichi.lang import type_factory_impl as tf_impl - -import taichi as ti - - -class Quant: - """Generator of quantized types. - - For more details, read https://yuanming.taichi.graphics/publication/2021-quantaichi/quantaichi.pdf. - """ - @staticmethod - def int(bits, signed=False, compute=None): - """Generates a quantized type for integers. - - Args: - bits (int): Number of bits. - signed (bool): Signed or unsigned. - compute (DataType): Type for computation. - - Returns: - DataType: The specified type. - """ - if compute is None: - compute = impl.get_runtime().default_ip - return tf_impl.type_factory.custom_int(bits, signed, compute) - - @staticmethod - def fixed(frac, signed=True, range=1.0, compute=None): - """Generates a quantized type for fixed-point real numbers. - - Args: - frac (int): Number of bits. - signed (bool): Signed or unsigned. - range (float): Range of the number. - compute (DataType): Type for computation. - - Returns: - DataType: The specified type. - """ - # TODO: handle cases with frac > 32 - frac_type = Quant.int(bits=frac, signed=signed, compute=ti.i32) - if signed: - scale = range / 2**(frac - 1) - else: - scale = range / 2**frac - if compute is None: - compute = impl.get_runtime().default_fp - return tf_impl.type_factory.custom_float(frac_type, None, compute, - scale) - - @staticmethod - def float(exp, frac, signed=True, compute=None): - """Generates a quantized type for floating-point real numbers. - - Args: - exp (int): Number of exponent bits. - frac (int): Number of fraction bits. - signed (bool): Signed or unsigned. - compute (DataType): Type for computation. - - Returns: - DataType: The specified type. - """ - # Exponent is always unsigned - exp_type = Quant.int(bits=exp, signed=False, compute=ti.i32) - # TODO: handle cases with frac > 32 - frac_type = Quant.int(bits=frac, signed=signed, compute=ti.i32) - if compute is None: - compute = impl.get_runtime().default_fp - return tf_impl.type_factory.custom_float(significand_type=frac_type, - exponent_type=exp_type, - compute_type=compute) - - -# Unstable API -quant = Quant diff --git a/python/taichi/lang/runtime_ops.py b/python/taichi/lang/runtime_ops.py index aff86c95eeac2..a1103bc3439dd 100644 --- a/python/taichi/lang/runtime_ops.py +++ b/python/taichi/lang/runtime_ops.py @@ -7,3 +7,6 @@ def sync(): def async_flush(): impl.get_runtime().prog.async_flush() + + +__all__ = ['sync'] diff --git a/python/taichi/lang/shell.py b/python/taichi/lang/shell.py index c21262917edaa..87ea33eafcf21 100644 --- a/python/taichi/lang/shell.py +++ b/python/taichi/lang/shell.py @@ -1,20 +1,19 @@ -import atexit import functools import os import sys +from taichi._lib import core as _ti_core from taichi._logging import info, warn -from taichi.core.util import ti_core as _ti_core try: - import sourceinspect as oinspect + import sourceinspect as oinspect # pylint: disable=unused-import except ImportError: warn('`sourceinspect` not installed!') warn( 'Without this package Taichi may not function well in Python IDLE interactive shell, ' 'Blender scripting module and Python native shell.') warn('Please run `python3 -m pip install sourceinspect` to install.') - import inspect as oinspect + import inspect as oinspect # pylint: disable=unused-import pybuf_enabled = False _env_enable_pybuf = os.environ.get('TI_ENABLE_PYBUF', '1') @@ -35,7 +34,6 @@ def _shell_pop_print(old_call): @functools.wraps(old_call) def new_call(*args, **kwargs): - _taichi_skip_traceback = 1 ret = old_call(*args, **kwargs) # print's in kernel won't take effect until ti.sync(), discussion: # https://github.com/taichi-dev/taichi/pull/1303#discussion_r444897102 diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index e0b9b31510dd5..f457c8428a7bd 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -1,15 +1,8 @@ import numbers -# The reason we import just the taichi.core.util module, instead of the ti_core -# object within it, is that ti_core is stateful. While in practice ti_core is -# loaded during the import procedure, it's probably still good to delay the -# access to it. -import taichi.lang -from taichi.core.util import ti_core as _ti_core -from taichi.lang import impl -from taichi.lang.expr import Expr +from taichi._lib import core as _ti_core +from taichi.lang import expr, impl, matrix from taichi.lang.field import Field -from taichi.misc.util import deprecated class SNode: @@ -59,13 +52,15 @@ def pointer(self, axes, dimensions): self.ptr.pointer(axes, dimensions, impl.current_cfg().packed)) - def hash(self, axes, dimensions): + @staticmethod + def _hash(axes, dimensions): + # original code is #def hash(self,axes, dimensions) without #@staticmethod before fix pylint R0201 """Not supported.""" raise RuntimeError('hash not yet supported') - if isinstance(dimensions, int): - dimensions = [dimensions] * len(axes) - return SNode(self.ptr.hash(axes, dimensions, - impl.current_cfg().packed)) + # if isinstance(dimensions, int): + # dimensions = [dimensions] * len(axes) + # return SNode(self.ptr.hash(axes, dimensions, + # impl.current_cfg().packed)) def dynamic(self, axis, dimension, chunk_size=None): """Adds a dynamic SNode as a child component of `self`. @@ -101,10 +96,6 @@ def bitmasked(self, axes, dimensions): self.ptr.bitmasked(axes, dimensions, impl.current_cfg().packed)) - @deprecated('_bit_struct', 'bit_struct') - def _bit_struct(self, num_bits): - return self.bit_struct(num_bits) - def bit_struct(self, num_bits: int): """Adds a bit_struct SNode as a child component of `self`. @@ -116,10 +107,6 @@ def bit_struct(self, num_bits: int): """ return SNode(self.ptr.bit_struct(num_bits, impl.current_cfg().packed)) - @deprecated('_bit_array', 'bit_array') - def _bit_array(self, axes, dimensions, num_bits): - return self.bit_array(axes, dimensions, num_bits) - def bit_array(self, axes, dimensions, num_bits): """Adds a bit_array SNode as a child component of `self`. @@ -157,7 +144,7 @@ def place(self, *args, offset=None, shared_exponent=False): for arg in args: if isinstance(arg, Field): - for var in arg.get_field_members(): + for var in arg._get_field_members(): self.ptr.place(var.ptr, offset) elif isinstance(arg, list): for x in arg: @@ -199,8 +186,22 @@ def parent(self, n=1): return impl.root return SNode(p) + def _path_from_root(self): + """Gets the path from root to `self` in the SNode tree. + + Returns: + List[Union[_Root, SNode]]: The list of SNodes on the path from root to `self`. + """ + p = self + res = [p] + while p != impl.root: + p = p.parent() + res.append(p) + res.reverse() + return res + @property - def dtype(self): + def _dtype(self): """Gets the data type of `self`. Returns: @@ -208,16 +209,8 @@ def dtype(self): """ return self.ptr.data_type() - @deprecated('x.data_type()', 'x.dtype') - def data_type(self): - return self.dtype - - @deprecated('x.dim()', 'len(x.shape)') - def dim(self): - return len(self.shape) - @property - def id(self): + def _id(self): """Gets the id of `self`. Returns: @@ -233,21 +226,11 @@ def shape(self): Tuple[int]: The number of elements from root in each axis of `self`. """ dim = self.ptr.num_active_indices() - ret = [self.ptr.get_shape_along_axis(i) for i in range(dim)] - - class callable_tuple(tuple): - @deprecated('x.shape()', 'x.shape') - def __call__(self): - return self + ret = tuple(self.ptr.get_shape_along_axis(i) for i in range(dim)) - ret = callable_tuple(ret) return ret - @deprecated('x.get_shape(i)', 'x.shape[i]') - def get_shape(self, i): - return self.shape[i] - - def loop_range(self): + def _loop_range(self): """Gets the taichi_core.Expr wrapping the taichi_core.GlobalVariableExpression corresponding to `self` to serve as loop range. Returns: @@ -256,7 +239,7 @@ def loop_range(self): return _ti_core.global_var_expr_from_snode(self.ptr) @property - def name(self): + def _name(self): """Gets the name of `self`. Returns: @@ -265,24 +248,14 @@ def name(self): return self.ptr.name() @property - def snode(self): + def _snode(self): """Gets `self`. - Returns: SNode: `self`. """ return self - @property - def needs_grad(self): - """Checks whether `self` has a corresponding gradient :class:`~taichi.lang.SNode`. - - Returns: - bool: Whether `self` has a corresponding gradient :class:`~taichi.lang.SNode`. - """ - return self.ptr.has_grad() - - def get_children(self): + def _get_children(self): """Gets all children components of `self`. Returns: @@ -294,30 +267,38 @@ def get_children(self): return children @property - def num_dynamically_allocated(self): + def _num_dynamically_allocated(self): runtime = impl.get_runtime() - runtime.materialize() + runtime.materialize_root_fb(False) return runtime.prog.get_snode_num_dynamically_allocated(self.ptr) @property - def cell_size_bytes(self): - runtime = impl.get_runtime() - runtime.materialize() + def _cell_size_bytes(self): + impl.get_runtime().materialize_root_fb(False) return self.ptr.cell_size_bytes + @property + def _offset_bytes_in_parent_cell(self): + impl.get_runtime().materialize_root_fb(False) + return self.ptr.offset_bytes_in_parent_cell + def deactivate_all(self): """Recursively deactivate all children components of `self`.""" - ch = self.get_children() + ch = self._get_children() for c in ch: c.deactivate_all() SNodeType = _ti_core.SNodeType if self.ptr.type == SNodeType.pointer or self.ptr.type == SNodeType.bitmasked: - taichi.lang.meta.snode_deactivate(self) + from taichi._kernels import \ + snode_deactivate # pylint: disable=C0415 + snode_deactivate(self) if self.ptr.type == SNodeType.dynamic: # Note that dynamic nodes are different from other sparse nodes: # instead of deactivating each element, we only need to deactivate # its parent, whose linked list of chunks of elements will be deleted. - taichi.lang.meta.snode_deactivate_dynamic(self) + from taichi._kernels import \ + snode_deactivate_dynamic # pylint: disable=C0415 + snode_deactivate_dynamic(self) def __repr__(self): type_ = str(self.ptr.type)[len('SNodeType.'):] @@ -334,7 +315,7 @@ def __str__(self): def __eq__(self, other): return self.ptr == other.ptr - def physical_index_position(self): + def _physical_index_position(self): """Gets mappings from virtual axes to physical axes. Returns: @@ -346,3 +327,90 @@ def physical_index_position(self): if physical != -1: ret[virtual] = physical return ret + + +def rescale_index(a, b, I): + """Rescales the index 'I' of field (or SNode) 'a' to match the shape of SNode 'b' + + Parameters + ---------- + a: ti.field(), ti.Vector.field, ti.Matrix.field() + input taichi field or snode + b: ti.field(), ti.Vector.field, ti.Matrix.field() + output taichi field or snode + I: ti.Vector() + grouped loop index + + Returns + ------- + Ib: ti.Vector() + rescaled grouped loop index + + """ + assert isinstance( + a, (Field, SNode)), "The first argument must be a field or an SNode" + assert isinstance( + b, (Field, SNode)), "The second argument must be a field or an SNode" + if isinstance(I, list): + I = matrix.Vector(I) + else: + assert isinstance( + I, matrix.Matrix + ), "The third argument must be an index (list or ti.Vector)" + entries = [I(i) for i in range(I.n)] + for n in range(min(I.n, min(len(a.shape), len(b.shape)))): + if a.shape[n] > b.shape[n]: + entries[n] = I(n) // (a.shape[n] // b.shape[n]) + if a.shape[n] < b.shape[n]: + entries[n] = I(n) * (b.shape[n] // a.shape[n]) + return matrix.Vector(entries) + + +def append(l, indices, val): + a = impl.expr_init( + _ti_core.insert_append(l._snode.ptr, expr.make_expr_group(indices), + expr.Expr(val).ptr)) + return a + + +def is_active(l, indices): + return expr.Expr( + _ti_core.insert_is_active(l._snode.ptr, expr.make_expr_group(indices))) + + +def activate(l, indices): + impl.get_runtime().prog.current_ast_builder().insert_activate( + l._snode.ptr, expr.make_expr_group(indices)) + + +def deactivate(l, indices): + impl.get_runtime().prog.current_ast_builder().insert_deactivate( + l._snode.ptr, expr.make_expr_group(indices)) + + +def length(l, indices): + return expr.Expr( + _ti_core.insert_len(l._snode.ptr, expr.make_expr_group(indices))) + + +def get_addr(f, indices): + """Query the memory address (on CUDA/x64) of field `f` at index `indices`. + + Currently, this function can only be called inside a taichi kernel. + + Args: + f (Union[ti.field, ti.Vector.field, ti.Matrix.field]): Input taichi field for memory address query. + indices (Union[int, ti.Vector()]): The specified field indices of the query. + + Returns: + ti.u64: The memory address of `f[indices]`. + + """ + return expr.Expr( + _ti_core.expr_get_addr(f._snode.ptr, expr.make_expr_group(indices))) + + +__all__ = [ + 'activate', 'append', 'deactivate', 'get_addr', 'is_active', 'length', + 'rescale_index', "SNode" +] diff --git a/python/taichi/lang/source_builder.py b/python/taichi/lang/source_builder.py new file mode 100644 index 0000000000000..5a151ef8149d1 --- /dev/null +++ b/python/taichi/lang/source_builder.py @@ -0,0 +1,148 @@ +import atexit +import ctypes +import os +import shutil +import subprocess +import tempfile + +from taichi._lib import core as _ti_core +from taichi.lang import impl +from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.expr import make_expr_group +from taichi.lang.util import get_clangpp + + +class SourceBuilder: + def __init__(self): + self.bc = None + self.so = None + self.mode = None + self.td = None + + def cleanup(): + if self.td is not None: + shutil.rmtree(self.td) + + atexit.register(cleanup) + + @classmethod + def from_file(cls, filename, compile_fn=None, _temp_dir=None): + self = cls() + self.td = _temp_dir + if self.td is None: + self.td = tempfile.mkdtemp() + + if filename.endswith((".cpp", ".c", ".cc")): + if impl.current_cfg().arch not in [ + _ti_core.Arch.x64, _ti_core.Arch.cuda + ]: + raise TaichiSyntaxError( + "Unsupported arch for external function call") + if compile_fn is None: + + def compile_fn_impl(filename): + if impl.current_cfg().arch == _ti_core.Arch.x64: + subprocess.call(get_clangpp() + ' -flto -c ' + + filename + ' -o ' + + os.path.join(self.td, 'source.bc'), + shell=True) + else: + subprocess.call(get_clangpp() + ' -flto -c ' + + filename + ' -o ' + + os.path.join(self.td, 'source.bc') + + ' -target nvptx64-nvidia-cuda', + shell=True) + return os.path.join(self.td, 'source.bc') + + compile_fn = compile_fn_impl + self.bc = compile_fn(filename) + self.mode = 'bc' + elif filename.endswith(".cu"): + if impl.current_cfg().arch not in [_ti_core.Arch.cuda]: + raise TaichiSyntaxError( + "Unsupported arch for external function call") + if compile_fn is None: + shutil.copy(filename, os.path.join(self.td, 'source.cu')) + + def compile_fn_impl(filename): + # Cannot use -o to specify multiple output files + subprocess.call( + get_clangpp() + ' ' + + os.path.join(self.td, 'source.cu') + + ' -c -emit-llvm -std=c++17 --cuda-gpu-arch=sm_50 -nocudalib', + cwd=self.td, + shell=True) + return os.path.join( + self.td, 'source-cuda-nvptx64-nvidia-cuda-sm_50.bc') + + compile_fn = compile_fn_impl + self.bc = compile_fn(filename) + self.mode = 'bc' + elif filename.endswith((".so", ".dylib", ".dll")): + if impl.current_cfg().arch not in [_ti_core.Arch.x64]: + raise TaichiSyntaxError( + "Unsupported arch for external function call") + self.so = ctypes.CDLL(filename) + self.mode = 'so' + elif filename.endswith(".ll"): + if impl.current_cfg().arch not in [ + _ti_core.Arch.x64, _ti_core.Arch.cuda + ]: + raise TaichiSyntaxError( + "Unsupported arch for external function call") + subprocess.call('llvm-as ' + filename + ' -o ' + + os.path.join(self.td, 'source.bc'), + shell=True) + self.bc = os.path.join(self.td, 'source.bc') + self.mode = 'bc' + elif filename.endswith(".bc"): + if impl.current_cfg().arch not in [ + _ti_core.Arch.x64, _ti_core.Arch.cuda + ]: + raise TaichiSyntaxError( + "Unsupported arch for external function call") + self.bc = filename + self.mode = 'bc' + else: + raise TaichiSyntaxError( + 'Unsupported file type for external function call.') + return self + + @classmethod + def from_source(cls, source_code, compile_fn=None): + if impl.current_cfg().arch not in [ + _ti_core.Arch.x64, _ti_core.Arch.cuda + ]: + raise TaichiSyntaxError( + "Unsupported arch for external function call") + _temp_dir = tempfile.mkdtemp() + _temp_source = os.path.join(_temp_dir, '_temp_source.cpp') + with open(_temp_source, 'w') as f: + f.write(source_code) + return SourceBuilder.from_file(_temp_source, compile_fn, _temp_dir) + + def __getattr__(self, item): + def bitcode_func_call_wrapper(*args): + impl.get_runtime().prog.current_ast_builder( + ).insert_external_func_call(0, '', self.bc, item, + make_expr_group(args), + make_expr_group([])) + + if self.mode == 'bc': + return bitcode_func_call_wrapper + + def external_func_call_wrapper(args=[], outputs=[]): + func_addr = ctypes.cast(self.so.__getattr__(item), + ctypes.c_void_p).value + impl.get_runtime().prog.current_ast_builder( + ).insert_external_func_call(func_addr, '', '', '', + make_expr_group(args), + make_expr_group(outputs)) + + if self.mode == 'so': + return external_func_call_wrapper + + raise TaichiSyntaxError('Error occurs when calling external function.') + + +__all__ = [] diff --git a/python/taichi/lang/stmt_builder.py b/python/taichi/lang/stmt_builder.py deleted file mode 100644 index fdc34d875c79c..0000000000000 --- a/python/taichi/lang/stmt_builder.py +++ /dev/null @@ -1,717 +0,0 @@ -import ast -import copy - -import astor -from taichi.lang import impl -from taichi.lang.ast.symbol_resolver import ASTResolver -from taichi.lang.ast_builder_utils import * -from taichi.lang.exception import TaichiSyntaxError -from taichi.lang.expr_builder import build_expr, build_exprs -from taichi.lang.util import to_taichi_type - -import taichi as ti - - -class StmtBuilder(Builder): - @staticmethod - def set_subscript_index(node, value): - assert isinstance(node, ast.Subscript), type(node) - if isinstance(node.slice, ast.Index): - node.slice.value = value - else: - node.slice = value - - @staticmethod - def make_single_statement(stmts): - template = 'if 1: pass' - t = ast.parse(template).body[0] - t.body = stmts - return t - - @staticmethod - def make_constant(value): - # Do not use ast.Constant which does not exist in python3.5 - node = parse_expr('0') - node.value = value - return node - - @staticmethod - def build_AugAssign(ctx, node): - node.target = build_expr(ctx, node.target) - node.value = build_expr(ctx, node.value) - template = 'x.augassign(0, 0)' - t = ast.parse(template).body[0] - t.value.func.value = node.target - t.value.func.value.ctx = ast.Load() - t.value.args[0] = node.value - t.value.args[1] = ast.Str(s=type(node.op).__name__, - ctx=ast.Load(), - kind=None) - return ast.copy_location(t, node) - - @staticmethod - def _is_string_mod_args(msg): - # 1. str % (a, b, c, ...) - # 2. str % single_item - # Note that |msg.right| may not be a tuple. - return isinstance(msg, ast.BinOp) and isinstance( - msg.left, ast.Str) and isinstance(msg.op, ast.Mod) - - @staticmethod - def _handle_string_mod_args(ctx, msg): - assert StmtBuilder._is_string_mod_args(msg) - s = msg.left.s - t = None - if isinstance(msg.right, ast.Tuple): - t = msg.right - else: - # assuming the format is `str % single_item` - t = ast.Tuple(elts=[msg.right], ctx=ast.Load()) - t = build_expr(ctx, t) - return s, t - - @staticmethod - def build_Assert(ctx, node): - extra_args = ast.List(elts=[], ctx=ast.Load()) - if node.msg is not None: - if isinstance(node.msg, ast.Constant): - msg = node.msg.value - elif isinstance(node.msg, ast.Str): - msg = node.msg.s - elif StmtBuilder._is_string_mod_args(node.msg): - msg = build_expr(ctx, node.msg) - msg, extra_args = StmtBuilder._handle_string_mod_args(ctx, msg) - else: - raise ValueError( - f"assert info must be constant, not {ast.dump(node.msg)}") - else: - msg = astor.to_source(node.test) - node.test = build_expr(ctx, node.test) - - new_node = parse_stmt('ti.ti_assert(0, 0, [])') - new_node.value.args[0] = node.test - new_node.value.args[1] = parse_expr("'{}'".format(msg.strip())) - new_node.value.args[2] = extra_args - new_node = ast.copy_location(new_node, node) - return new_node - - @staticmethod - def build_Assign(ctx, node): - node.value = build_expr(ctx, node.value) - node.targets = build_exprs(ctx, node.targets) - - is_static_assign = isinstance( - node.value, ast.Call) and ASTResolver.resolve_to( - node.value.func, ti.static, globals()) - if is_static_assign: - return node - - # Keep all generated assign statements and compose single one at last. - # The variable is introduced to support chained assignments. - # Ref https://github.com/taichi-dev/taichi/issues/2659. - assign_stmts = [] - for node_target in node.targets: - if isinstance(node_target, ast.Tuple): - assign_stmts.append( - StmtBuilder.build_assign_unpack(ctx, node, node_target)) - else: - assign_stmts.append( - StmtBuilder.build_assign_basic(ctx, node, node_target, - node.value)) - return StmtBuilder.make_single_statement(assign_stmts) - - @staticmethod - def build_assign_unpack(ctx, node, node_target): - """Build the unpack assignments like this: (target1, target2) = (value1, value2). - The function should be called only if the node target is a tuple. - - Args: - ctx (ast_builder_utils.BuilderContext): The builder context. - node (ast.Assign): An assignment. targets is a list of nodes, - and value is a single node. - node_target (ast.Tuple): A list or tuple object. elts holds a - list of nodes representing the elements. - """ - - targets = node_target.elts - - # Create - stmts = [] - - # Create a temp list and keep values in it, delete it after the initialization is finished. - holder = parse_stmt('__tmp_tuple = ti.expr_init_list(0, ' - f'{len(targets)})') - holder.value.args[0] = node.value - - stmts.append(holder) - - def tuple_indexed(i): - indexing = parse_stmt('__tmp_tuple[0]') - StmtBuilder.set_subscript_index(indexing.value, parse_expr(f"{i}")) - return indexing.value - - # Generate assign statements for every target, then merge them into one. - for i, target in enumerate(targets): - stmts.append( - StmtBuilder.build_assign_basic(ctx, node, target, - tuple_indexed(i))) - stmts.append(parse_stmt('del __tmp_tuple')) - return StmtBuilder.make_single_statement(stmts) - - @staticmethod - def build_assign_basic(ctx, node, target, value): - """Build basic assginment like this: target = value. - - Args: - ctx (ast_builder_utils.BuilderContext): The builder context. - node (ast.Assign): An assignment. targets is a list of nodes, - and value is a single node. - target (ast.Name): A variable name. id holds the name as - a string. - value: A node representing the value. - """ - is_local = isinstance(target, ast.Name) - if is_local and ctx.is_creation(target.id): - var_name = target.id - target.ctx = ast.Store() - # Create, no AST resolution needed - init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), - attr='expr_init', - ctx=ast.Load()) - rhs = ast.Call( - func=init, - args=[value], - keywords=[], - ) - ctx.create_variable(var_name) - return ast.copy_location( - ast.Assign(targets=[target], value=rhs, type_comment=None), - node) - else: - # Assign - target.ctx = ast.Load() - func = ast.Attribute(value=target, attr='assign', ctx=ast.Load()) - call = ast.Call(func=func, args=[value], keywords=[]) - return ast.copy_location(ast.Expr(value=call), node) - - @staticmethod - def build_Try(ctx, node): - raise TaichiSyntaxError( - "Keyword 'try' not supported in Taichi kernels") - - @staticmethod - def build_While(ctx, node): - if node.orelse: - raise TaichiSyntaxError( - "'else' clause for 'while' not supported in Taichi kernels") - - with ctx.control_scope(): - ctx.current_control_scope().append('while') - - template = ''' -if 1: - ti.core.begin_frontend_while(ti.Expr(1).ptr) - __while_cond = 0 - if __while_cond: - pass - else: - break - ti.core.pop_scope() -''' - cond = node.test - t = ast.parse(template).body[0] - t.body[1].value = cond - t.body = t.body[:3] + node.body + t.body[3:] - - t.body = build_stmts(ctx, t.body) - return ast.copy_location(t, node) - - @staticmethod - def build_If(ctx, node): - node.test = build_expr(ctx, node.test) - node.body = build_stmts(ctx, node.body) - node.orelse = build_stmts(ctx, node.orelse) - - is_static_if = isinstance(node.test, ast.Call) and isinstance( - node.test.func, ast.Attribute) - if is_static_if: - attr = node.test.func - if attr.attr == 'static': - is_static_if = True - else: - is_static_if = False - - if is_static_if: - # Do nothing - return node - - template = ''' -if 1: - __cond = 0 - ti.begin_frontend_if(__cond) - ti.core.begin_frontend_if_true() - ti.core.pop_scope() - ti.core.begin_frontend_if_false() - ti.core.pop_scope() -''' - t = ast.parse(template).body[0] - cond = node.test - t.body[0].value = cond - t.body = t.body[:5] + node.orelse + t.body[5:] - t.body = t.body[:3] + node.body + t.body[3:] - return ast.copy_location(t, node) - - @staticmethod - def get_for_loop_targets(node): - """ - Returns the list of indices of the for loop |node|. - See also: https://docs.python.org/3/library/ast.html#ast.For - """ - if isinstance(node.target, ast.Name): - return [node.target.id] - else: - assert isinstance(node.target, ast.Tuple) - return [name.id for name in node.target.elts] - - @staticmethod - def get_decorator(node): - if not isinstance(node, ast.Call): - return '' - for wanted, name in [ - (ti.static, 'static'), - (ti.grouped, 'grouped'), - (ti.ndrange, 'ndrange'), - ]: - if ASTResolver.resolve_to(node.func, wanted, globals()): - return name - return '' - - @staticmethod - def build_static_for(ctx, node, is_grouped): - # for i in ti.static(range(n)) - # for i, j in ti.static(ti.ndrange(n)) - # for I in ti.static(ti.grouped(ti.ndrange(n, m))) - - ctx.current_control_scope().append('static') - node.body = build_stmts(ctx, node.body) - if is_grouped: - assert len(node.iter.args[0].args) == 1 - template = ''' -if 1: - __ndrange_arg = 0 - from taichi.lang.exception import TaichiSyntaxError - if not isinstance(__ndrange_arg, ti.ndrange): - raise TaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.") - pass - del a - ''' - t = ast.parse(template).body[0] - t.body[0].value = node.iter.args[0].args[0] - t.body[3] = node - t.body[3].iter.args[0].args[0] = parse_expr('__ndrange_arg') - else: - t = parse_stmt('if 1: pass; del a') - t.body[0] = node - target = copy.deepcopy(node.target) - target.ctx = ast.Del() - if isinstance(target, ast.Tuple): - for tar in target.elts: - tar.ctx = ast.Del() - t.body[-1].targets = [target] - return t - - @staticmethod - def build_range_for(ctx, node): - # for i in range(n) - node.body = build_stmts(ctx, node.body) - loop_var = node.target.id - ctx.check_loop_var(loop_var) - template = ''' -if 1: - {} = ti.Expr(ti.core.make_id_expr('')) - ___begin = ti.Expr(0) - ___end = ti.Expr(0) - ___begin = ti.cast(___begin, ti.i32) - ___end = ti.cast(___end, ti.i32) - ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) - ti.core.end_frontend_range_for() - '''.format(loop_var, loop_var) - t = ast.parse(template).body[0] - - assert len(node.iter.args) in [1, 2] - if len(node.iter.args) == 2: - bgn = build_expr(ctx, node.iter.args[0]) - end = build_expr(ctx, node.iter.args[1]) - else: - bgn = StmtBuilder.make_constant(value=0) - end = build_expr(ctx, node.iter.args[0]) - - t.body[1].value.args[0] = bgn - t.body[2].value.args[0] = end - t.body = t.body[:6] + node.body + t.body[6:] - t.body.append(parse_stmt('del {}'.format(loop_var))) - return ast.copy_location(t, node) - - @staticmethod - def build_ndrange_for(ctx, node): - # for i, j in ti.ndrange(n) - template = f''' -if ti.static(1): - __ndrange{id(node)} = 0 - for __ndrange_I{id(node)} in range(0): - __I = __ndrange_I{id(node)} - ''' - t = ast.parse(template).body[0] - t.body[0].value = node.iter - t_loop = t.body[1] - t_loop.iter.args[0] = parse_expr( - f'__ndrange{id(node)}.acc_dimensions[0]') - targets = StmtBuilder.get_for_loop_targets(node) - targets_tmp = ['__' + name for name in targets] - loop_body = t_loop.body - for i in range(len(targets)): - if i + 1 < len(targets): - stmt = '{} = __I // __ndrange{}.acc_dimensions[{}]'.format( - targets_tmp[i], id(node), i + 1) - else: - stmt = '{} = __I'.format(targets_tmp[i]) - loop_body.append(parse_stmt(stmt)) - stmt = '{} = {} + __ndrange{}.bounds[{}][0]'.format( - targets[i], targets_tmp[i], id(node), i) - loop_body.append(parse_stmt(stmt)) - if i + 1 < len(targets): - stmt = '__I = __I - {} * __ndrange{}.acc_dimensions[{}]'.format( - targets_tmp[i], id(node), i + 1) - loop_body.append(parse_stmt(stmt)) - loop_body += node.body - - node = ast.copy_location(t, node) - return build_stmt(ctx, node) # further translate as a range for - - @staticmethod - def build_grouped_ndrange_for(ctx, node): - # for I in ti.grouped(ti.ndrange(n, m)) - node.body = build_stmts(ctx, node.body) - target = node.target.id - template = ''' -if ti.static(1): - __ndrange = 0 - {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions), disable_local_tensor=True)) - ___begin = ti.Expr(0) - ___end = __ndrange.acc_dimensions[0] - ___begin = ti.cast(___begin, ti.i32) - ___end = ti.cast(___end, ti.i32) - __ndrange_I = ti.Expr(ti.core.make_id_expr('')) - ti.core.begin_frontend_range_for(__ndrange_I.ptr, ___begin.ptr, ___end.ptr) - __I = __ndrange_I - for __grouped_I in range(len(__ndrange.dimensions)): - __grouped_I_tmp = 0 - if __grouped_I + 1 < len(__ndrange.dimensions): - __grouped_I_tmp = __I // __ndrange.acc_dimensions[__grouped_I + 1] - else: - __grouped_I_tmp = __I - ti.subscript({}, __grouped_I).assign(__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]) - if __grouped_I + 1 < len(__ndrange.dimensions): - __I = __I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1] - ti.core.end_frontend_range_for() - '''.format(target, target) - t = ast.parse(template).body[0] - node.iter.args[0].args = build_exprs(ctx, node.iter.args[0].args) - t.body[0].value = node.iter.args[0] - cut = len(t.body) - 1 - t.body = t.body[:cut] + node.body + t.body[cut:] - return ast.copy_location(t, node) - - @staticmethod - def build_struct_for(ctx, node, is_grouped): - # for i, j in x - # for I in ti.grouped(x) - node.body = build_stmts(ctx, node.body) - targets = StmtBuilder.get_for_loop_targets(node) - - for loop_var in targets: - ctx.check_loop_var(loop_var) - - var_decl = ''.join( - ' {} = ti.Expr(ti.core.make_id_expr(""))\n'.format(name) - for name in targets) # indent: 4 spaces - vars = ', '.join(targets) - if is_grouped: - template = ''' -if 1: - ___loop_var = 0 - {} = ti.lang.expr.make_var_vector(size=len(___loop_var.shape)) - ___expr_group = ti.lang.expr.make_expr_group({}) - ti.begin_frontend_struct_for(___expr_group, ___loop_var) - ti.core.end_frontend_range_for() - '''.format(vars, vars) - t = ast.parse(template).body[0] - cut = 4 - t.body[0].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - else: - template = ''' -if 1: -{} - ___loop_var = 0 - ___expr_group = ti.lang.expr.make_expr_group({}) - ti.begin_frontend_struct_for(___expr_group, ___loop_var) - ti.core.end_frontend_range_for() - '''.format(var_decl, vars) - t = ast.parse(template).body[0] - cut = len(targets) + 3 - t.body[cut - 3].value = node.iter - t.body = t.body[:cut] + node.body + t.body[cut:] - for loop_var in reversed(targets): - t.body.append(parse_stmt('del {}'.format(loop_var))) - return ast.copy_location(t, node) - - @staticmethod - def build_For(ctx, node): - if node.orelse: - raise TaichiSyntaxError( - "'else' clause for 'for' not supported in Taichi kernels") - - with ctx.control_scope(): - ctx.current_control_scope().append('for') - - decorator = StmtBuilder.get_decorator(node.iter) - double_decorator = '' - if decorator != '' and len(node.iter.args) == 1: - double_decorator = StmtBuilder.get_decorator(node.iter.args[0]) - ast.fix_missing_locations(node) - - if decorator == 'static': - if double_decorator == 'static': - raise TaichiSyntaxError("'ti.static' cannot be nested") - return StmtBuilder.build_static_for( - ctx, node, double_decorator == 'grouped') - elif decorator == 'ndrange': - if double_decorator != '': - raise TaichiSyntaxError( - "No decorator is allowed inside 'ti.ndrange") - return StmtBuilder.build_ndrange_for(ctx, node) - elif decorator == 'grouped': - if double_decorator == 'static': - raise TaichiSyntaxError( - "'ti.static' is not allowed inside 'ti.grouped'") - elif double_decorator == 'ndrange': - return StmtBuilder.build_grouped_ndrange_for(ctx, node) - elif double_decorator == 'grouped': - raise TaichiSyntaxError("'ti.grouped' cannot be nested") - else: - return StmtBuilder.build_struct_for(ctx, - node, - is_grouped=True) - elif isinstance(node.iter, ast.Call) and isinstance( - node.iter.func, ast.Name) and node.iter.func.id == 'range': - return StmtBuilder.build_range_for(ctx, node) - else: # Struct for - return StmtBuilder.build_struct_for(ctx, - node, - is_grouped=False) - - @staticmethod - def build_Break(ctx, node): - if 'static' in ctx.current_control_scope(): - return node - else: - return parse_stmt('ti.core.insert_break_stmt()') - - @staticmethod - def build_Continue(ctx, node): - if 'static' in ctx.current_control_scope(): - return node - else: - return parse_stmt('ti.core.insert_continue_stmt()') - - @staticmethod - def build_FunctionDef(ctx, node): - args = node.args - assert args.vararg is None - assert args.kwonlyargs == [] - assert args.kw_defaults == [] - assert args.kwarg is None - - arg_decls = [] - - def transform_as_kernel(): - # Treat return type - if node.returns is not None: - ret_init = parse_stmt( - 'ti.lang.kernel_arguments.decl_scalar_ret(0)') - ret_init.value.args[0] = node.returns - ctx.returns = node.returns - arg_decls.append(ret_init) - node.returns = None - - for i, arg in enumerate(args.args): - # Directly pass in template arguments, - # such as class instances ("self"), fields, SNodes, etc. - if isinstance(ctx.func.argument_annotations[i], ti.template): - continue - if isinstance(ctx.func.argument_annotations[i], - ti.linalg.sparse_matrix_builder): - arg_init = parse_stmt( - 'x = ti.lang.kernel_arguments.decl_sparse_matrix()') - arg_init.targets[0].id = arg.arg - ctx.create_variable(arg.arg) - arg_decls.append(arg_init) - elif isinstance(ctx.func.argument_annotations[i], ti.any_arr): - arg_init = parse_stmt( - 'x = ti.lang.kernel_arguments.decl_any_arr_arg(0, 0, 0, 0)' - ) - arg_init.targets[0].id = arg.arg - ctx.create_variable(arg.arg) - array_dt = ctx.arg_features[i][0] - array_dim = ctx.arg_features[i][1] - array_element_shape = ctx.arg_features[i][2] - array_layout = ctx.arg_features[i][3] - array_dt = to_taichi_type(array_dt) - dt_expr = 'ti.' + ti.core.data_type_name(array_dt) - dt = parse_expr(dt_expr) - arg_init.value.args[0] = dt - arg_init.value.args[1] = parse_expr("{}".format(array_dim)) - arg_init.value.args[2] = parse_expr( - "{}".format(array_element_shape)) - arg_init.value.args[3] = parse_expr( - "ti.{}".format(array_layout)) - arg_decls.append(arg_init) - else: - arg_init = parse_stmt( - 'x = ti.lang.kernel_arguments.decl_scalar_arg(0)') - arg_init.targets[0].id = arg.arg - dt = arg.annotation - arg_init.value.args[0] = dt - arg_decls.append(arg_init) - # remove original args - node.args.args = [] - - if ctx.is_kernel: # ti.kernel - for decorator in node.decorator_list: - if ASTResolver.resolve_to(decorator, ti.func, globals()): - raise TaichiSyntaxError( - "Function definition not allowed in 'ti.kernel'.") - transform_as_kernel() - - else: # ti.func - for decorator in node.decorator_list: - if ASTResolver.resolve_to(decorator, ti.func, globals()): - raise TaichiSyntaxError( - "Function definition not allowed in 'ti.func'.") - if impl.get_runtime().experimental_real_function: - transform_as_kernel() - else: - # Transform as force-inlined func - arg_decls = [] - for i, arg in enumerate(args.args): - # Remove annotations because they are not used. - args.args[i].annotation = None - # Template arguments are passed by reference. - if isinstance(ctx.func.argument_annotations[i], - ti.template): - ctx.create_variable(ctx.func.argument_names[i]) - continue - # Create a copy for non-template arguments, - # so that they are passed by value. - arg_init = parse_stmt('x = ti.expr_init_func(0)') - arg_init.targets[0].id = arg.arg - ctx.create_variable(arg.arg) - arg_init.value.args[0] = parse_expr(arg.arg + - '_by_value__') - args.args[i].arg += '_by_value__' - arg_decls.append(arg_init) - - with ctx.variable_scope(): - node.body = build_stmts(ctx, node.body) - - node.body = arg_decls + node.body - node.body = [parse_stmt('import taichi as ti')] + node.body - return node - - @staticmethod - def build_Return(ctx, node): - node.value = build_expr(ctx, node.value) - if ctx.is_kernel or impl.get_runtime().experimental_real_function: - # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not - if node.value is not None: - if ctx.returns is None: - raise TaichiSyntaxError( - f'A {"kernel" if ctx.is_kernel else "function"} ' - 'with a return value must be annotated ' - 'with a return type, e.g. def func() -> ti.f32') - ret_expr = parse_expr('ti.cast(ti.Expr(0), 0)') - ret_expr.args[0].args[0] = node.value - ret_expr.args[1] = ctx.returns - ret_stmt = parse_stmt('ti.core.create_kernel_return(ret.ptr)') - # For args[0], it is an ast.Attribute, because it loads the - # attribute, |ptr|, of the expression |ret_expr|. Therefore we - # only need to replace the object part, i.e. args[0].value - ret_stmt.value.args[0].value = ret_expr - return ret_stmt - return node - - @staticmethod - def build_Module(ctx, node): - with ctx.variable_scope(): - # Do NOT use |build_stmts| which inserts 'del' statements to the - # end and deletes parameters passed into the module - node.body = [build_stmt(ctx, stmt) for stmt in list(node.body)] - return node - - @staticmethod - def build_Global(ctx, node): - raise TaichiSyntaxError( - "Keyword 'global' not supported in Taichi kernels") - - @staticmethod - def build_Nonlocal(ctx, node): - raise TaichiSyntaxError( - "Keyword 'nonlocal' not supported in Taichi kernels") - - @staticmethod - def build_Raise(ctx, node): - node.exc = build_expr(ctx, node.exc) - return node - - @staticmethod - def build_Expr(ctx, node): - if not isinstance(node.value, ast.Call): - # A statement with a single expression. - return node - - # A function call. - node.value = build_expr(ctx, node.value) - # Note that we can only return an ast.Expr instead of an ast.Call. - - if impl.get_runtime().experimental_real_function: - # Generates code that inserts a FrontendExprStmt if the function - # called is a Taichi function. - # We cannot insert the FrontendExprStmt here because we do not - # know if the function is a Taichi function now. - node.value.args = [node.value.func] + node.value.args - node.value.func = parse_expr('ti.insert_expr_stmt_if_ti_func') - return node - - @staticmethod - def build_Import(ctx, node): - return node - - @staticmethod - def build_ImportFrom(ctx, node): - return node - - @staticmethod - def build_Pass(ctx, node): - return node - - -build_stmt = StmtBuilder() - - -def build_stmts(ctx, stmts): - result = [] - with ctx.variable_scope(result): - for stmt in list(stmts): - result.append(build_stmt(ctx, stmt)) - return result diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 9dd7a6cadb93a..f4f582f57b733 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -1,29 +1,46 @@ -import copy import numbers -from numpy import broadcast -from taichi.lang import expr, impl +from taichi.lang import expr, impl, ops from taichi.lang.common_ops import TaichiOperations from taichi.lang.enums import Layout from taichi.lang.exception import TaichiSyntaxError from taichi.lang.field import Field, ScalarField, SNodeHostAccess from taichi.lang.matrix import Matrix -from taichi.lang.ops import cast -from taichi.lang.types import CompoundType from taichi.lang.util import (cook_dtype, in_python_scope, is_taichi_class, python_scope, taichi_scope) - -import taichi as ti +from taichi.types import primitive_types +from taichi.types.compound_types import CompoundType class Struct(TaichiOperations): """The Struct type class. - Args: - entries (Dict[str, Union[Dict, Expr, Matrix, Struct]]): keys and values for struct members. + + A struct is a dictionary-like data structure that stores members as + (key, value) pairs. Valid data members of a struct can be scalars, + matrices or other dictionary-like stuctures. """ - is_taichi_class = True + _is_taichi_class = True def __init__(self, *args, **kwargs): + """ + Args: + entries (Dict[str, Union[Dict, Expr, Matrix, Struct]]): \ + keys and values for struct members. + + Returns: + An instance of this struct. + + Example:: + + >>> vec3 = ti.types.vector(3, ti.f32) + >>> a = ti.Struct(v=vec3([0, 0, 0]), t=1.0) + >>> print(a.items) + dict_items([('v', [0. 0. 0.]), ('t', 1.0)]) + >>> + >>> B = ti.Struct(v=vec3([0., 0., 0.]), t=1.0, A=a) + >>> print(B.items) + dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})]) + """ # converts lists to matrices and dicts to structs if len(args) == 1 and kwargs == {} and isinstance(args[0], dict): self.entries = args[0] @@ -38,47 +55,43 @@ def __init__(self, *args, **kwargs): v = Matrix(v) if isinstance(v, dict): v = Struct(v) - self.entries[k] = v - self.register_members() - self.local_tensor_proxy = None - self.any_array_access = None + self.entries[k] = v if in_python_scope() else impl.expr_init(v) + self._register_members() @property def keys(self): return list(self.entries.keys()) @property - def members(self): + def _members(self): return list(self.entries.values()) @property def items(self): return self.entries.items() - def register_members(self): + def _register_members(self): for k in self.keys: setattr(Struct, k, property( - Struct.make_getter(k), - Struct.make_setter(k), + Struct._make_getter(k), + Struct._make_setter(k), )) def __getitem__(self, key): - _taichi_skip_traceback = 1 ret = self.entries[key] if isinstance(ret, SNodeHostAccess): ret = ret.accessor.getter(*ret.key) return ret def __setitem__(self, key, value): - _taichi_skip_traceback = 1 if isinstance(self.entries[key], SNodeHostAccess): self.entries[key].accessor.setter(value, *self.entries[key].key) else: if in_python_scope(): if isinstance(self.entries[key], Struct) or isinstance( self.entries[key], Matrix): - self.entries[key].set_entries(value) + self.entries[key]._set_entries(value) else: if isinstance(value, numbers.Number): self.entries[key] = value @@ -89,121 +102,89 @@ def __setitem__(self, key, value): else: self.entries[key] = value - def set_entries(self, value): + def _set_entries(self, value): if isinstance(value, dict): value = Struct(value) for k in self.keys: self[k] = value[k] @staticmethod - def make_getter(key): + def _make_getter(key): def getter(self): """Get an entry from custom struct by name.""" - _taichi_skip_traceback = 1 return self[key] return getter @staticmethod - def make_setter(key): + def _make_setter(key): @python_scope def setter(self, value): - _taichi_skip_traceback = 1 self[key] = value return setter - def element_wise_unary(self, foo): - _taichi_skip_traceback = 1 - ret = self.empty_copy() + def _element_wise_unary(self, foo): + entries = {} for k, v in self.items: - if isinstance(v, expr.Expr): - ret.entries[k] = foo(v) + if is_taichi_class(v): + entries[k] = v._element_wise_unary(foo) else: - ret.entries[k] = v.element_wise_unary(foo) - return ret + entries[k] = foo(v) + return Struct(entries) - def element_wise_binary(self, foo, other): - _taichi_skip_traceback = 1 - ret = self.empty_copy() - if isinstance(other, (dict)): - other = Struct(other) - if isinstance(other, Struct): - if self.entries.keys() != other.entries.keys(): - raise TypeError( - f"Member mismatch between structs {self.keys}, {other.keys}" - ) - for k, v in self.items: - if isinstance(v, expr.Expr): - ret.entries[k] = foo(v, other.entries[k]) - else: - ret.entries[k] = v.element_wise_binary( - foo, other.entries[k]) - else: # assumed to be scalar - for k, v in self.items: - if isinstance(v, expr.Expr): - ret.entries[k] = foo(v, other) - else: - ret.entries[k] = v.element_wise_binary(foo, other) - return ret + def _element_wise_binary(self, foo, other): + other = self._broadcast_copy(other) + entries = {} + for k, v in self.items: + if is_taichi_class(v): + entries[k] = v._element_wise_binary(foo, other.entries[k]) + else: + entries[k] = foo(v, other.entries[k]) + return Struct(entries) - def broadcast_copy(self, other): + def _broadcast_copy(self, other): if isinstance(other, dict): other = Struct(other) if not isinstance(other, Struct): - ret = self.empty_copy() - for k, v in ret.items: - if isinstance(v, (Matrix, Struct)): - ret.entries[k] = v.broadcast_copy(other) + entries = {} + for k, v in self.items: + if is_taichi_class(v): + entries[k] = v._broadcast_copy(other) else: - ret.entries[k] = other - other = ret + entries[k] = other + other = Struct(entries) if self.entries.keys() != other.entries.keys(): raise TypeError( f"Member mismatch between structs {self.keys}, {other.keys}") return other - def element_wise_writeback_binary(self, foo, other): - ret = self.empty_copy() - if isinstance(other, (dict)): - other = Struct(other) - if is_taichi_class(other): - other = other.variable() - if foo.__name__ == 'assign' and not isinstance(other, Struct): + def _element_wise_writeback_binary(self, foo, other): + if foo.__name__ == 'assign' and not isinstance(other, (dict, Struct)): raise TaichiSyntaxError( 'cannot assign scalar expr to ' f'taichi class {type(self)}, maybe you want to use `a.fill(b)` instead?' ) - if isinstance(other, Struct): - if self.entries.keys() != other.entries.keys(): - raise TypeError( - f"Member mismatch between structs {self.keys}, {other.keys}" - ) - for k, v in self.items: - if isinstance(v, expr.Expr): - ret.entries[k] = foo(v, other.entries[k]) - else: - ret.entries[k] = v.element_wise_binary( - foo, other.entries[k]) - else: # assumed to be scalar - for k, v in self.items: - if isinstance(v, expr.Expr): - ret.entries[k] = foo(v, other) - else: - ret.entries[k] = v.element_wise_binary(foo, other) - return ret + other = self._broadcast_copy(other) + entries = {} + for k, v in self.items: + if is_taichi_class(v): + entries[k] = v._element_wise_binary(foo, other.entries[k]) + else: + entries[k] = foo(v, other.entries[k]) + return self if foo.__name__ == 'assign' else Struct(entries) - def element_wise_ternary(self, foo, other, extra): - ret = self.empty_copy() - other = self.broadcast_copy(other) - extra = self.broadcast_copy(extra) + def _element_wise_ternary(self, foo, other, extra): + other = self._broadcast_copy(other) + extra = self._broadcast_copy(extra) + entries = {} for k, v in self.items: - if isinstance(v, expr.Expr): - ret.entries[k] = foo(v, other.entries[k], extra.entries[k]) + if is_taichi_class(v): + entries[k] = v._element_wise_ternary(foo, other.entries[k], + extra.entries[k]) else: - ret.entries[k] = v.element_wise_ternary( - foo, other.entries[k], extra.entries[k]) - return ret + entries[k] = foo(v, other.entries[k], extra.entries[k]) + return Struct(entries) @taichi_scope def fill(self, val): @@ -213,35 +194,9 @@ def fill(self, val): val (Union[int, float]): Value to fill. """ def assign_renamed(x, y): - return ti.assign(x, y) - - return self.element_wise_writeback_binary(assign_renamed, val) + return ops.assign(x, y) - def empty_copy(self): - """ - Nested structs and matrices need to be recursively handled. - """ - struct = Struct.empty(self.keys) - for k, v in self.items: - if isinstance(v, (Struct, Matrix)): - struct.entries[k] = v.empty_copy() - return struct - - def copy(self): - ret = self.empty_copy() - ret.entries = copy.copy(self.entries) - return ret - - @taichi_scope - def variable(self): - ret = self.copy() - ret.entries = { - k: impl.expr_init(v) if isinstance(v, - (numbers.Number, - expr.Expr)) else v.variable() - for k, v in ret.items - } - return ret + return self._element_wise_writeback_binary(assign_renamed, val) def __len__(self): """Get the number of entries in a custom struct""" @@ -256,13 +211,11 @@ def __str__(self): item_str = ", ".join( [str(k) + "=" + str(v) for k, v in self.items]) return f'' - else: - return str(self.to_dict()) + return str(self.to_dict()) def __repr__(self): return str(self.to_dict()) - @python_scope def to_dict(self): """Converts the Struct to a dictionary. @@ -271,19 +224,11 @@ def to_dict(self): Returns: Dict: The result dictionary. """ - return self.entries - - @classmethod - def empty(cls, entries): - """Clear the struct and fill None. - - Args: - members (Dict[str, DataType]): the names and data types for struct members. - Returns: - :class:`~taichi.lang.struct.Struct`: A :class:`~taichi.lang.struct.Struct` instance filled with None. - - """ - return cls({k: None for k in entries}) + return { + k: v.to_dict() if isinstance(v, Struct) else + v.to_list() if isinstance(v, Matrix) else v + for k, v in self.entries.items() + } @classmethod @python_scope @@ -328,23 +273,35 @@ def field(cls, dim = len(shape) if layout == Layout.SOA: for e in field_dict.values(): - ti.root.dense(impl.index_nd(dim), - shape).place(e, offset=offset) + impl.root.dense(impl.index_nd(dim), + shape).place(e, offset=offset) if needs_grad: for e in field_dict.values(): - ti.root.dense(impl.index_nd(dim), - shape).place(e.grad, offset=offset) + impl.root.dense(impl.index_nd(dim), + shape).place(e.grad, offset=offset) else: - ti.root.dense(impl.index_nd(dim), - shape).place(*tuple(field_dict.values()), - offset=offset) + impl.root.dense(impl.index_nd(dim), + shape).place(*tuple(field_dict.values()), + offset=offset) if needs_grad: grads = tuple(e.grad for e in field_dict.values()) - ti.root.dense(impl.index_nd(dim), - shape).place(*grads, offset=offset) + impl.root.dense(impl.index_nd(dim), + shape).place(*grads, offset=offset) return StructField(field_dict, name=name) +class _IntermediateStruct(Struct): + """Intermediate struct class for compiler internal use only. + + Args: + entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members. + """ + def __init__(self, entries): + assert isinstance(entries, dict) + self.entries = entries + self._register_members() + + class StructField(Field): """Taichi struct field with SNode implementation. Instead of directly contraining Expr entries, the StructField object @@ -357,80 +314,74 @@ class StructField(Field): def __init__(self, field_dict, name=None): # will not call Field initializer self.field_dict = field_dict - self._name = name - self.register_fields() - - @property - def name(self): - return self._name + self.name = name + self._register_fields() @property def keys(self): return list(self.field_dict.keys()) @property - def members(self): + def _members(self): return list(self.field_dict.values()) @property - def items(self): + def _items(self): return self.field_dict.items() @staticmethod - def make_getter(key): + def _make_getter(key): def getter(self): """Get an entry from custom struct by name.""" - _taichi_skip_traceback = 1 return self.field_dict[key] return getter @staticmethod - def make_setter(key): + def _make_setter(key): @python_scope def setter(self, value): - _taichi_skip_traceback = 1 self.field_dict[key] = value return setter - def register_fields(self): + def _register_fields(self): for k in self.keys: setattr( StructField, k, property( - StructField.make_getter(k), - StructField.make_setter(k), + StructField._make_getter(k), + StructField._make_setter(k), )) - def get_field_members(self): + def _get_field_members(self): """Get A flattened list of all struct elements. Returns: A list of struct elements. """ field_members = [] - for m in self.members: + for m in self._members: assert isinstance(m, Field) - field_members += m.get_field_members() + field_members += m._get_field_members() return field_members @property - def snode(self): + def _snode(self): """Gets representative SNode for info purposes. Returns: SNode: Representative SNode (SNode of first field member). """ - return self.members[0].snode + return self._members[0]._snode - def loop_range(self): + def _loop_range(self): """Gets representative field member for loop range info. Returns: taichi_core.Expr: Representative (first) field member. """ - return self.members[0].loop_range() + return self._members[0]._loop_range() @python_scope def copy_from(self, other): @@ -453,12 +404,12 @@ def fill(self, val): Args: val (Union[int, float]): Value to fill. """ - for v in self.members: + for v in self._members: v.fill(val) - def initialize_host_accessors(self): - for v in self.members: - v.initialize_host_accessors() + def _initialize_host_accessors(self): + for v in self._members: + v._initialize_host_accessors() def get_member_field(self, key): """Creates a ScalarField using a specific field member. Only used for quant. @@ -473,12 +424,12 @@ def get_member_field(self, key): @python_scope def from_numpy(self, array_dict): - for k, v in self.items: + for k, v in self._items: v.from_numpy(array_dict[k]) @python_scope def from_torch(self, array_dict): - for k, v in self.items: + for k, v in self._items: v.from_torch(array_dict[k]) @python_scope @@ -490,7 +441,7 @@ def to_numpy(self): Returns: Dict[str, Union[numpy.ndarray, Dict]]: The result NumPy array. """ - return {k: v.to_numpy() for k, v in self.items} + return {k: v.to_numpy() for k, v in self._items} @python_scope def to_torch(self, device=None): @@ -502,21 +453,21 @@ def to_torch(self, device=None): Returns: Dict[str, Union[torch.Tensor, Dict]]: The result PyTorch tensor. """ - return {k: v.to_torch(device=device) for k, v in self.items} + return {k: v.to_torch(device=device) for k, v in self._items} @python_scope def __setitem__(self, indices, element): - self.initialize_host_accessors() - self[indices].set_entries(element) + self._initialize_host_accessors() + self[indices]._set_entries(element) @python_scope def __getitem__(self, indices): - self.initialize_host_accessors() + self._initialize_host_accessors() # scalar fields does not instantiate SNodeHostAccess by default entries = { - k: v.host_access(self.pad_key(indices))[0] if isinstance( + k: v._host_access(self._pad_key(indices))[0] if isinstance( v, ScalarField) else v[indices] - for k, v in self.items + for k, v in self._items } return Struct(entries) @@ -542,43 +493,43 @@ def __call__(self, *args, **kwargs): elif len(args) == 1: # fill a single scalar if isinstance(args[0], (numbers.Number, expr.Expr)): - entries = self.scalar_filled(args[0]) + entries = self.filled_with_scalar(args[0]) else: - # fill a single vector or matrix # initialize struct members by dictionary entries = Struct(args[0]) struct = self.cast(entries) return struct - def cast(self, struct, in_place=False): - if not in_place: - struct = struct.copy() + def cast(self, struct): # sanity check members if self.members.keys() != struct.entries.keys(): raise TaichiSyntaxError( "Incompatible arguments for custom struct members!") + entries = {} for k, dtype in self.members.items(): if isinstance(dtype, CompoundType): - struct.entries[k] = dtype.cast(struct.entries[k]) + entries[k] = dtype.cast(struct.entries[k]) else: if in_python_scope(): v = struct.entries[k] - struct.entries[k] = int( - v) if dtype in ti.integer_types else float(v) + entries[k] = int( + v + ) if dtype in primitive_types.integer_types else float(v) else: - struct.entries[k] = cast(struct.entries[k], dtype) - return struct + entries[k] = ops.cast(struct.entries[k], dtype) + return Struct(entries) - def empty(self): - """ - Create an empty instance of the given compound type. - Nested structs and matrices need to be recursively handled. - """ - struct = Struct.empty(self.members.keys()) + def filled_with_scalar(self, value): + entries = {} for k, dtype in self.members.items(): if isinstance(dtype, CompoundType): - struct.entries[k] = dtype.empty() - return struct + entries[k] = dtype.filled_with_scalar(value) + else: + entries[k] = value + return Struct(entries) def field(self, **kwargs): return Struct.field(self.members, **kwargs) + + +__all__ = ["Struct", "StructField"] diff --git a/python/taichi/lang/tape.py b/python/taichi/lang/tape.py index 74be931ba3f41..c9101d306cbd6 100644 --- a/python/taichi/lang/tape.py +++ b/python/taichi/lang/tape.py @@ -11,7 +11,7 @@ def __enter__(self): assert not self.entered, "Tape can be entered only once." self.entered = True - def __exit__(self, type, value, tb): + def __exit__(self, _type, value, tb): # print('# kernel calls', len(self.calls)) self.runtime.target_tape = None if self.eval_on_exit: diff --git a/python/taichi/lang/type_factory_impl.py b/python/taichi/lang/type_factory_impl.py deleted file mode 100644 index cd26e1433a540..0000000000000 --- a/python/taichi/lang/type_factory_impl.py +++ /dev/null @@ -1,54 +0,0 @@ -from taichi.core.util import ti_core as _ti_core -from taichi.lang import impl - - -class TypeFactory: - """A Python-side TypeFactory wrapper.""" - def __init__(self): - self.core = _ti_core.get_type_factory_instance() - - def custom_int(self, bits, signed=True, compute_type=None): - """Generates a custom int type. - - Args: - bits (int): Number of bits. - signed (bool): Signed or unsigned. - compute_type (DataType): Type for computation. - - Returns: - DataType: The specified type. - """ - if compute_type is None: - compute_type = impl.get_runtime().default_ip - if isinstance(compute_type, _ti_core.DataType): - compute_type = compute_type.get_ptr() - return self.core.get_custom_int_type(bits, signed, compute_type) - - def custom_float(self, - significand_type, - exponent_type=None, - compute_type=None, - scale=1.0): - """Generates a custom float type. - - Args: - significand_type (DataType): Type of significand. - exponent_type (DataType): Type of exponent. - compute_type (DataType): Type for computation. - scale (float): Scaling factor. - - Returns: - DataType: The specified type. - """ - if compute_type is None: - compute_type = impl.get_runtime().default_fp - if isinstance(compute_type, _ti_core.DataType): - compute_type = compute_type.get_ptr() - return self.core.get_custom_float_type(significand_type, - exponent_type, - compute_type, - scale=scale) - - -# Unstable API -type_factory = TypeFactory() diff --git a/python/taichi/lang/types.py b/python/taichi/lang/types.py deleted file mode 100644 index 1299e378b43f5..0000000000000 --- a/python/taichi/lang/types.py +++ /dev/null @@ -1,31 +0,0 @@ -import numbers - -import taichi.lang.matrix -from taichi.lang.exception import TaichiSyntaxError - - -class CompoundType: - def empty(self): - """ - Create an empty instance of the given compound type. - """ - raise NotImplementedError - - def scalar_filled(self, value): - instance = self.empty() - return instance.broadcast_copy(value) - - def field(self, **kwargs): - raise NotImplementedError - - -def matrix(m, n, dtype=None): - return taichi.lang.matrix.MatrixType(m, n, dtype=dtype) - - -def vector(m, dtype=None): - return taichi.lang.matrix.MatrixType(m, 1, dtype=dtype) - - -def struct(**kwargs): - return taichi.lang.struct.StructType(**kwargs) diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index 6d9fd2d578d8c..feb18831d91cb 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -1,11 +1,13 @@ import functools import os +import traceback import numpy as np -from taichi.core.util import ti_core as _ti_core +from colorama import Fore, Style +from taichi._lib import core as _ti_core from taichi.lang import impl - -import taichi as ti +from taichi.types.primitive_types import (f16, f32, f64, i8, i16, i32, i64, u8, + u16, u32, u64) _has_pytorch = False @@ -28,10 +30,29 @@ def has_pytorch(): return _has_pytorch +from distutils.spawn import find_executable + +# Taichi itself uses llvm-10.0.0 to compile. +# There will be some issues compiling CUDA with other clang++ version. +_clangpp_candidates = ['clang++-10'] +_clangpp_presence = None +for c in _clangpp_candidates: + if find_executable(c) is not None: + _clangpp_presence = find_executable(c) + + +def has_clangpp(): + return _clangpp_presence is not None + + +def get_clangpp(): + return _clangpp_presence + + def is_taichi_class(rhs): taichi_class = False try: - if rhs.is_taichi_class: + if rhs._is_taichi_class: taichi_class = True except: pass @@ -48,28 +69,29 @@ def to_numpy_type(dt): DataType: The counterpart data type in numpy. """ - if dt == ti.f32: + if dt == f32: return np.float32 - elif dt == ti.f64: + if dt == f64: return np.float64 - elif dt == ti.i32: + if dt == i32: return np.int32 - elif dt == ti.i64: + if dt == i64: return np.int64 - elif dt == ti.i8: + if dt == i8: return np.int8 - elif dt == ti.i16: + if dt == i16: return np.int16 - elif dt == ti.u8: + if dt == u8: return np.uint8 - elif dt == ti.u16: + if dt == u16: return np.uint16 - elif dt == ti.u32: + if dt == u32: return np.uint32 - elif dt == ti.u64: + if dt == u64: return np.uint64 - else: - assert False + if dt == f16: + return np.half + assert False def to_pytorch_type(dt): @@ -82,28 +104,27 @@ def to_pytorch_type(dt): DataType: The counterpart data type in torch. """ - if dt == ti.f32: + # pylint: disable=E1101 + if dt == f32: return torch.float32 - elif dt == ti.f64: + if dt == f64: return torch.float64 - elif dt == ti.i32: + if dt == i32: return torch.int32 - elif dt == ti.i64: + if dt == i64: return torch.int64 - elif dt == ti.i8: + if dt == i8: return torch.int8 - elif dt == ti.i16: + if dt == i16: return torch.int16 - elif dt == ti.u8: + if dt == u8: return torch.uint8 - elif dt == ti.u16: - return torch.uint16 - elif dt == ti.u32: - return torch.uint32 - elif dt == ti.u64: - return torch.uint64 - else: - assert False + if dt == f16: + return torch.float16 + if dt in (u16, u32, u64): + raise RuntimeError( + f'PyTorch doesn\'t support {dt.to_string()} data type.') + assert False def to_taichi_type(dt): @@ -120,63 +141,63 @@ def to_taichi_type(dt): return dt if dt == np.float32: - return ti.f32 - elif dt == np.float64: - return ti.f64 - elif dt == np.int32: - return ti.i32 - elif dt == np.int64: - return ti.i64 - elif dt == np.int8: - return ti.i8 - elif dt == np.int16: - return ti.i16 - elif dt == np.uint8: - return ti.u8 - elif dt == np.uint16: - return ti.u16 - elif dt == np.uint32: - return ti.u32 - elif dt == np.uint64: - return ti.u64 + return f32 + if dt == np.float64: + return f64 + if dt == np.int32: + return i32 + if dt == np.int64: + return i64 + if dt == np.int8: + return i8 + if dt == np.int16: + return i16 + if dt == np.uint8: + return u8 + if dt == np.uint16: + return u16 + if dt == np.uint32: + return u32 + if dt == np.uint64: + return u64 + if dt == np.half: + return f16 if has_pytorch(): + # pylint: disable=E1101 if dt == torch.float32: - return ti.f32 - elif dt == torch.float64: - return ti.f64 - elif dt == torch.int32: - return ti.i32 - elif dt == torch.int64: - return ti.i64 - elif dt == torch.int8: - return ti.i8 - elif dt == torch.int16: - return ti.i16 - elif dt == torch.uint8: - return ti.u8 - elif dt == torch.uint16: - return ti.u16 - elif dt == torch.uint32: - return ti.u32 - elif dt == torch.uint64: - return ti.u64 - - raise AssertionError("Unknown type {}".format(dt)) + return f32 + if dt == torch.float64: + return f64 + if dt == torch.int32: + return i32 + if dt == torch.int64: + return i64 + if dt == torch.int8: + return i8 + if dt == torch.int16: + return i16 + if dt == torch.uint8: + return u8 + if dt == torch.float16: + return f16 + if dt in (u16, u32, u64): + raise RuntimeError( + f'PyTorch doesn\'t support {dt.to_string()} data type.') + + raise AssertionError(f"Unknown type {dt}") def cook_dtype(dtype): - _taichi_skip_traceback = 1 if isinstance(dtype, _ti_core.DataType): return dtype - elif isinstance(dtype, _ti_core.Type): + if isinstance(dtype, _ti_core.Type): return _ti_core.DataType(dtype) - elif dtype is float: + if dtype is float: return impl.get_runtime().default_fp - elif dtype is int: + if dtype is int: return impl.get_runtime().default_ip - else: - raise ValueError(f'Invalid data type {dtype}') + raise ValueError(f'Invalid data type {dtype}') def in_taichi_scope(): @@ -190,7 +211,6 @@ def in_python_scope(): def taichi_scope(func): @functools.wraps(func) def wrapped(*args, **kwargs): - _taichi_skip_traceback = 1 assert in_taichi_scope(), \ f'{func.__name__} cannot be called in Python-scope' return func(*args, **kwargs) @@ -201,9 +221,32 @@ def wrapped(*args, **kwargs): def python_scope(func): @functools.wraps(func) def wrapped(*args, **kwargs): - _taichi_skip_traceback = 1 assert in_python_scope(), \ f'{func.__name__} cannot be called in Taichi-scope' return func(*args, **kwargs) return wrapped + + +def warning(msg, warning_type=UserWarning, stacklevel=1, print_stack=True): + """Print a warning message. Note that the builtin `warnings` module is + unreliable since it may be suppressed by other packages such as IPython. + + Args: + msg (str): message to print. + warning_type (Warning): type of warning. + stacklevel (int): warning stack level from the caller. + print_stack (bool): whether to print the stack + """ + msg = f'{warning_type.__name__}: {msg}' + if print_stack: + msg += f'\n{get_traceback(stacklevel)}' + print(Fore.YELLOW + Style.BRIGHT + msg + Style.RESET_ALL) + + +def get_traceback(stacklevel=1): + s = traceback.extract_stack()[:-1 - stacklevel] + return ''.join(traceback.format_list(s)) + + +__all__ = [] diff --git a/python/taichi/linalg/__init__.py b/python/taichi/linalg/__init__.py index a56d77fb4846d..469762c2ee0cd 100644 --- a/python/taichi/linalg/__init__.py +++ b/python/taichi/linalg/__init__.py @@ -1,3 +1,2 @@ -from taichi.linalg.sparse_matrix import (SparseMatrix, SparseMatrixBuilder, - sparse_matrix_builder) +from taichi.linalg.sparse_matrix import * from taichi.linalg.sparse_solver import SparseSolver diff --git a/python/taichi/linalg/sparse_matrix.py b/python/taichi/linalg/sparse_matrix.py index 9e0c21a935d9c..928aec91d1382 100644 --- a/python/taichi/linalg/sparse_matrix.py +++ b/python/taichi/linalg/sparse_matrix.py @@ -1,7 +1,8 @@ import numpy as np -from taichi.core.util import ti_core as _ti_core from taichi.lang.field import Field -from taichi.type.primitive_types import f32 +from taichi.lang.impl import get_runtime +from taichi.lang.util import warning +from taichi.types import annotations, f32 class SparseMatrix: @@ -18,7 +19,7 @@ def __init__(self, n=None, m=None, sm=None, dtype=f32): if sm is None: self.n = n self.m = m if m else n - self.matrix = _ti_core.create_sparse_matrix(n, m) + self.matrix = get_runtime().prog.create_sparse_matrix(n, m) else: self.n = sm.num_rows() self.m = sm.num_cols() @@ -55,11 +56,13 @@ def __mul__(self, other): if isinstance(other, float): sm = self.matrix * other return SparseMatrix(sm=sm) - elif isinstance(other, SparseMatrix): + if isinstance(other, SparseMatrix): assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})" sm = self.matrix * other.matrix return SparseMatrix(sm=sm) + return None + def __rmul__(self, other): """Right scalar multiplication for sparse matrix. @@ -72,6 +75,8 @@ def __rmul__(self, other): sm = other * self.matrix return SparseMatrix(sm=sm) + return None + def transpose(self): """Sparse Matrix transpose. @@ -93,16 +98,15 @@ def __matmul__(self, other): assert self.m == other.n, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})" sm = self.matrix.matmul(other.matrix) return SparseMatrix(sm=sm) - elif isinstance(other, Field): + if isinstance(other, Field): assert self.m == other.shape[ 0], f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})" return self.matrix.mat_vec_mul(other.to_numpy()) - elif isinstance(other, np.ndarray): + if isinstance(other, np.ndarray): assert self.m == other.shape[ 0], f"Dimension mismatch between sparse matrix ({self.n}, {self.m}) and vector ({other.shape})" return self.matrix.mat_vec_mul(other) - else: - assert False, f"Sparse matrix-matrix/vector multiplication does not support {type(other)} for now. Supported types are SparseMatrix, ti.field, and numpy.ndarray." + assert False, f"Sparse matrix-matrix/vector multiplication does not support {type(other)} for now. Supported types are SparseMatrix, ti.field, and numpy.ndarray." def __getitem__(self, indices): return self.matrix.get_element(indices[0], indices[1]) @@ -117,6 +121,10 @@ def __str__(self): def __repr__(self): return self.matrix.to_string() + def shape(self): + """The shape of the sparse matrix.""" + return (self.n, self.m) + class SparseMatrixBuilder: """A python wrap around sparse matrix builder. @@ -135,11 +143,12 @@ def __init__(self, dtype=f32): self.num_rows = num_rows self.num_cols = num_cols if num_cols else num_rows + self.dtype = dtype if num_rows is not None: - self.ptr = _ti_core.create_sparse_matrix_builder( - num_rows, num_cols, max_num_triplets) + self.ptr = get_runtime().prog.create_sparse_matrix_builder( + num_rows, num_cols, max_num_triplets, dtype) - def get_addr(self): + def _get_addr(self): """Get the address of the sparse matrix""" return self.ptr.get_addr() @@ -147,11 +156,18 @@ def print_triplets(self): """Print the triplets stored in the builder""" self.ptr.print_triplets() - def build(self, dtype=f32, format='CSR'): + def build(self, dtype=f32, _format='CSR'): """Create a sparse matrix using the triplets""" sm = self.ptr.build() return SparseMatrix(sm=sm) -sparse_matrix_builder = SparseMatrixBuilder -# Alias for :class:`SparseMatrixBuilder` +# TODO: remove this in 1.0 release +class sparse_matrix_builder(annotations.sparse_matrix_builder): + def __init__(self): + warning( + 'ti.linalg.sparse_matrix_builder is deprecated. Please use ti.types.sparse_matrix_builder instead.', + DeprecationWarning) + + +__all__ = ['SparseMatrix', 'SparseMatrixBuilder', 'sparse_matrix_builder'] diff --git a/python/taichi/linalg/sparse_solver.py b/python/taichi/linalg/sparse_solver.py index 4b886fe96fa3d..67fca2e5f2ab3 100644 --- a/python/taichi/linalg/sparse_solver.py +++ b/python/taichi/linalg/sparse_solver.py @@ -1,8 +1,9 @@ import numpy as np import taichi.lang -from taichi.core.util import ti_core as _ti_core +from taichi._lib import core as _ti_core +from taichi.lang.field import Field from taichi.linalg import SparseMatrix -from taichi.type.primitive_types import f32 +from taichi.types.primitive_types import f32 class SparseSolver: @@ -20,12 +21,13 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"): if solver_type in solver_type_list and ordering in solver_ordering: taichi_arch = taichi.lang.impl.get_runtime().prog.config.arch assert taichi_arch == _ti_core.Arch.x64 or taichi_arch == _ti_core.Arch.arm64, "SparseSolver only supports CPU for now." - self.solver = _ti_core.make_sparse_solver(solver_type, ordering) + self.solver = _ti_core.make_sparse_solver(dtype, solver_type, + ordering) else: assert False, f"The solver type {solver_type} with {ordering} is not supported for now. Only {solver_type_list} with {solver_ordering} are supported." @staticmethod - def type_assert(sparse_matrix): + def _type_assert(sparse_matrix): assert False, f"The parameter type: {type(sparse_matrix)} is not supported in linear solvers for now." def compute(self, sparse_matrix): @@ -37,7 +39,7 @@ def compute(self, sparse_matrix): if isinstance(sparse_matrix, SparseMatrix): self.solver.compute(sparse_matrix.matrix) else: - self.type_assert(sparse_matrix) + self._type_assert(sparse_matrix) def analyze_pattern(self, sparse_matrix): """Reorder the nonzero elements of the matrix, such that the factorization step creates less fill-in. @@ -48,7 +50,7 @@ def analyze_pattern(self, sparse_matrix): if isinstance(sparse_matrix, SparseMatrix): self.solver.analyze_pattern(sparse_matrix.matrix) else: - self.type_assert(sparse_matrix) + self._type_assert(sparse_matrix) def factorize(self, sparse_matrix): """Do the factorization step @@ -59,7 +61,7 @@ def factorize(self, sparse_matrix): if isinstance(sparse_matrix, SparseMatrix): self.solver.factorize(sparse_matrix.matrix) else: - self.type_assert(sparse_matrix) + self._type_assert(sparse_matrix) def solve(self, b): """Computes the solution of the linear systems. @@ -69,12 +71,11 @@ def solve(self, b): Returns: numpy.array: The solution of linear systems. """ - if isinstance(b, taichi.lang.Field): + if isinstance(b, Field): return self.solver.solve(b.to_numpy()) - elif isinstance(b, np.ndarray): + if isinstance(b, np.ndarray): return self.solver.solve(b) - else: - assert False, f"The parameter type: {type(b)} is not supported in linear solvers for now." + assert False, f"The parameter type: {type(b)} is not supported in linear solvers for now." def info(self): """Check if the linear systems are solved successfully. diff --git a/python/taichi/misc/__init__.py b/python/taichi/misc/__init__.py deleted file mode 100644 index b629d01d1ee2c..0000000000000 --- a/python/taichi/misc/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .error import * -from .gui import * -from .image import * -from .task import Task -from .util import * - -__all__ = [s for s in dir() if not s.startswith('_')] diff --git a/python/taichi/misc/error.py b/python/taichi/misc/error.py deleted file mode 100644 index 7f88c853c2b58..0000000000000 --- a/python/taichi/misc/error.py +++ /dev/null @@ -1,48 +0,0 @@ -import functools -import sys -import traceback - -from colorama import Fore, Style - - -def enable_excepthook(): - def excepthook(exctype, value, tb): - skip = 0 - back = 4 - forward = 2 - bar = f'{Fore.LIGHTBLACK_EX}{"-"*44}{Fore.RESET}' - print( - f'{Fore.LIGHTBLACK_EX}========== Taichi Stack Traceback =========={Fore.RESET}' - ) - for frame, lineno in traceback.walk_tb(tb): - name = frame.f_code.co_name - filename = frame.f_code.co_filename - if '_taichi_skip_traceback' in frame.f_locals: - skip = frame.f_locals['_taichi_skip_traceback'] - if skip > 0: - skip -= 1 - continue - print( - f'In {Fore.LIGHTYELLOW_EX}{name}{Fore.RESET}() at {Fore.LIGHTMAGENTA_EX}{filename}{Fore.RESET}:{Fore.LIGHTCYAN_EX}{lineno}{Fore.RESET}:\n{bar}' - ) - with open(filename) as f: - lines = [''] + f.readlines() - if lines[lineno][-1] == '\n': - lines[lineno] = lines[lineno][:-1] - lines[lineno] = f'{Fore.LIGHTRED_EX}' + lines[ - lineno] + f' {Fore.LIGHTYELLOW_EX}<--{Fore.LIGHTBLACK_EX}\n' - line = ''.join(lines[max(1, lineno - - back):min(len(lines), lineno + - forward + 1)]) - if line[-1] != '\n': - line += '\n' - print(f'{Fore.LIGHTWHITE_EX}{line}{bar}') - value = str(value) - if len(value): - value = f': {Fore.LIGHTRED_EX}{value}' - print( - f'{Fore.LIGHTGREEN_EX}{exctype.__name__}{Fore.RESET}{value}{Fore.RESET}' - ) - - if sys.excepthook is not excepthook: - sys.excepthook = excepthook diff --git a/python/taichi/misc/gui.py b/python/taichi/misc/gui.py deleted file mode 100644 index e921f3006043c..0000000000000 --- a/python/taichi/misc/gui.py +++ /dev/null @@ -1,859 +0,0 @@ -import math -import numbers -import os - -import numpy as np -import taichi.lang -from taichi.core import ti_core as _ti_core -from taichi.lang.field import Field, ScalarField - -import taichi as ti - -from .util import core_veci, deprecated - - -class GUI: - """Taichi Graphical User Interface class. - - Args: - name (str, optional): The name of the GUI to be constructed. - Default is 'Taichi'. - res (Union[int, List[int]], optional): The resolution of created - GUI. Default is 512*512. If `res` is scalar, then width will be equal to height. - background_color (int, optional): The background color of created GUI. - Default is 0x000000. - show_gui (bool, optional): Specify whether to render the GUI. Default is True. - fullscreen (bool, optional): Specify whether to render the GUI in - fullscreen mode. Default is False. - fast_gui (bool, optional): Specify whether to use fast gui mode of - Taichi. Default is False. - - Returns: - :class:`~taichi.misc.gui.GUI` :The created taichi GUI object. - - """ - class Event: - pass - - # Event keys - SHIFT = 'Shift' - ALT = 'Alt' - CTRL = 'Control' - ESCAPE = 'Escape' - RETURN = 'Return' - TAB = 'Tab' - BACKSPACE = 'BackSpace' - SPACE = ' ' - UP = 'Up' - DOWN = 'Down' - LEFT = 'Left' - RIGHT = 'Right' - CAPSLOCK = 'Caps_Lock' - LMB = 'LMB' - MMB = 'MMB' - RMB = 'RMB' - EXIT = 'WMClose' - WHEEL = 'Wheel' - MOVE = 'Motion' - - # Event types - MOTION = _ti_core.KeyEvent.EType.Move - PRESS = _ti_core.KeyEvent.EType.Press - RELEASE = _ti_core.KeyEvent.EType.Release - - def __init__(self, - name='Taichi', - res=512, - background_color=0x0, - show_gui=True, - fullscreen=False, - fast_gui=False): - show_gui = self.get_bool_environ('TI_GUI_SHOW', show_gui) - fullscreen = self.get_bool_environ('TI_GUI_FULLSCREEN', fullscreen) - fast_gui = self.get_bool_environ('TI_GUI_FAST', fast_gui) - - self.name = name - if isinstance(res, numbers.Number): - res = (res, res) - self.res = res - self.fast_gui = fast_gui - if fast_gui: - self.img = np.ascontiguousarray( - np.zeros(self.res[0] * self.res[1], dtype=np.uint32)) - fast_buf = self.img.ctypes.data - else: - # The GUI canvas uses RGBA for storage, therefore we need NxMx4 for an image. - self.img = np.ascontiguousarray( - np.zeros(self.res + (4, ), np.float32)) - fast_buf = 0 - self.core = _ti_core.GUI(name, core_veci(*res), show_gui, fullscreen, - fast_gui, fast_buf) - self.canvas = self.core.get_canvas() - self.background_color = background_color - self.key_pressed = set() - self.event = None - self.frame = 0 - self.clear() - - def __enter__(self): - return self - - def __exit__(self, type, val, tb): - self.close() - - def __del__(self): - self.close() - - def close(self): - self.core = None # dereference to call GUI::~GUI() - - ## Widget system - - class WidgetValue: - def __init__(self, gui, wid): - self.gui = gui - self.wid = wid - - @property - def value(self): - return self.gui.core.get_widget_value(self.wid) - - @value.setter - def value(self, value): - self.gui.core.set_widget_value(self.wid, value) - - def get_bool_environ(self, key, default): - """Get an environment variable and cast to bool. - Args: - key (str): The environment variable key. - default (bool): The default value. - Return: - The environment variable value cast to bool. If the value is not found, directly return argument 'default'. - """ - if key not in os.environ: - return default - return bool(int(os.environ[key])) - - def slider(self, text, minimum, maximum, step=1): - """Create a slider object on canvas to be manipulated with. - - Args: - text (str): The title of slider. - minimum (Number): The minimum value of slider. - maximum (Number): The maximum value of slider. - step (Number, optional): The changing step of slider. Default is 1. - - Return: - :class:`~taichi.misc.gui.GUI.WidgetValue` :The created slider object. - - """ - wid = self.core.make_slider(text, minimum, minimum, maximum, step) - return GUI.WidgetValue(self, wid) - - def label(self, text): - """Create a label object on canvas. - - Args: - text (str): The title of label. - - Return: - :class:`~taichi.misc.gui.GUI.WidgetValue` :The created label object. - - """ - wid = self.core.make_label(text, 0) - return GUI.WidgetValue(self, wid) - - def button(self, text, event_name=None): - """Create a button object on canvas to be manipulated with. - - Args: - text (str): The title of button. - event_name (str, optional): The event name associated with button. - Default is WidgetButton_{text} - - Return: - The event name associated with created button. - - """ - event_name = event_name or f'WidgetButton_{text}' - self.core.make_button(text, event_name) - return event_name - - ## Drawing system - - def clear(self, color=None): - """Clear the canvas with the color provided. - - Args: - color (int, optional): Specify the color to clear the canvas. Default - is the background color of GUI. - - """ - if color is None: - color = self.background_color - self.canvas.clear(color) - - def cook_image(self, img): - if img.dtype in [np.uint8, np.uint16, np.uint32, np.uint64]: - img = img.astype(np.float32) * (1 / np.iinfo(img.dtype).max) - elif img.dtype in [np.float32, np.float64]: - img = img.astype(np.float32) - else: - raise ValueError( - f'Data type {img.dtype} not supported in GUI.set_image') - - if len(img.shape) == 2: - img = img[..., None] - - if img.shape[2] == 1: - img = img + np.zeros((1, 1, 4), np.float32) - if img.shape[2] == 3: - zeros = np.zeros((img.shape[0], img.shape[1], 1), np.float32) - img = np.concatenate([img, zeros], axis=2) - if img.shape[2] == 2: - zeros = np.zeros((img.shape[0], img.shape[1], 2), np.float32) - img = np.concatenate([img, zeros], axis=2) - - assert img.shape[2] == 4, "Image must be grayscale, RG, RGB or RGBA" - - res = img.shape[:2] - assert res == self.res, "Image resolution does not match GUI resolution" - return np.ascontiguousarray(img) - - def get_image(self): - """Get the image data. - - Returns: - :class:`numpy.array` :The image data in numpy contiguous array type. - - """ - self.img = np.ascontiguousarray(self.img) - self.core.get_img(self.img.ctypes.data) - return self.img - - def set_image(self, img): - """Draw an image on canvas. - - Args: - img (Union[ti.field, numpy.array]): The color array representing the - image to be drawn. Support greyscale, RG, RGB, and RGBA color - representations. Its shape must match GUI resolution. - - """ - - if self.fast_gui: - assert isinstance(img, taichi.lang.matrix.MatrixField), \ - "Only ti.Vector.field is supported in GUI.set_image when fast_gui=True" - assert img.shape == self.res, \ - "Image resolution does not match GUI resolution" - assert img.n in [3, 4] and img.m == 1, \ - "Only RGB images are supported in GUI.set_image when fast_gui=True" - assert img.dtype in [ti.f32, ti.f64, ti.u8], \ - "Only f32, f64, u8 are supported in GUI.set_image when fast_gui=True" - - taichi.lang.meta.vector_to_fast_image(img, self.img) - return - - if isinstance(img, ScalarField): - if _ti_core.is_integral(img.dtype) or len(img.shape) != 2: - # Images of uint is not optimized by xxx_to_image - self.img = self.cook_image(img.to_numpy()) - else: - # Type matched! We can use an optimized copy kernel. - assert img.shape \ - == self.res, "Image resolution does not match GUI resolution" - taichi.lang.meta.tensor_to_image(img, self.img) - ti.sync() - - elif isinstance(img, taichi.lang.matrix.MatrixField): - if _ti_core.is_integral(img.dtype): - self.img = self.cook_image(img.to_numpy()) - else: - # Type matched! We can use an optimized copy kernel. - assert img.shape == self.res, \ - "Image resolution does not match GUI resolution" - assert img.n in [2, 3, 4] and img.m == 1, \ - "Only greyscale, RG, RGB or RGBA images are supported in GUI.set_image" - - taichi.lang.meta.vector_to_image(img, self.img) - ti.sync() - - elif isinstance(img, np.ndarray): - self.img = self.cook_image(img) - - else: - raise ValueError( - f"GUI.set_image only takes a Taichi field or NumPy array, not {type(img)}" - ) - - self.core.set_img(self.img.ctypes.data) - - def circle(self, pos, color=0xFFFFFF, radius=1): - """Draw a single circle on canvas. - - Args: - pos (Union[List[int], numpy.array]): The position of the circle. - color (int, Optional): The color of the circle. Default is 0xFFFFFF. - radius (Number, Optional): The radius of the circle. Default is 1. - - """ - self.canvas.circle_single(pos[0], pos[1], color, radius) - - def circles(self, - pos, - radius=1, - color=0xFFFFFF, - palette=None, - palette_indices=None): - """Draw a list of circles on canvas. - - Args: - pos (numpy.array): The positions of the circles. - radius (Number, optional): The radius of the circles. Default is 1. - color (int, optional): The color of the circles. Default is 0xFFFFFF. - palette (list[int], optional): The List of colors from which to - choose to draw. Default is None. - palette_indices (Union[list[int], ti.field, numpy.array], optional): - The List of indices that choose color from palette for each - circle. Shape must match pos. Default is None. - - """ - n = pos.shape[0] - if len(pos.shape) == 3: - assert pos.shape[2] == 1 - pos = pos[:, :, 0] - - assert pos.shape == (n, 2) - pos = np.ascontiguousarray(pos.astype(np.float32)) - # Note: do not use "pos = int(pos.ctypes.data)" here - # Otherwise pos will get garbage collected by Python - # and the pointer to its data becomes invalid - pos_ptr = int(pos.ctypes.data) - - if isinstance(color, np.ndarray): - assert color.shape == (n, ) - color = np.ascontiguousarray(color.astype(np.uint32)) - color_array = int(color.ctypes.data) - color_single = 0 - elif isinstance(color, int): - color_array = 0 - color_single = color - else: - raise ValueError( - 'Color must be an ndarray or int (e.g., 0x956333)') - - if palette is not None: - assert palette_indices is not None, 'palette must be used together with palette_indices' - - if isinstance(palette_indices, Field): - ind_int = palette_indices.to_numpy().astype(np.uint32) - elif isinstance(palette_indices, list) or isinstance( - palette_indices, np.ndarray): - ind_int = np.array(palette_indices).astype(np.uint32) - else: - try: - ind_int = np.array(palette_indices) - except: - raise TypeError( - 'palette_indices must be a type that can be converted to numpy.ndarray' - ) - - assert issubclass( - ind_int.dtype.type, - np.integer), 'palette_indices must be an integer array' - assert ind_int.shape == ( - n, - ), 'palette_indices must be in 1-d shape with shape (num_particles, )' - assert min( - ind_int - ) >= 0, 'the min of palette_indices must not be less than zero' - assert max(ind_int) < len( - palette - ), 'the max of palette_indices must not exceed the length of palette' - color_array = np.array(palette, dtype=np.uint32)[ind_int] - color_array = np.ascontiguousarray(color_array) - color_array = color_array.ctypes.data - - if isinstance(radius, np.ndarray): - assert radius.shape == (n, ) - radius = np.ascontiguousarray(radius.astype(np.float32)) - radius_array = int(radius.ctypes.data) - radius_single = 0 - elif isinstance(radius, numbers.Number): - radius_array = 0 - radius_single = radius - else: - raise ValueError('Radius must be an ndarray or float (e.g., 0.4)') - - self.canvas.circles_batched(n, pos_ptr, color_single, color_array, - radius_single, radius_array) - - def triangles(self, a, b, c, color=0xFFFFFF): - """Draw a list of triangles on canvas. - - Args: - a (numpy.array): The positions of the first points of triangles. - b (numpy.array): The positions of the second points of triangles. - c (numpy.array): The positions of the thrid points of triangles. - color (Union[int, numpy.array], optional): The color or colors of triangles. - Can be either a single color or a list of colors whose shape matches - the shape of a & b & c. Default is 0xFFFFFF. - - """ - assert a.shape == b.shape - assert a.shape == c.shape - n = a.shape[0] - if len(a.shape) == 3: - assert a.shape[2] == 1 - a = a[:, :, 0] - b = b[:, :, 0] - c = c[:, :, 0] - - assert a.shape == (n, 2) - a = np.ascontiguousarray(a.astype(np.float32)) - b = np.ascontiguousarray(b.astype(np.float32)) - c = np.ascontiguousarray(c.astype(np.float32)) - # Note: do not use "a = int(a.ctypes.data)" here - # Otherwise a will get garbage collected by Python - # and the pointer to its data becomes invalid - a_ptr = int(a.ctypes.data) - b_ptr = int(b.ctypes.data) - c_ptr = int(c.ctypes.data) - - if isinstance(color, np.ndarray): - assert color.shape == (n, ) - color = np.ascontiguousarray(color.astype(np.uint32)) - color_array = int(color.ctypes.data) - color_single = 0 - elif isinstance(color, int): - color_array = 0 - color_single = color - else: - raise ValueError( - '"color" must be an ndarray or int (e.g., 0x956333)') - - self.canvas.triangles_batched(n, a_ptr, b_ptr, c_ptr, color_single, - color_array) - - def triangle(self, a, b, c, color=0xFFFFFF): - """Draw a single triangle on canvas. - - Args: - a (List[Number]): The position of the first point of triangle. Shape must be 2. - b (List[Number]): The position of the second point of triangle. Shape must be 2. - c (List[Number]): The position of the third point of triangle. Shape must be 2. - color (int, optional): The color of the triangle. Default is 0xFFFFFF. - - """ - self.canvas.triangle_single(a[0], a[1], b[0], b[1], c[0], c[1], color) - - def lines(self, begin, end, radius=1, color=0xFFFFFF): - """Draw a list of lines on canvas. - - Args: - begin (numpy.array): The positions of one end of lines. - end (numpy.array): The positions of the other end of lines. - radius (Union[Number, numpy.array], optional): The width of lines. - Can be either a single width or a list of width whose shape matches - the shape of begin & end. Default is 1. - color (Union[int, numpy.array], optional): The color or colors of lines. - Can be either a single color or a list of colors whose shape matches - the shape of begin & end. Default is 0xFFFFFF. - - """ - assert begin.shape == end.shape - n = begin.shape[0] - if len(begin.shape) == 3: - assert begin.shape[2] == 1 - begin = begin[:, :, 0] - end = end[:, :, 0] - - assert begin.shape == (n, 2) - begin = np.ascontiguousarray(begin.astype(np.float32)) - end = np.ascontiguousarray(end.astype(np.float32)) - # Note: do not use "begin = int(begin.ctypes.data)" here - # Otherwise begin will get garbage collected by Python - # and the pointer to its data becomes invalid - begin_ptr = int(begin.ctypes.data) - end_ptr = int(end.ctypes.data) - - if isinstance(color, np.ndarray): - assert color.shape == (n, ) - color = np.ascontiguousarray(color.astype(np.uint32)) - color_array = int(color.ctypes.data) - color_single = 0 - elif isinstance(color, int): - color_array = 0 - color_single = color - else: - raise ValueError( - 'Color must be an ndarray or int (e.g., 0x956333)') - - if isinstance(radius, np.ndarray): - assert radius.shape == (n, ) - radius = np.ascontiguousarray(radius.astype(np.float32)) - radius_array = int(radius.ctypes.data) - radius_single = 0 - elif isinstance(radius, numbers.Number): - radius_array = 0 - radius_single = radius - else: - raise ValueError('Radius must be an ndarray or float (e.g., 0.4)') - - self.canvas.paths_batched(n, begin_ptr, end_ptr, color_single, - color_array, radius_single, radius_array) - - def line(self, begin, end, radius=1, color=0xFFFFFF): - """Draw a single line on canvas. - - Args: - begin (List[Number]): The position of one end of line. Shape must be 2. - end (List[Number]): The position of the other end of line. Shape must be 2. - radius (Number, optional): The width of line. Default is 1. - color (int, optional): The color of line. Default is 0xFFFFFF. - - """ - self.canvas.path_single(begin[0], begin[1], end[0], end[1], color, - radius) - - @staticmethod - def _arrow_to_lines(orig, major, tip_scale=0.2, angle=45): - angle = math.radians(180 - angle) - c, s = math.cos(angle), math.sin(angle) - minor1 = np.array([ - major[:, 0] * c - major[:, 1] * s, - major[:, 0] * s + major[:, 1] * c - ]).swapaxes(0, 1) - minor2 = np.array([ - major[:, 0] * c + major[:, 1] * s, - -major[:, 0] * s + major[:, 1] * c - ]).swapaxes(0, 1) - end = orig + major - return [(orig, end), (end, end + minor1 * tip_scale), - (end, end + minor2 * tip_scale)] - - def arrows(self, orig, dir, radius=1, color=0xffffff, **kwargs): - """Draw a list arrows on canvas. - - Args: - orig (numpy.array): The positions where arrows start. - dir (numpy.array): The directions where arrows point to. - radius (Union[Number, np.array], optional): The width of arrows. Default is 1. - color (Union[int, np.array], optional): The color or colors of arrows. Default is 0xffffff. - - """ - for begin, end in self._arrow_to_lines(orig, dir, **kwargs): - self.lines(begin, end, radius, color) - - def arrow(self, orig, dir, radius=1, color=0xffffff, **kwargs): - """Draw a single arrow on canvas. - - Args: - orig (List[Number]): The position where arrow starts. Shape must be 2. - dir (List[Number]): The direction where arrow points to. Shape must be 2. - radius (Number, optional): The width of arrow. Default is 1. - color (int, optional): The color of arrow. Default is 0xFFFFFF. - - """ - orig = np.array([orig]) - dir = np.array([dir]) - for begin, end in self._arrow_to_lines(orig, dir, **kwargs): - self.line(begin[0], end[0], radius, color) - - def rect(self, topleft, bottomright, radius=1, color=0xFFFFFF): - """Draw a single rectangle on canvas. - - Args: - topleft (List[Number]): The position of the topleft corner of rectangle. - Shape must be 2. - bottomright (List[Number]): The position of the bottomright corner - of rectangle. Shape must be 2. - radius (Number, optional): The width of rectangle's sides. Default is 1. - color (int, optional): The color of rectangle. Default is 0xFFFFFF. - - """ - a = topleft[0], topleft[1] - b = bottomright[0], topleft[1] - c = bottomright[0], bottomright[1] - d = topleft[0], bottomright[1] - self.line(a, b, radius, color) - self.line(b, c, radius, color) - self.line(c, d, radius, color) - self.line(d, a, radius, color) - - def text(self, content, pos, font_size=15, color=0xFFFFFF): - """Draw texts on canvas. - - Args: - content (str): The text to be drawn on canvas. - pos (List[Number]): The position where the text is to be put. - font_size (Number, optional): The font size of the text. - color (int, optional): The color of the text. Default is 0xFFFFFF. - - """ - - # TODO: refactor Canvas::text - font_size = float(font_size) - pos = ti.core_vec(*pos) - r, g, b = hex_to_rgb(color) - color = ti.core_vec(r, g, b, 1) - self.canvas.text(content, pos, font_size, color) - - @staticmethod - def _make_field_base(w, h, bound): - x = np.linspace(bound / w, 1 - bound / w, w) - y = np.linspace(bound / h, 1 - bound / h, h) - base = np.array(np.meshgrid(x, y)) - base = base.swapaxes(0, 1).swapaxes(1, 2).swapaxes(0, 1) - return base.reshape(w * h, 2) - - def point_field(self, radius, color=0xffffff, bound=0.5): - """Draw a field of points on canvas. - - Args: - radius (np.array): The pattern and radius of the field of points. - color (Union[int, np.array], optional): The color or colors of points. - Default is 0xFFFFFF. - bound (Number, optional): The boundary of the field. Default is 0.5. - - """ - assert len(radius.shape) == 2 - base = self._make_field_base(radius.shape[0], radius.shape[1], bound) - radius = radius.reshape(radius.shape[0] * radius.shape[1]) - self.circles(base, radius=radius, color=color) - - def arrow_field(self, dir, radius=1, color=0xffffff, bound=0.5, **kwargs): - """Draw a field of arrows on canvas. - - Args: - dir (np.array): The pattern and direction of the field of arrows. - color (Union[int, np.array], optional): The color or colors of arrows. - Default is 0xFFFFFF. - bound (Number, optional): The boundary of the field. Default is 0.5. - - """ - assert len(dir.shape) == 3 - assert dir.shape[2] == 2 - base = self._make_field_base(dir.shape[0], dir.shape[1], bound) - dir = dir.reshape(dir.shape[0] * dir.shape[1], 2) - self.arrows(base, dir, radius=radius, color=color, **kwargs) - - def show(self, file=None): - """Show the frame or save current frame as a picture. - - Args: - file (str, optional): The path & name of the picture to be saved. - Default is None. - - """ - self.core.update() - if file: - self.core.screenshot(file) - self.frame += 1 - self.clear() - - ## Event system - - class EventFilter: - def __init__(self, *filter): - self.filter = set() - for ent in filter: - if isinstance(ent, (list, tuple)): - type, key = ent - ent = (type, key) - self.filter.add(ent) - - def match(self, e): - if (e.type, e.key) in self.filter: - return True - if e.type in self.filter: - return True - if e.key in self.filter: - return True - return False - - def has_key_event(self): - """Check if there are any key event registered. - - Returns: - Bool to indicate whether there is any key event registered. - - """ - return self.core.has_key_event() - - def get_event(self, *filter): - """Check if the specific event is triggered. - - Args: - *filter (ti.GUI.EVENT): The specific event to be checked. - - Returns: - Bool to indicate whether the specific event is triggered. - - """ - for e in self.get_events(*filter): - self.event = e - return True - else: - return False - - def get_events(self, *filter): - """Get a list of events that are triggered. - - Args: - *filter (List[ti.GUI.EVENT]): The type of events to be filtered. - - Returns: - :class:`~taichi.misc.gui.GUI.EVENT` :A list of events that are triggered. - - """ - filter = filter and GUI.EventFilter(*filter) or None - - while True: - if not self.has_key_event(): - break - e = self.get_key_event() - if filter is None or filter.match(e): - yield e - - def get_key_event(self): - """Get keyboard triggered event. - - Returns: - :class:`~taichi.misc.gui.GUI.EVENT` :The keyboard triggered event. - - """ - self.core.wait_key_event() - - e = GUI.Event() - event = self.core.get_key_event_head() - - e.type = event.type - e.key = event.key - e.pos = self.core.canvas_untransform(event.pos) - e.pos = (e.pos[0], e.pos[1]) - e.modifier = [] - - if e.key == GUI.WHEEL: - e.delta = event.delta - else: - e.delta = (0, 0) - - for mod in ['Shift', 'Alt', 'Control']: - if self.is_pressed(mod): - e.modifier.append(mod) - - if e.type == GUI.PRESS: - self.key_pressed.add(e.key) - else: - self.key_pressed.discard(e.key) - - self.core.pop_key_event_head() - return e - - def is_pressed(self, *keys): - """Check if the specific key or keys are pressed. - - Args: - *keys (Union[str, List[str]]): The string that stands for keys in keyboard. - - Returns: - Bool to indicate whether the key or keys are pressed. - - """ - for key in keys: - if key in ['Shift', 'Alt', 'Control']: - if key + '_L' in self.key_pressed or key + '_R' in self.key_pressed: - return True - if key in self.key_pressed: - return True - else: - return False - - def get_cursor_pos(self): - """Get the current position of mouse. - - Returns: - The current position of mouse. - - """ - pos = self.core.get_cursor_pos() - return pos[0], pos[1] - - @deprecated('gui.has_key_pressed()', 'gui.get_event()') - def has_key_pressed(self): - if self.has_key_event(): - self.get_key_event() # pop to update self.key_pressed - return len(self.key_pressed) != 0 - - @property - def running(self): - """Get the property of whether the gui is running. - - Returns: - The running property of gui(bool). - - """ - return not self.core.should_close - - @running.setter - def running(self, value): - if value: - self.core.should_close = 0 - elif not self.core.should_close: - self.core.should_close = 1 - - @property - def fps_limit(self): - """Get the property of fps limit. - - Returns: - The property of fps limit of gui. - - """ - if self.core.frame_delta_limit == 0: - return None - else: - return 1 / self.core.frame_delta_limit - - @fps_limit.setter - def fps_limit(self, value): - if value is None: - self.core.frame_delta_limit = 0 - else: - self.core.frame_delta_limit = 1 / value - - -def rgb_to_hex(c): - """Convert rgb color format to hex color format. - - Args: - c (List[int]): The rgb representation of color. - - Returns: - The hex representation of color. - - """ - to255 = lambda x: np.clip(np.int32(x * 255), 0, 255) - return (to255(c[0]) << 16) + (to255(c[1]) << 8) + to255(c[2]) - - -def hex_to_rgb(color): - """Convert hex color format to rgb color format. - - Args: - color (int): The hex representation of color. - - Returns: - The rgb representation of color. - - """ - r, g, b = (color >> 16) & 0xff, (color >> 8) & 0xff, color & 0xff - return r / 255, g / 255, b / 255 - - -__all__ = [ - 'GUI', - 'rgb_to_hex', - 'hex_to_rgb', -] diff --git a/python/taichi/misc/task.py b/python/taichi/misc/task.py deleted file mode 100644 index e06231ec69d30..0000000000000 --- a/python/taichi/misc/task.py +++ /dev/null @@ -1,36 +0,0 @@ -from taichi.core import ti_core as _ti_core -from taichi.misc.util import config_from_dict - - -def _unit(unit_name): - def decorator(target_class): - if target_class.__init__ != object.__init__: - original_init = target_class.__init__ - else: - - def dummy_init(*args, **kwargs): - pass - - original_init = dummy_init - - def new_init(self, name, *args, **kwargs): - self.c = getattr(_ti_core, 'create_' + unit_name)(name) - self.c.initialize(config_from_dict(kwargs)) - original_init(self, *args, **kwargs) - - target_class.__init__ = new_init - - def new_getattr_(self, key): - return self.c.__getattribute__(key) - - target_class.__getattr__ = new_getattr_ - - return target_class - - return decorator - - -@_unit('task') -class Task: - def run(self, *args): - return self.c.run(args) diff --git a/python/taichi/misc/util.py b/python/taichi/misc/util.py deleted file mode 100644 index 21e81253c3028..0000000000000 --- a/python/taichi/misc/util.py +++ /dev/null @@ -1,273 +0,0 @@ -import copy -import functools -import subprocess -import sys -import traceback - -from colorama import Fore, Style -from taichi.core import ti_core as _ti_core - -import taichi as ti - - -def config_from_dict(args): - d = copy.copy(args) - for k in d: - if isinstance(d[k], _ti_core.Vector2f): - d[k] = '({}, {})'.format(d[k].x, d[k].y) - if isinstance(d[k], _ti_core.Vector3f): - d[k] = '({}, {}, {})'.format(d[k].x, d[k].y, d[k].z) - d[k] = str(d[k]) - return _ti_core.config_from_dict(d) - - -def core_veci(*args): - if isinstance(args[0], _ti_core.Vector2i): - return args[0] - if isinstance(args[0], _ti_core.Vector3i): - return args[0] - if isinstance(args[0], tuple): - args = tuple(*args) - if len(args) == 2: - return _ti_core.Vector2i(int(args[0]), int(args[1])) - elif len(args) == 3: - return _ti_core.Vector3i(int(args[0]), int(args[1]), int(args[2])) - elif len(args) == 4: - return _ti_core.Vector4i(int(args[0]), int(args[1]), int(args[2]), - int(args[3])) - else: - assert False, type(args[0]) - - -def core_vec(*args): - if isinstance(args[0], _ti_core.Vector2f): - return args[0] - if isinstance(args[0], _ti_core.Vector3f): - return args[0] - if isinstance(args[0], _ti_core.Vector4f): - return args[0] - if isinstance(args[0], _ti_core.Vector2d): - return args[0] - if isinstance(args[0], _ti_core.Vector3d): - return args[0] - if isinstance(args[0], _ti_core.Vector4d): - return args[0] - if isinstance(args[0], tuple): - args = tuple(*args) - if _ti_core.get_default_float_size() == 4: - if len(args) == 2: - return _ti_core.Vector2f(float(args[0]), float(args[1])) - elif len(args) == 3: - return _ti_core.Vector3f(float(args[0]), float(args[1]), - float(args[2])) - elif len(args) == 4: - return _ti_core.Vector4f(float(args[0]), float(args[1]), - float(args[2]), float(args[3])) - else: - assert False, type(args[0]) - else: - if len(args) == 2: - return _ti_core.Vector2d(float(args[0]), float(args[1])) - elif len(args) == 3: - return _ti_core.Vector3d(float(args[0]), float(args[1]), - float(args[2])) - elif len(args) == 4: - return _ti_core.Vector4d(float(args[0]), float(args[1]), - float(args[2]), float(args[3])) - else: - assert False, type(args[0]) - - -class Tee(): - def __init__(self, name): - self.file = open(name, 'w') - self.stdout = sys.stdout - self.stderr = sys.stderr - sys.stdout = self - sys.stderr = self - - def __del__(self): - self.file.close() - - def write(self, data): - self.file.write(data) - self.stdout.write(data) - self.file.flush() - self.stdout.flush() - - def write_to_file(self, data): - self.file.write(data) - - -# The builtin `warnings` module is unreliable since it may be suppressed -# by other packages such as IPython. -def warning(msg, type=UserWarning, stacklevel=1): - """Print warning message - - Args: - msg (str): massage to print. - type (builtin warning type): type of warning. - stacklevel (int): warning stack level from the caller. - """ - s = traceback.extract_stack()[:-stacklevel] - raw = ''.join(traceback.format_list(s)) - print(Fore.YELLOW + Style.BRIGHT, end='') - print(f'{type.__name__}: {msg}') - print(f'\n{raw}') - print(Style.RESET_ALL, end='') - - -def deprecated(old, new, warning_type=DeprecationWarning): - """Mark an API as deprecated. - - Args: - old (str): old method. - new (str): new method. - warning_type (builtin warning type): type of warning. - - Example:: - - >>> @deprecated('ti.sqr(x)', 'x**2') - >>> def sqr(x): - >>> return x**2 - - Returns: - Decorated fuction with warning message - """ - def decorator(foo): - @functools.wraps(foo) - def wrapped(*args, **kwargs): - _taichi_skip_traceback = 1 - msg = f'{old} is deprecated. Please use {new} instead.' - warning(msg, warning_type, stacklevel=2) - return foo(*args, **kwargs) - - return wrapped - - return decorator - - -def obsolete(old, new): - """ - Mark an API as obsolete. Usage: - - sqr = obsolete('ti.sqr(x)', 'x**2') - """ - def wrapped(*args, **kwargs): - _taichi_skip_traceback = 1 - msg = f'{old} is obsolete. Please use {new} instead.' - raise SyntaxError(msg) - - return wrapped - - -def get_traceback(stacklevel=1): - s = traceback.extract_stack()[:-1 - stacklevel] - return ''.join(traceback.format_list(s)) - - -def duplicate_stdout_to_file(fn): - _ti_core.duplicate_stdout_to_file(fn) - - -def set_gdb_trigger(on=True): - _ti_core.set_core_trigger_gdb_when_crash(on) - - -def print_profile_info(): - """Print time elapsed on the host tasks in a hierarchical format. - - This profiler is automatically on. - - Call function imports from C++ : _ti_core.print_profile_info() - - Example:: - - >>> import taichi as ti - >>> ti.init(arch=ti.cpu) - >>> var = ti.field(ti.f32, shape=1) - >>> @ti.kernel - >>> def compute(): - >>> var[0] = 1.0 - >>> print("Setting var[0] =", var[0]) - >>> compute() - >>> ti.print_profile_info() - """ - _ti_core.print_profile_info() - - -def clear_profile_info(): - """Clear profiler's records about time elapsed on the host tasks. - - Call function imports from C++ : _ti_core.clear_profile_info() - """ - _ti_core.clear_profile_info() - - -@deprecated('ti.vec(x, y)', 'ti.core_vec(x, y)') -def vec(*args, **kwargs): - return core_vec(*args, **kwargs) - - -@deprecated('ti.veci(x, y)', 'ti.core_veci(x, y)') -def veci(*args, **kwargs): - return core_veci(*args, **kwargs) - - -def dump_dot(filepath=None, rankdir=None, embed_states_threshold=0): - d = _ti_core.dump_dot(rankdir, embed_states_threshold) - if filepath is not None: - with open(filepath, 'w') as fh: - fh.write(d) - return d - - -def dot_to_pdf(dot, filepath): - assert filepath.endswith('.pdf') - p = subprocess.Popen(['dot', '-Tpdf'], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE) - pdf_contents = p.communicate(input=dot.encode())[0] - with open(filepath, 'wb') as fh: - fh.write(pdf_contents) - - -def get_kernel_stats(): - return _ti_core.get_kernel_stats() - - -def print_async_stats(include_kernel_profiler=False): - if include_kernel_profiler: - ti.print_kernel_profile_info() - print() - stat = ti.get_kernel_stats() - counters = stat.get_counters() - print('=======================') - print('Async benchmark metrics') - print('-----------------------') - print(f'Async mode: {ti.current_cfg().async_mode}') - print(f'Kernel time: {ti.kernel_profiler_total_time():.3f} s') - print(f'Tasks launched: {int(counters["launched_tasks"])}') - print(f'Instructions emitted: {int(counters["codegen_statements"])}') - print(f'Tasks compiled: {int(counters["codegen_offloaded_tasks"])}') - NUM_FUSED_TASKS_KEY = 'num_fused_tasks' - if NUM_FUSED_TASKS_KEY in counters: - print(f'Tasks fused: {int(counters["num_fused_tasks"])}') - print('=======================') - - -__all__ = [ - 'vec', - 'veci', - 'core_vec', - 'core_veci', - 'deprecated', - 'dump_dot', - 'dot_to_pdf', - 'obsolete', - 'get_kernel_stats', - 'get_traceback', - 'set_gdb_trigger', - 'print_profile_info', - 'clear_profile_info', -] diff --git a/python/taichi/profiler/__init__.py b/python/taichi/profiler/__init__.py index 0aad976a6cfb5..780b2a937f507 100644 --- a/python/taichi/profiler/__init__.py +++ b/python/taichi/profiler/__init__.py @@ -1,3 +1,4 @@ -from taichi.profiler.kernelprofiler import \ - KernelProfiler # import for docstring-gen -from taichi.profiler.kernelprofiler import get_default_kernel_profiler +from taichi.profiler.kernel_metrics import * +from taichi.profiler.kernel_profiler import * +from taichi.profiler.memory_profiler import * +from taichi.profiler.scoped_profiler import * diff --git a/python/taichi/profiler/kernelmetrics.py b/python/taichi/profiler/kernel_metrics.py similarity index 71% rename from python/taichi/profiler/kernelmetrics.py rename to python/taichi/profiler/kernel_metrics.py index 909f1f06874fd..a8a837d3b0fcc 100644 --- a/python/taichi/profiler/kernelmetrics.py +++ b/python/taichi/profiler/kernel_metrics.py @@ -1,21 +1,18 @@ -from dataclasses import dataclass +from taichi._lib import core as _ti_core -from taichi.core import ti_core as _ti_core - -@dataclass class CuptiMetric: - """A data class to add CUPTI metric for :class:`~taichi.lang.KernelProfiler`. + """A class to add CUPTI metric for :class:`~taichi.profiler.kernel_profiler.KernelProfiler`. - This data class is designed to add user selected CUPTI metrics. + This class is designed to add user selected CUPTI metrics. Only available for the CUDA backend now, i.e. you need ``ti.init(kernel_profiler=True, arch=ti.cuda)``. - For usage of this class, see examples in func :func:`~taichi.lang.set_kernel_profile_metrics` and :func:`~taichi.lang.collect_kernel_profile_metrics`. + For usage of this class, see examples in func :func:`~taichi.profiler.set_kernel_profiler_metrics` and :func:`~taichi.profiler.collect_kernel_profiler_metrics`. Args: - name (str): name of metric that collected by CUPTI toolkit. used by :func:`~taichi.lang.set_kernel_profile_metrics` and :func:`~taichi.lang.collect_kernel_profile_metrics`. - header (str): column header of this metric, used by :func:`~taichi.lang.print_kernel_profile_info`. - format (str): format for print metric value (and unit of this value), used by :func:`~taichi.lang.print_kernel_profile_info`. - scale (float): scale of metric value, used by :func:`~taichi.lang.print_kernel_profile_info`. + name (str): name of metric that collected by CUPTI toolkit. used by :func:`~taichi.profiler.set_kernel_profiler_metrics` and :func:`~taichi.profiler.collect_kernel_profiler_metrics`. + header (str): column header of this metric, used by :func:`~taichi.profiler.print_kernel_profiler_info`. + val_format (str): format for print metric value (and unit of this value), used by :func:`~taichi.profiler.print_kernel_profiler_info`. + scale (float): scale of metric value, used by :func:`~taichi.profiler.print_kernel_profiler_info`. Example:: @@ -33,62 +30,67 @@ class CuptiMetric: >>> for i in x: >>> y[None] += x[i] - >>> global_op_atom = ti.CuptiMetric( + >>> global_op_atom = ti.profiler.CuptiMetric( >>> name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_atom.sum', >>> header=' global.atom ', - >>> format=' {:8.0f} ') + >>> val_format=' {:8.0f} ') >>> # add and set user defined metrics - >>> profiling_metrics = ti.get_predefined_cupti_metrics('global_access') + [global_op_atom] - >>> ti.set_kernel_profile_metrics(profiling_metrics) + >>> profiling_metrics = ti.profiler.get_predefined_cupti_metrics('global_access') + [global_op_atom] + >>> ti.profiler.set_kernel_profile_metrics(profiling_metrics) >>> for i in range(16): >>> reduction() - >>> ti.print_kernel_profile_info('trace') + >>> ti.profiler.print_kernel_profiler_info('trace') Note: For details about using CUPTI in Taichi, please visit https://docs.taichi.graphics/docs/lang/articles/misc/profiler#advanced-mode. """ - name: str = '' - header: str = '' - format: str = '' - scale: float = 1.0 + def __init__(self, + name='', + header='unnamed_header', + val_format=' {:8.0f} ', + scale=1.0): + self.name = name + self.header = header + self.val_format = val_format + self.scale = scale # Global Memory Metrics dram_utilization = CuptiMetric( name='dram__throughput.avg.pct_of_peak_sustained_elapsed', header=' global.uti ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') dram_bytes_sum = CuptiMetric(name='dram__bytes.sum', header=' global.R&W ', - format='{:9.3f} MB ', + val_format='{:9.3f} MB ', scale=1.0 / 1024 / 1024) dram_bytes_throughput = CuptiMetric(name='dram__bytes.sum.per_second', header=' global.R&W/s ', - format='{:8.3f} GB/s ', + val_format='{:8.3f} GB/s ', scale=1.0 / 1024 / 1024 / 1024) dram_bytes_read = CuptiMetric(name='dram__bytes_read.sum', header=' global.R ', - format='{:8.3f} MB ', + val_format='{:8.3f} MB ', scale=1.0 / 1024 / 1024) dram_read_throughput = CuptiMetric(name='dram__bytes_read.sum.per_second', header=' global.R/s ', - format='{:8.3f} GB/s ', + val_format='{:8.3f} GB/s ', scale=1.0 / 1024 / 1024 / 1024) dram_bytes_write = CuptiMetric(name='dram__bytes_write.sum', header=' global.W ', - format='{:8.3f} MB ', + val_format='{:8.3f} MB ', scale=1.0 / 1024 / 1024) dram_write_throughput = CuptiMetric(name='dram__bytes_write.sum.per_second', header=' global.W/s ', - format='{:8.3f} GB/s ', + val_format='{:8.3f} GB/s ', scale=1.0 / 1024 / 1024 / 1024) # Shared Memory Metrics @@ -96,73 +98,73 @@ class CuptiMetric: name= 'l1tex__data_pipe_lsu_wavefronts_mem_shared.avg.pct_of_peak_sustained_elapsed', header=' uti.shared ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') shared_transactions_load = CuptiMetric( name='l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum', header=' shared.trans.W ', - format=' {:10.0f} ') + val_format=' {:10.0f} ') shared_transactions_store = CuptiMetric( name='l1tex__data_pipe_lsu_wavefronts_mem_shared_op_st.sum', header=' shared.trans.R ', - format=' {:10.0f} ') + val_format=' {:10.0f} ') shared_bank_conflicts_store = CuptiMetric( name='l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum', header=' bank.conflict.W ', - format=' {:10.0f} ') + val_format=' {:10.0f} ') shared_bank_conflicts_load = CuptiMetric( name='l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum', header=' bank.conflict.R ', - format=' {:10.0f} ') + val_format=' {:10.0f} ') # Atomic Metrics global_op_atom = CuptiMetric( name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_atom.sum', header=' global.atom ', - format=' {:8.0f} ') + val_format=' {:8.0f} ') global_op_reduction = CuptiMetric( name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_red.sum', header=' global.red ', - format=' {:8.0f} ') + val_format=' {:8.0f} ') # Hardware Utilization Metrics sm_throughput = CuptiMetric( name='sm__throughput.avg.pct_of_peak_sustained_elapsed', header=' core.uti ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') dram_throughput = CuptiMetric( name='gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed', header=' mem.uti ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') l1tex_throughput = CuptiMetric( name='l1tex__throughput.avg.pct_of_peak_sustained_elapsed', header=' L1.uti ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') l2_throughput = CuptiMetric( name='lts__throughput.avg.pct_of_peak_sustained_elapsed', header=' L2.uti ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') # Misc Metrics l1_hit_rate = CuptiMetric(name='l1tex__t_sector_hit_rate.pct', header=' L1.hit ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') l2_hit_rate = CuptiMetric(name='lts__t_sector_hit_rate.pct', header=' L2.hit ', - format=' {:6.2f} % ') + val_format=' {:6.2f} % ') achieved_occupancy = CuptiMetric( name='sm__warps_active.avg.pct_of_peak_sustained_active', header=' occupancy', - format=' {:6.0f} ') + val_format=' {:6.0f} ') # metric suite: global load & store global_access = [ @@ -219,9 +221,10 @@ def get_predefined_cupti_metrics(name=''): for key in predefined_cupti_metrics: _ti_core.warn(f" '{key}'") return None - else: - return predefined_cupti_metrics[name] + return predefined_cupti_metrics[name] # Default metrics list default_cupti_metrics = [dram_bytes_sum] + +__all__ = ['CuptiMetric', 'get_predefined_cupti_metrics'] diff --git a/python/taichi/profiler/kernelprofiler.py b/python/taichi/profiler/kernel_profiler.py similarity index 55% rename from python/taichi/profiler/kernelprofiler.py rename to python/taichi/profiler/kernel_profiler.py index 558bb08856d14..21615b5f06dfc 100644 --- a/python/taichi/profiler/kernelprofiler.py +++ b/python/taichi/profiler/kernel_profiler.py @@ -1,10 +1,8 @@ from contextlib import contextmanager -from taichi.core import ti_core as _ti_core +from taichi._lib import core as _ti_core from taichi.lang import impl -from taichi.profiler.kernelmetrics import default_cupti_metrics - -import taichi as ti +from taichi.profiler.kernel_metrics import default_cupti_metrics class StatisticalResult: @@ -42,7 +40,7 @@ class KernelProfiler: """Kernel profiler of Taichi. Kernel profiler acquires kernel profiling records from backend, counts records in Python scope, - and prints the results to the console by :func:`~taichi.profiler.kernelprofiler.KernelProfiler.print_info`. + and prints the results to the console by :func:`~taichi.profiler.kernel_profiler.KernelProfiler.print_info`. ``KernelProfiler`` now support detailed low-level performance metrics (such as memory bandwidth consumption) in its advanced mode. This mode is only available for the CUDA backend with CUPTI toolkit, i.e. you need ``ti.init(kernel_profiler=True, arch=ti.cuda)``. @@ -52,6 +50,7 @@ class KernelProfiler: """ def __init__(self): self._profiling_mode = False + self._profiling_toolkit = 'default' self._metric_list = [default_cupti_metrics] self._total_time_ms = 0.0 self._traced_records = [] @@ -60,7 +59,7 @@ def __init__(self): # public methods def set_kernel_profiler_mode(self, mode=False): - """Turn on or off :class:`~taichi.profiler.kernelprofiler.KernelProfiler`.""" + """Turn on or off :class:`~taichi.profiler.kernel_profiler.KernelProfiler`.""" if type(mode) is bool: self._profiling_mode = mode else: @@ -69,9 +68,22 @@ def set_kernel_profiler_mode(self, mode=False): ) def get_kernel_profiler_mode(self): - """Get status of :class:`~taichi.profiler.kernelprofiler.KernelProfiler`.""" + """Get status of :class:`~taichi.profiler.kernel_profiler.KernelProfiler`.""" return self._profiling_mode + def set_toolkit(self, toolkit_name='default'): + if self._check_not_turned_on_with_warning_message(): + return False + status = impl.get_runtime().prog.set_kernel_profiler_toolkit( + toolkit_name) + if status is True: + self._profiling_toolkit = toolkit_name + else: + _ti_core.warn( + f'Failed to set kernel profiler toolkit ({toolkit_name}) , keep using ({self._profiling_toolkit}).' + ) + return status + def get_total_time(self): """Get elapsed time of all kernels recorded in KernelProfiler. @@ -85,7 +97,7 @@ def get_total_time(self): return self._total_time_ms / 1000 # ms to s def clear_info(self): - """Clear all records both in front-end :class:`~taichi.profiler.kernelprofiler.KernelProfiler` and back-end instance ``KernelProfilerBase``. + """Clear all records both in front-end :class:`~taichi.profiler.kernel_profiler.KernelProfiler` and back-end instance ``KernelProfilerBase``. Note: The values of ``self._profiling_mode`` and ``self._metric_list`` will not be cleared. @@ -93,13 +105,15 @@ def clear_info(self): if self._check_not_turned_on_with_warning_message(): return None #sync first - impl.get_runtime().sync() + impl.get_runtime().prog.sync_kernel_profiler() #then clear backend & frontend info impl.get_runtime().prog.clear_kernel_profile_info() self._clear_frontend() + return None + def query_info(self, name): - """For docsting of this function, see :func:`~taichi.lang.query_kernel_profile_info`.""" + """For docstring of this function, see :func:`~taichi.profiler.query_kernel_profiler_info`.""" if self._check_not_turned_on_with_warning_message(): return None self._update_records() # kernel records @@ -108,7 +122,7 @@ def query_info(self, name): return impl.get_runtime().prog.query_kernel_profile_info(name) def set_metrics(self, metric_list=default_cupti_metrics): - """For docsting of this function, see :func:`~taichi.lang.set_kernel_profile_metrics`.""" + """For docstring of this function, see :func:`~taichi.profiler.set_kernel_profiler_metrics`.""" if self._check_not_turned_on_with_warning_message(): return None self._metric_list = metric_list @@ -117,11 +131,13 @@ def set_metrics(self, metric_list=default_cupti_metrics): impl.get_runtime().prog.reinit_kernel_profiler_with_metrics( metric_name_list) + return None + @contextmanager def collect_metrics_in_context(self, metric_list=default_cupti_metrics): """This function is not exposed to user now. - For usage of this function, see :func:`~taichi.lang.collect_kernel_profile_metrics`. + For usage of this function, see :func:`~taichi.profiler.collect_kernel_profiler_metrics`. """ if self._check_not_turned_on_with_warning_message(): return None @@ -129,6 +145,8 @@ def collect_metrics_in_context(self, metric_list=default_cupti_metrics): yield self self.set_metrics() #back to default metric list + return None + # mode of print_info COUNT = 'count' # print the statistical results (min,max,avg time) of Taichi kernels. TRACE = 'trace' # print the records of launched Taichi kernels with specific profiling metrics (time, memory load/store and core utilization etc.) @@ -136,7 +154,7 @@ def collect_metrics_in_context(self, metric_list=default_cupti_metrics): def print_info(self, mode=COUNT): """Print the profiling results of Taichi kernels. - For usage of this function, see :func:`~taichi.lang.print_kernel_profile_info`. + For usage of this function, see :func:`~taichi.profiler.print_kernel_profiler_info`. Args: mode (str): the way to print profiling results. @@ -154,21 +172,22 @@ def print_info(self, mode=COUNT): self._print_kernel_info() else: raise ValueError( - f'Arg `mode` must be of type \'str\', and has the value \'count\' or \'trace\'.' + 'Arg `mode` must be of type \'str\', and has the value \'count\' or \'trace\'.' ) + return None + # private methods def _check_not_turned_on_with_warning_message(self): if self._profiling_mode is False: _ti_core.warn( - f'use \'ti.init(kernel_profiler = True)\' to turn on KernelProfiler.' + 'use \'ti.init(kernel_profiler = True)\' to turn on KernelProfiler.' ) return True - else: - return False + return False def _clear_frontend(self): - """Clear member variables in :class:`~taichi.profiler.kernelprofiler.KernelProfiler`. + """Clear member variables in :class:`~taichi.profiler.kernel_profiler.KernelProfiler`. Note: The values of ``self._profiling_mode`` and ``self._metric_list`` will not be cleared. @@ -179,7 +198,7 @@ def _clear_frontend(self): def _update_records(self): """Acquires kernel records from a backend.""" - impl.get_runtime().sync() + impl.get_runtime().prog.sync_kernel_profiler() self._clear_frontend() self._traced_records = impl.get_runtime( ).prog.get_kernel_profiler_records() @@ -204,8 +223,8 @@ def _count_statistics(self): } def _make_table_header(self, mode): - header_str = f'Kernel Profiler({mode})' - arch_name = f' @ {_ti_core.arch_name(ti.cfg.arch).upper()}' + header_str = f'Kernel Profiler({mode}, {self._profiling_toolkit})' + arch_name = f' @ {_ti_core.arch_name(impl.current_cfg().arch).upper()}' device_name = impl.get_runtime().prog.get_kernel_profiler_device_name() if len(device_name) > 1: # default device_name = ' ' device_name = ' on ' + device_name @@ -334,3 +353,240 @@ def get_default_kernel_profiler(): For data retention purposes, multiple instances will be considered in the future. """ return _ti_kernel_profiler + + +def print_kernel_profiler_info(mode='count'): + """Print the profiling results of Taichi kernels. + + To enable this profiler, set ``kernel_profiler=True`` in ``ti.init()``. + ``'count'`` mode: print the statistics (min,max,avg time) of launched kernels, + ``'trace'`` mode: print the records of launched kernels with specific profiling metrics (time, memory load/store and core utilization etc.), + and defaults to ``'count'``. + + Args: + mode (str): the way to print profiling results. + + Example:: + + >>> import taichi as ti + + >>> ti.init(ti.cpu, kernel_profiler=True) + >>> var = ti.field(ti.f32, shape=1) + + >>> @ti.kernel + >>> def compute(): + >>> var[0] = 1.0 + + >>> compute() + >>> ti.profiler.print_kernel_profiler_info() + >>> # equivalent calls : + >>> # ti.profiler.print_kernel_profiler_info('count') + + >>> ti.profiler.print_kernel_profiler_info('trace') + + Note: + Currently the result of `KernelProfiler` could be incorrect on OpenGL + backend due to its lack of support for `ti.sync()`. + + For advanced mode of `KernelProfiler`, please visit https://docs.taichi.graphics/docs/lang/articles/misc/profiler#advanced-mode. + """ + get_default_kernel_profiler().print_info(mode) + + +def query_kernel_profiler_info(name): + """Query kernel elapsed time(min,avg,max) on devices using the kernel name. + + To enable this profiler, set `kernel_profiler=True` in `ti.init`. + + Args: + name (str): kernel name. + + Returns: + KernelProfilerQueryResult (class): with member variables(counter, min, max, avg) + + Example:: + + >>> import taichi as ti + + >>> ti.init(ti.cpu, kernel_profiler=True) + >>> n = 1024*1024 + >>> var = ti.field(ti.f32, shape=n) + + >>> @ti.kernel + >>> def fill(): + >>> for i in range(n): + >>> var[i] = 0.1 + + >>> fill() + >>> ti.profiler.clear_kernel_profiler_info() #[1] + >>> for i in range(100): + >>> fill() + >>> query_result = ti.profiler.query_kernel_profiler_info(fill.__name__) #[2] + >>> print("kernel excuted times =",query_result.counter) + >>> print("kernel elapsed time(min_in_ms) =",query_result.min) + >>> print("kernel elapsed time(max_in_ms) =",query_result.max) + >>> print("kernel elapsed time(avg_in_ms) =",query_result.avg) + + Note: + [1] To get the correct result, query_kernel_profiler_info() must be used in conjunction with + clear_kernel_profiler_info(). + + [2] Currently the result of `KernelProfiler` could be incorrect on OpenGL + backend due to its lack of support for `ti.sync()`. + """ + return get_default_kernel_profiler().query_info(name) + + +def clear_kernel_profiler_info(): + """Clear all KernelProfiler records.""" + get_default_kernel_profiler().clear_info() + + +def get_kernel_profiler_total_time(): + """Get elapsed time of all kernels recorded in KernelProfiler. + + Returns: + time (float): total time in second. + """ + return get_default_kernel_profiler().get_total_time() + + +def set_kernel_profiler_toolkit(toolkit_name='default'): + """Set the toolkit used by KernelProfiler. + + Currently, we only support toolkits: ``'default'`` and ``'cupti'``. + + Args: + toolkit_name (str): string of toolkit name. + + Returns: + status (bool): whether the setting is successful or not. + + Example:: + + >>> import taichi as ti + + >>> ti.init(arch=ti.cuda, kernel_profiler=True) + >>> x = ti.field(ti.f32, shape=1024*1024) + + >>> @ti.kernel + >>> def fill(): + >>> for i in x: + >>> x[i] = i + + >>> ti.profiler.set_kernel_profiler_toolkit('cupti') + >>> for i in range(100): + >>> fill() + >>> ti.profiler.print_kernel_profiler_info() + + >>> ti.profiler.set_kernel_profiler_toolkit('default') + >>> for i in range(100): + >>> fill() + >>> ti.profiler.print_kernel_profiler_info() + """ + return get_default_kernel_profiler().set_toolkit(toolkit_name) + + +def set_kernel_profiler_metrics(metric_list=default_cupti_metrics): + """Set metrics that will be collected by the CUPTI toolkit. + + Args: + metric_list (list): a list of :class:`~taichi.profiler.CuptiMetric()` instances, default value: :data:`~taichi.profiler.kernel_metrics.default_cupti_metrics`. + + Example:: + + >>> import taichi as ti + + >>> ti.init(kernel_profiler=True, arch=ti.cuda) + >>> ti.profiler.set_kernel_profiler_toolkit('cupti') + >>> num_elements = 128*1024*1024 + + >>> x = ti.field(ti.f32, shape=num_elements) + >>> y = ti.field(ti.f32, shape=()) + >>> y[None] = 0 + + >>> @ti.kernel + >>> def reduction(): + >>> for i in x: + >>> y[None] += x[i] + + >>> # In the case of not pramater, Taichi will print its pre-defined metrics list + >>> ti.profiler.get_predefined_cupti_metrics() + >>> # get Taichi pre-defined metrics + >>> profiling_metrics = ti.profiler.get_predefined_cupti_metrics('shared_access') + + >>> global_op_atom = ti.profiler.CuptiMetric( + >>> name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_atom.sum', + >>> header=' global.atom ', + >>> format=' {:8.0f} ') + >>> # add user defined metrics + >>> profiling_metrics += [global_op_atom] + + >>> # metrics setting will be retained until the next configuration + >>> ti.profiler.set_kernel_profile_metrics(profiling_metrics) + >>> for i in range(16): + >>> reduction() + >>> ti.profiler.print_kernel_profiler_info('trace') + + Note: + Metrics setting will be retained until the next configuration. + """ + get_default_kernel_profiler().set_metrics(metric_list) + + +@contextmanager +def collect_kernel_profiler_metrics(metric_list=default_cupti_metrics): + """Set temporary metrics that will be collected by the CUPTI toolkit within this context. + + Args: + metric_list (list): a list of :class:`~taichi.profiler.CuptiMetric()` instances, default value: :data:`~taichi.profiler.kernel_metrics.default_cupti_metrics`. + + Example:: + + >>> import taichi as ti + + >>> ti.init(kernel_profiler=True, arch=ti.cuda) + >>> ti.profiler.set_kernel_profiler_toolkit('cupti') + >>> num_elements = 128*1024*1024 + + >>> x = ti.field(ti.f32, shape=num_elements) + >>> y = ti.field(ti.f32, shape=()) + >>> y[None] = 0 + + >>> @ti.kernel + >>> def reduction(): + >>> for i in x: + >>> y[None] += x[i] + + >>> # In the case of not pramater, Taichi will print its pre-defined metrics list + >>> ti.profiler.get_predefined_cupti_metrics() + >>> # get Taichi pre-defined metrics + >>> profiling_metrics = ti.profiler.get_predefined_cupti_metrics('device_utilization') + + >>> global_op_atom = ti.profiler.CuptiMetric( + >>> name='l1tex__t_set_accesses_pipe_lsu_mem_global_op_atom.sum', + >>> header=' global.atom ', + >>> format=' {:8.0f} ') + >>> # add user defined metrics + >>> profiling_metrics += [global_op_atom] + + >>> # metrics setting is temporary, and will be clear when exit from this context. + >>> with ti.profiler.collect_kernel_profiler_metrics(profiling_metrics): + >>> for i in range(16): + >>> reduction() + >>> ti.profiler.print_kernel_profiler_info('trace') + + Note: + The configuration of the ``metric_list`` will be clear when exit from this context. + """ + get_default_kernel_profiler().set_metrics(metric_list) + yield get_default_kernel_profiler() + get_default_kernel_profiler().set_metrics() + + +__all__ = [ + 'clear_kernel_profiler_info', 'collect_kernel_profiler_metrics', + 'get_kernel_profiler_total_time', 'print_kernel_profiler_info', + 'query_kernel_profiler_info', 'set_kernel_profiler_metrics', + 'set_kernel_profiler_toolkit' +] diff --git a/python/taichi/profiler/memory_profiler.py b/python/taichi/profiler/memory_profiler.py new file mode 100644 index 0000000000000..0c1030bf742aa --- /dev/null +++ b/python/taichi/profiler/memory_profiler.py @@ -0,0 +1,13 @@ +from taichi.lang.impl import get_runtime + + +def print_memory_profiler_info(): + """Memory profiling tool for LLVM backends with full sparse support. + + This profiler is automatically on. + """ + get_runtime().materialize() + get_runtime().prog.print_memory_profiler_info() + + +__all__ = ['print_memory_profiler_info'] diff --git a/python/taichi/profiler/scoped_profiler.py b/python/taichi/profiler/scoped_profiler.py new file mode 100644 index 0000000000000..c9884993ebe08 --- /dev/null +++ b/python/taichi/profiler/scoped_profiler.py @@ -0,0 +1,34 @@ +from taichi._lib import core as _ti_core + + +def print_scoped_profiler_info(): + """Print time elapsed on the host tasks in a hierarchical format. + + This profiler is automatically on. + + Call function imports from C++ : _ti_core.print_profile_info() + + Example:: + + >>> import taichi as ti + >>> ti.init(arch=ti.cpu) + >>> var = ti.field(ti.f32, shape=1) + >>> @ti.kernel + >>> def compute(): + >>> var[0] = 1.0 + >>> print("Setting var[0] =", var[0]) + >>> compute() + >>> ti.profiler.print_scoped_profiler_info() + """ + _ti_core.print_profile_info() + + +def clear_scoped_profiler_info(): + """Clear profiler's records about time elapsed on the host tasks. + + Call function imports from C++ : _ti_core.clear_profile_info() + """ + _ti_core.clear_profile_info() + + +__all__ = ['print_scoped_profiler_info', 'clear_scoped_profiler_info'] diff --git a/python/taichi/shaders/Circles_vk.vert b/python/taichi/shaders/Circles_vk.vert index c77358f4e7a08..712132ec4e495 100644 --- a/python/taichi/shaders/Circles_vk.vert +++ b/python/taichi/shaders/Circles_vk.vert @@ -3,7 +3,7 @@ layout(location = 0) in vec3 in_position; layout(location = 1) in vec3 in_normal; layout(location = 2) in vec2 in_texcoord; -layout(location = 3) in vec3 in_color; +layout(location = 3) in vec4 in_color; layout(binding = 0) uniform UBO { vec3 color; @@ -25,6 +25,6 @@ void main() { if (ubo.use_per_vertex_color == 0) { selected_color = ubo.color; } else { - selected_color = in_color; + selected_color = in_color.rgb; } } diff --git a/python/taichi/shaders/Circles_vk_frag.spv b/python/taichi/shaders/Circles_vk_frag.spv index e18b1cd09d69d..51146c0b7358b 100644 Binary files a/python/taichi/shaders/Circles_vk_frag.spv and b/python/taichi/shaders/Circles_vk_frag.spv differ diff --git a/python/taichi/shaders/Circles_vk_vert.spv b/python/taichi/shaders/Circles_vk_vert.spv index fc0a8c566bb5b..f511e9dc7e74d 100644 Binary files a/python/taichi/shaders/Circles_vk_vert.spv and b/python/taichi/shaders/Circles_vk_vert.spv differ diff --git a/python/taichi/shaders/Lines_vk.vert b/python/taichi/shaders/Lines_vk.vert index e009eec432945..4176ecc3ea8cc 100644 --- a/python/taichi/shaders/Lines_vk.vert +++ b/python/taichi/shaders/Lines_vk.vert @@ -3,7 +3,7 @@ layout(location = 0) in vec3 in_position; layout(location = 1) in vec3 in_normal; layout(location = 2) in vec2 in_texcoord; -layout(location = 3) in vec3 in_color; +layout(location = 3) in vec4 in_color; layout(location = 0) out vec2 frag_texcoord; @@ -25,6 +25,6 @@ void main() { if (ubo.use_per_vertex_color == 0) { selected_color = ubo.color; } else { - selected_color = in_color; + selected_color = in_color.rgb; } } diff --git a/python/taichi/shaders/Mesh_vk.frag b/python/taichi/shaders/Mesh_vk.frag index 45f788ef1b4dc..7596e10647ef9 100644 --- a/python/taichi/shaders/Mesh_vk.frag +++ b/python/taichi/shaders/Mesh_vk.frag @@ -32,10 +32,10 @@ layout(binding = 1, std430) buffer SSBO { } ssbo; -layout(location = 3) in vec3 selected_color; +layout(location = 3) in vec4 selected_color; vec3 lambertian() { - vec3 ambient = ubo.scene.ambient_light * selected_color; + vec3 ambient = ubo.scene.ambient_light * selected_color.rgb; vec3 result = ambient; for (int i = 0; i < ubo.scene.point_light_count; ++i) { @@ -50,7 +50,7 @@ vec3 lambertian() { else{ factor = max(dot(light_dir, normal), 0); } - vec3 diffuse = factor * selected_color * light_color; + vec3 diffuse = factor * selected_color.rgb * light_color; result += diffuse; } @@ -58,5 +58,5 @@ vec3 lambertian() { } void main() { - out_color = vec4(lambertian(), 1); + out_color = vec4(lambertian(), selected_color.a); } diff --git a/python/taichi/shaders/Mesh_vk.vert b/python/taichi/shaders/Mesh_vk.vert index d3b99471eb705..327d80280b62a 100644 --- a/python/taichi/shaders/Mesh_vk.vert +++ b/python/taichi/shaders/Mesh_vk.vert @@ -3,12 +3,12 @@ layout(location = 0) in vec3 in_position; layout(location = 1) in vec3 in_normal; layout(location = 2) in vec2 in_texcoord; -layout(location = 3) in vec3 in_color; +layout(location = 3) in vec4 in_color; layout(location = 0) out vec3 frag_pos; layout(location = 1) out vec3 frag_normal; layout(location = 2) out vec2 frag_texcoord; -layout(location = 3) out vec3 selected_color; +layout(location = 3) out vec4 selected_color; struct SceneUBO { vec3 camera_pos; @@ -39,7 +39,7 @@ void main() { frag_normal = in_normal; if (ubo.use_per_vertex_color == 0) { - selected_color = ubo.color; + selected_color = vec4(ubo.color, 1.0); } else { selected_color = in_color; } diff --git a/python/taichi/shaders/Mesh_vk_frag.spv b/python/taichi/shaders/Mesh_vk_frag.spv index 137dcc5a80052..3656b926bbfec 100644 Binary files a/python/taichi/shaders/Mesh_vk_frag.spv and b/python/taichi/shaders/Mesh_vk_frag.spv differ diff --git a/python/taichi/shaders/Mesh_vk_vert.spv b/python/taichi/shaders/Mesh_vk_vert.spv index 2b2893e3b8a67..9d12325b1199d 100644 Binary files a/python/taichi/shaders/Mesh_vk_vert.spv and b/python/taichi/shaders/Mesh_vk_vert.spv differ diff --git a/python/taichi/shaders/Particles_vk.frag b/python/taichi/shaders/Particles_vk.frag index 81221d730b6e2..751ee9aa616d5 100644 --- a/python/taichi/shaders/Particles_vk.frag +++ b/python/taichi/shaders/Particles_vk.frag @@ -32,7 +32,7 @@ ssbo; layout(location = 0) out vec4 out_color; layout(location = 0) in vec4 pos_camera_space; -layout(location = 1) in vec3 selected_color; +layout(location = 1) in vec4 selected_color; float project_z(float view_z) { vec4 projected = ubo.scene.projection * vec4(0, 0, view_z, 1); @@ -46,7 +46,7 @@ vec3 to_camera_space(vec3 pos) { // operates in camera space !! vec3 lambertian(vec3 frag_pos, vec3 frag_normal) { - vec3 ambient = ubo.scene.ambient_light * selected_color; + vec3 ambient = ubo.scene.ambient_light * selected_color.rgb; vec3 result = ambient; for (int i = 0; i < ubo.scene.point_light_count; ++i) { @@ -56,7 +56,7 @@ vec3 lambertian(vec3 frag_pos, vec3 frag_normal) { normalize(to_camera_space(ssbo.point_lights[i].pos) - frag_pos); vec3 normal = normalize(frag_normal); vec3 diffuse = - max(dot(light_dir, normal), 0.0) * selected_color * light_color; + max(dot(light_dir, normal), 0.0) * selected_color.rgb * light_color; result += diffuse; } @@ -80,7 +80,7 @@ void main() { pos_camera_space.xyz / pos_camera_space.w + coord_in_sphere * ubo.radius; vec3 frag_normal = coord_in_sphere; vec3 color = lambertian(frag_pos, frag_normal); - out_color = vec4(color, 1.0); + out_color = vec4(color, selected_color.a); float depth = (pos_camera_space.z / pos_camera_space.w) + z_in_sphere * ubo.radius; diff --git a/python/taichi/shaders/Particles_vk.vert b/python/taichi/shaders/Particles_vk.vert index 5dbfeb7c79cc4..6a963fe846868 100644 --- a/python/taichi/shaders/Particles_vk.vert +++ b/python/taichi/shaders/Particles_vk.vert @@ -3,7 +3,7 @@ layout(location = 0) in vec3 in_position; layout(location = 1) in vec3 in_normal; layout(location = 2) in vec2 in_texcoord; -layout(location = 3) in vec3 in_color; +layout(location = 3) in vec4 in_color; struct SceneUBO { vec3 camera_pos; @@ -25,7 +25,7 @@ layout(binding = 0) uniform UBO { ubo; layout(location = 0) out vec4 pos_camera_space; -layout(location = 1) out vec3 selected_color; +layout(location = 1) out vec4 selected_color; void main() { float distance = length(in_position - ubo.scene.camera_pos); @@ -38,7 +38,7 @@ void main() { gl_Position.y *= -1; if (ubo.use_per_vertex_color == 0) { - selected_color = ubo.color; + selected_color = vec4(ubo.color, 1.0); } else { selected_color = in_color; } diff --git a/python/taichi/shaders/Particles_vk_frag.spv b/python/taichi/shaders/Particles_vk_frag.spv index cd3dabb8dd778..5038c2506c2cf 100644 Binary files a/python/taichi/shaders/Particles_vk_frag.spv and b/python/taichi/shaders/Particles_vk_frag.spv differ diff --git a/python/taichi/shaders/Particles_vk_vert.spv b/python/taichi/shaders/Particles_vk_vert.spv index c432761fbe874..d46476649e80a 100644 Binary files a/python/taichi/shaders/Particles_vk_vert.spv and b/python/taichi/shaders/Particles_vk_vert.spv differ diff --git a/python/taichi/shaders/SetImage_vk.vert b/python/taichi/shaders/SetImage_vk.vert index 9f40ffe1ec0ca..b8610c5486662 100644 --- a/python/taichi/shaders/SetImage_vk.vert +++ b/python/taichi/shaders/SetImage_vk.vert @@ -5,7 +5,7 @@ layout(location = 0) in vec3 in_position; layout(location = 1) in vec3 in_normal; layout(location = 2) in vec2 in_texcoord; -layout(location = 3) in vec3 in_color; +layout(location = 3) in vec4 in_color; layout(location = 0) out vec2 frag_texcoord; diff --git a/python/taichi/shaders/Triangles_vk.vert b/python/taichi/shaders/Triangles_vk.vert index 1bc1f8ded4a03..914396bf003a5 100644 --- a/python/taichi/shaders/Triangles_vk.vert +++ b/python/taichi/shaders/Triangles_vk.vert @@ -3,7 +3,7 @@ layout(location = 0) in vec3 in_position; layout(location = 1) in vec3 in_normal; layout(location = 2) in vec2 in_texcoord; -layout(location = 3) in vec3 in_color; +layout(location = 3) in vec4 in_color; layout(location = 0) out vec2 frag_texcoord; layout(location = 1) out vec3 selected_color; @@ -24,6 +24,6 @@ void main() { if (ubo.use_per_vertex_color == 0) { selected_color = ubo.color; } else { - selected_color = in_color; + selected_color = in_color.rgb; } } diff --git a/python/taichi/snode/__init__.py b/python/taichi/snode/__init__.py deleted file mode 100644 index e3b852b98d876..0000000000000 --- a/python/taichi/snode/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from taichi.snode.fields_builder import FieldsBuilder diff --git a/python/taichi/tools/__init__.py b/python/taichi/tools/__init__.py index 98e9b2efbd0b3..c70d7f8f46cab 100644 --- a/python/taichi/tools/__init__.py +++ b/python/taichi/tools/__init__.py @@ -1,5 +1,6 @@ -from .np2ply import PLYWriter -from .patterns import taichi_logo -from .video import VideoManager - -__all__ = [s for s in dir() if not s.startswith('_')] +from taichi.tools.async_utils import * +from taichi.tools.cc_compose import * +from taichi.tools.diagnose import * +from taichi.tools.image import * +from taichi.tools.np2ply import * +from taichi.tools.video import * diff --git a/python/taichi/tools/async_utils.py b/python/taichi/tools/async_utils.py new file mode 100644 index 0000000000000..9cbec3391b044 --- /dev/null +++ b/python/taichi/tools/async_utils.py @@ -0,0 +1,29 @@ +import subprocess + +from taichi._lib import core as _ti_core +from taichi.lang.impl import get_runtime + + +def dump_dot(filepath=None, rankdir=None, embed_states_threshold=0): + d = get_runtime().prog.dump_dot(rankdir, embed_states_threshold) + if filepath is not None: + with open(filepath, 'w') as fh: + fh.write(d) + return d + + +def dot_to_pdf(dot, filepath): + assert filepath.endswith('.pdf') + with subprocess.Popen(['dot', '-Tpdf'], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE) as p: + pdf_contents = p.communicate(input=dot.encode())[0] + with open(filepath, 'wb') as fh: + fh.write(pdf_contents) + + +def get_kernel_stats(): + return _ti_core.get_kernel_stats() + + +__all__ = [] diff --git a/python/taichi/cc_compose.py b/python/taichi/tools/cc_compose.py similarity index 97% rename from python/taichi/cc_compose.py rename to python/taichi/tools/cc_compose.py index 99bb960811e88..01bd803d22ba8 100644 --- a/python/taichi/cc_compose.py +++ b/python/taichi/tools/cc_compose.py @@ -24,7 +24,6 @@ def do_group_begin(self, e): self.launches = [] def do_group_end(self, e): - name = e['content'] self.groups[self.current_group] = list(self.launches) self.current_group = None self.launches = [] @@ -77,7 +76,6 @@ def do_compile_layout(self, e): self.emit('') def do_allocate_buffer(self, e): - root_size = e['root_size'] gtmp_size = e['gtmp_size'] extr_size = 4 * 1024 * 1024 # pinpoint: 4 MB @@ -106,7 +104,6 @@ def do_allocate_buffer(self, e): self.emit('') def do_compile_kernel(self, e): - name = e['kernel_name'] source = e['kernel_source'] if self.emscripten: @@ -137,7 +134,7 @@ def main(fin_name, fout_name, hdrout_name, emscripten=False): import yaml # pylint: disable=C0415 with open(fin_name, 'r') as fin: warnings.filterwarnings('ignore') - obj = yaml.load(fin) + obj = yaml.load(fin, Loader=yaml.FullLoader) with open(hdrout_name, 'w') as hdrout: with open(fout_name, 'w') as fout: @@ -147,3 +144,5 @@ def main(fin_name, fout_name, hdrout_name, emscripten=False): if __name__ == '__main__': main(sys.argv[1], sys.argv[2], sys.argv[3], len(sys.argv) > 4) + +__all__ = [] diff --git a/python/taichi/diagnose.py b/python/taichi/tools/diagnose.py similarity index 95% rename from python/taichi/diagnose.py rename to python/taichi/tools/diagnose.py index 184e508e06da0..a2d06fd84e297 100644 --- a/python/taichi/diagnose.py +++ b/python/taichi/tools/diagnose.py @@ -47,7 +47,7 @@ def try_print(tag, expr): try_print('import', 'ti') print('') for arch in ['cc', 'cpu', 'metal', 'opengl', 'cuda', 'vulkan']: - try_print(arch, f'ti.is_arch_supported(ti.{arch})') + try_print(arch, f'ti.lang.misc.is_arch_supported(ti.{arch})') print('') try: @@ -112,7 +112,7 @@ def try_print(tag, expr): ti_laplace = subprocess.check_output( [executable, '-m', 'taichi', 'example', 'minimal']) except Exception as e: - print(f'`examples/laplace.py` failed: {e}') + print(f'`python/taichi/examples/algorithm/laplace.py` failed: {e}') else: print(f'{ti_laplace.decode()}') @@ -123,3 +123,5 @@ def try_print(tag, expr): if __name__ == '__main__': main() + +__all__ = [] diff --git a/python/taichi/tools/file.py b/python/taichi/tools/file.py deleted file mode 100644 index 22797a45afc13..0000000000000 --- a/python/taichi/tools/file.py +++ /dev/null @@ -1,9 +0,0 @@ -import os - - -def clear_directory_with_suffix(directory, suffix): - files = os.listdir(directory) - assert suffix[0] != '.', "No '.' needed." - for f in files: - if f.endswith('.' + suffix): - os.remove(os.path.join(directory, f)) diff --git a/python/taichi/misc/image.py b/python/taichi/tools/image.py similarity index 75% rename from python/taichi/misc/image.py rename to python/taichi/tools/image.py index 018043764324e..29e7d2d025e2b 100644 --- a/python/taichi/misc/image.py +++ b/python/taichi/tools/image.py @@ -1,7 +1,5 @@ -from io import BytesIO - import numpy as np -from taichi.core import ti_core as _ti_core +from taichi._lib import core as _ti_core import taichi as ti @@ -10,7 +8,7 @@ def cook_image_to_bytes(img): """ Takes a NumPy array or Taichi field of any type. Returns a NumPy array of uint8. - This is used by ti.imwrite and ti.imdisplay. + This is used by ti.imwrite. """ if not isinstance(img, np.ndarray): img = img.to_numpy() @@ -34,26 +32,6 @@ def cook_image_to_bytes(img): return img.swapaxes(0, 1)[::-1, :] -def imdisplay(img): - """ - Try to display image in interactive shell. - - Args: - img (Union[ti.field, np.ndarray]): A field of of array with shape `(width, height)` or `(height, width, 3)` or `(height, width, 4)`. - """ - try: - get_ipython() - except: - ti.imshow(img) - else: - import IPython.display # pylint: disable=C0415 - import PIL.Image # pylint: disable=C0415 - img = cook_image_to_bytes(img) - with BytesIO() as f: - PIL.Image.fromarray(img).save(f, 'png') - IPython.display.display(IPython.display.Image(data=f.getvalue())) - - def imresize(img, w, h=None): """Resize an image to a specific size. @@ -113,32 +91,36 @@ def imread(filename, channels=0): return img.swapaxes(0, 1)[:, ::-1, :] -def imshow(img, window_name='imshow'): - """Show image in a Taichi GUI. +def imshow(img, title='imshow'): + """Display a taichi.field or a numpy.ndarray in a Taichi GUI window or an interactive Ipython notebook. + For an interactive Ipython environment, the image will be shown in the notebook. Args: img (Union[ti.field, np.ndarray]): A field of of array with shape `(width, height)` or `(height, width, 3)` or `(height, width, 4)`. - window_name (str, optional): The title of GUI window. Default to `imshow`. + title (str, optional): The title of GUI window. Default to `imshow`. """ - if not isinstance(img, np.ndarray): - img = img.to_numpy() - assert len(img.shape) in [2, - 3], "Image must be either RGB/RGBA or greyscale" - - with ti.GUI(window_name, res=img.shape[:2]) as gui: - img = gui.cook_image(img) - while gui.running: - if gui.get_event(ti.GUI.ESCAPE): - gui.running = False - - gui.set_image(img) - gui.show() + try: # check if we are in Ipython environment + get_ipython() + except: + if not isinstance(img, np.ndarray): + img = img.to_numpy() + assert len( + img.shape) in [2, + 3], "Image must be either RGB/RGBA or greyscale" + + with ti.GUI(title, res=img.shape[:2]) as gui: + img = gui.cook_image(img) + while gui.running: + if gui.get_event(ti.GUI.ESCAPE): + gui.running = False + + gui.set_image(img) + gui.show() + else: + import IPython.display # pylint: disable=C0415 + import PIL.Image # pylint: disable=C0415 + img = cook_image_to_bytes(img) + IPython.display.display(PIL.Image.fromarray(img)) -__all__ = [ - 'imshow', - 'imread', - 'imwrite', - 'imresize', - 'imdisplay', -] +__all__ = ['imread', 'imresize', 'imshow', 'imwrite'] diff --git a/python/taichi/tools/messenger.py b/python/taichi/tools/messenger.py deleted file mode 100644 index 906ad30c0868f..0000000000000 --- a/python/taichi/tools/messenger.py +++ /dev/null @@ -1,85 +0,0 @@ -import atexit -import os -import smtplib -import socket - -import taichi as tc - -gmail_sender = 'taichi.messager@gmail.com' -gmail_passwd = '6:L+XbNOp^' - -emailed = False - - -def send_crash_report(message, receiver=None): - global emailed - if emailed: - return - emailed = True - if receiver is None: - receiver = os.environ.get('TI_MONITOR_EMAIL', None) - if receiver is None: - tc.warn('No receiver in $TI_MONITOR_EMAIL') - return - tc.warn('Emailing {}'.format(receiver)) - TO = receiver - SUBJECT = 'Report' - TEXT = message - - server = smtplib.SMTP('smtp.gmail.com', 587) - server.ehlo() - server.starttls() - server.login(gmail_sender, gmail_passwd) - - BODY = '\r\n'.join([ - 'To: %s' % TO, - 'From: %s' % gmail_sender, - 'Subject: %s' % SUBJECT, '', TEXT - ]) - - try: - server.sendmail(gmail_sender, [TO], BODY) - except: - print('Error sending mail') - server.quit() - print('Press enter or Ctrl + \\ to exit.') - - -def enable(task_name): - register_call_back(task_name) - - -crashed = False -keep = [] - - -def register_call_back(task_name): - def at_exit(): - if not crashed: - message = 'Congratulations! Your task [{}] at machine [{}] has finished.'.format( - task_name, socket.gethostname()) - send_crash_report(message) - - def email_call_back(_): - global crashed - crashed = True - tc.warn('Task has crashed.') - message = 'Your task [{}] at machine [{}] has crashed.'.format( - task_name, socket.gethostname()) - send_crash_report(message) - atexit.unregister(at_exit) - exit(-1) - - keep.append(email_call_back) - # TODO: email_call_back should be passed to Taichi core (C++). It will then called by the signal handler when Taichi crashes - # (std::function python_at_exit) - # Simply register a callback in the Python scope will not work in cases when Taichi crashes - # call_back = tc.function11(email_call_back) - # tc.core.register_at_exit(call_back) - - atexit.register(at_exit) - - -if __name__ == '__main__': - register_call_back('test') - tc.core.trigger_sig_fpe() diff --git a/python/taichi/tools/np2ply.py b/python/taichi/tools/np2ply.py index 5902fa4c7ea2a..e8ff38c75f076 100644 --- a/python/taichi/tools/np2ply.py +++ b/python/taichi/tools/np2ply.py @@ -3,8 +3,6 @@ import numpy as np -import taichi as ti - class PLYWriter: def __init__(self, @@ -25,9 +23,8 @@ def __init__(self, np.float32, np.float64 ] self.type_map = {} - for i in range(len(self.ply_supported_types)): - self.type_map[self.ply_supported_types[ - i]] = self.corresponding_numpy_types[i] + for i, ply_type in enumerate(self.ply_supported_types): + self.type_map[ply_type] = self.corresponding_numpy_types[i] self.num_vertices = num_vertices self.num_vertex_channels = 0 @@ -46,9 +43,10 @@ def __init__(self, self.face_indices = -np.ones((self.num_faces, 4), dtype=np.int32) self.comment = comment - def add_vertex_channel(self, key: str, type: str, data: np.array): - if type not in self.ply_supported_types: - print("Unknown type " + type + " detected, skipping this channel") + def add_vertex_channel(self, key: str, data_type: str, data: np.array): + if data_type not in self.ply_supported_types: + print("Unknown type " + data_type + + " detected, skipping this channel") return if data.ndim == 1: assert data.size == self.num_vertices, "The dimension of the vertex channel is not correct" @@ -56,8 +54,8 @@ def add_vertex_channel(self, key: str, type: str, data: np.array): if key in self.vertex_channels: print("WARNING: duplicate key " + key + " detected") self.vertex_channels.append(key) - self.vertex_data_type.append(type) - self.vertex_data.append(self.type_map[type](data)) + self.vertex_data_type.append(data_type) + self.vertex_data.append(self.type_map[data_type](data)) else: num_col = data.size // self.num_vertices assert data.ndim == 2 and data.size == num_col * \ @@ -69,8 +67,8 @@ def add_vertex_channel(self, key: str, type: str, data: np.array): if item_key in self.vertex_channels: print("WARNING: duplicate key " + item_key + " detected") self.vertex_channels.append(item_key) - self.vertex_data_type.append(type) - self.vertex_data.append(self.type_map[type](data[:, i])) + self.vertex_data_type.append(data_type) + self.vertex_data.append(self.type_map[data_type](data[:, i])) def add_vertex_pos(self, x: np.array, y: np.array, z: np.array): self.add_vertex_channel("x", "float", x) @@ -164,9 +162,10 @@ def add_faces(self, indices: np.array): (self.num_faces, vert_per_face)) self.face_indices = self.face_indices.astype(np.int32) - def add_face_channel(self, key: str, type: str, data: np.array): - if type not in self.ply_supported_types: - print("Unknown type " + type + " detected, skipping this channel") + def add_face_channel(self, key: str, data_type: str, data: np.array): + if data_type not in self.ply_supported_types: + print("Unknown type " + data_type + + " detected, skipping this channel") return if data.ndim == 1: assert data.size == self.num_faces, "The dimension of the face channel is not correct" @@ -174,8 +173,8 @@ def add_face_channel(self, key: str, type: str, data: np.array): if key in self.face_channels: print("WARNING: duplicate key " + key + " detected") self.face_channels.append(key) - self.face_data_type.append(type) - self.face_data.append(self.type_map[type](data)) + self.face_data_type.append(data_type) + self.face_data.append(self.type_map[data_type](data)) else: num_col = data.size // self.num_faces assert data.ndim == 2 and data.size == num_col * \ @@ -187,8 +186,8 @@ def add_face_channel(self, key: str, type: str, data: np.array): if item_key in self.face_channels: print("WARNING: duplicate key " + item_key + " detected") self.face_channels.append(item_key) - self.face_data_type.append(type) - self.face_data.append(self.type_map[type](data[:, i])) + self.face_data_type.append(data_type) + self.face_data.append(self.type_map[data_type](data[:, i])) def add_face_id(self): self.add_face_channel("id", "int", np.arange(self.num_faces)) @@ -204,17 +203,17 @@ def sanity_check(self): for idx in self.face_indices.flatten(): assert idx >= 0 and idx < self.num_vertices, "The face indices are invalid" - def print_header(self, path: str, format: str): + def print_header(self, path: str, _format: str): with open(path, "w") as f: f.writelines([ - "ply\n", "format " + format + " 1.0\n", + "ply\n", "format " + _format + " 1.0\n", "comment " + self.comment + "\n" ]) f.write("element vertex " + str(self.num_vertices) + "\n") for i in range(self.num_vertex_channels): f.write("property " + self.vertex_data_type[i] + " " + self.vertex_channels[i] + "\n") - if (self.num_faces != 0): + if self.num_faces != 0: f.write("element face " + str(self.num_faces) + "\n") f.write("property list uchar int vertex_indices\n") for i in range(self.num_face_channels): @@ -267,7 +266,7 @@ def export_frame_ascii(self, series_num: int, path: str): if last_4_char == ".ply": path = path[:-4] - real_path = path + "_" + "{0:0=6d}".format(series_num) + ".ply" + real_path = path + "_" + f"{series_num:0=6d}" + ".ply" self.export_ascii(real_path) def export_frame(self, series_num: int, path: str): @@ -276,5 +275,8 @@ def export_frame(self, series_num: int, path: str): if last_4_char == ".ply": path = path[:-4] - real_path = path + "_" + "{0:0=6d}".format(series_num) + ".ply" + real_path = path + "_" + f"{series_num:0=6d}" + ".ply" self.export(real_path) + + +__all__ = ['PLYWriter'] diff --git a/python/taichi/tools/video.py b/python/taichi/tools/video.py index 846d92fd4e663..4a9e1bcf31297 100644 --- a/python/taichi/tools/video.py +++ b/python/taichi/tools/video.py @@ -1,8 +1,8 @@ import os import shutil -from taichi.core import get_os_name -from taichi.misc.image import imwrite +from taichi._lib.utils import get_os_name +from taichi.tools.image import imwrite FRAME_FN_TEMPLATE = '%06d.png' FRAME_DIR = 'frames' @@ -10,21 +10,22 @@ # Write the frames to the disk and then make videos (mp4 or gif) if necessary -def scale_video(input, output, ratiow, ratioh): - os.system('ffmpeg -i {} -vf "scale=iw*{:.4f}:ih*{:.4f}" {}'.format( - input, ratiow, ratioh, output)) +def scale_video(input_fn, output_fn, ratiow, ratioh): + os.system( + f'ffmpeg -i {input_fn} -vf "scale=iw*{ratiow:.4f}:ih*{ratioh:.4f}" {output_fn}' + ) -def crop_video(input, output, x_begin, x_end, y_begin, y_end): +def crop_video(input_fn, output_fn, x_begin, x_end, y_begin, y_end): os.system( - 'ffmpeg -i {} -filter:v "crop=iw*{:.4f}:ih*{:.4f}:iw*{:0.4f}:ih*{:0.4f}" {}' - .format(input, x_end - x_begin, y_end - y_begin, x_begin, 1 - y_end, - output)) + f'ffmpeg -i {input_fn} -filter:v "crop=iw*{x_end - x_begin:.4f}:ih*{y_end - y_begin:.4f}:iw*{x_begin:0.4f}:ih*{1 - y_end:0.4f}" {output_fn}' + ) -def accelerate_video(input, output, speed): - os.system('ffmpeg -i {} -filter:v "setpts={:.4f}*PTS" {}'.format( - input, 1 / speed, output)) +def accelerate_video(input_fn, output_fn, speed): + os.system( + f'ffmpeg -i {input_fn} -filter:v "setpts={1 / speed:.4f}*PTS" {output_fn}' + ) def get_ffmpeg_path(): @@ -36,19 +37,17 @@ def mp4_to_gif(input_fn, output_fn, framerate): palette_name = 'palette.png' if get_os_name() == 'win': command = get_ffmpeg_path( - ) + " -loglevel panic -i %s -vf 'palettegen' -y %s" % (input_fn, - palette_name) + ) + f" -loglevel panic -i {input_fn} -vf 'palettegen' -y {palette_name}" else: command = get_ffmpeg_path( - ) + " -loglevel panic -i %s -vf 'fps=%d,scale=320:640:flags=lanczos,palettegen' -y %s" % ( - input_fn, framerate, palette_name) + ) + f" -loglevel panic -i {input_fn} -vf 'fps={framerate}," \ + f"scale=320:640:flags=lanczos,palettegen' -y {palette_name}" # print command os.system(command) # Generate the GIF command = get_ffmpeg_path( - ) + " -loglevel panic -i %s -i %s -lavfi paletteuse -y %s" % ( - input_fn, palette_name, output_fn) + ) + f" -loglevel panic -i {input_fn} -i {palette_name} -lavfi paletteuse -y {output_fn}" # print command os.system(command) os.remove(palette_name) @@ -117,7 +116,8 @@ def clean_frames(self): def make_video(self, mp4=True, gif=True): fn = self.get_output_filename('.mp4') - command = (get_ffmpeg_path() + " -loglevel panic -framerate %d -i " % self.framerate) + os.path.join(self.frame_directory, FRAME_FN_TEMPLATE) + \ + command = (get_ffmpeg_path() + f" -loglevel panic -framerate {self.framerate} -i ") + os.path.join( + self.frame_directory, FRAME_FN_TEMPLATE) + \ " -s:v " + str(self.width) + 'x' + str(self.height) + \ " -c:v libx264 -profile:v high -crf 1 -pix_fmt yuv420p -y " + fn @@ -139,7 +139,7 @@ def interpolate_frames(frame_dir, mul=4): images_interpolated = [] for f in sorted(files): if f.endswith('png'): - images.append(cv2.imread(f) / 255.0) + images.append(cv2.imread(f) / 255.0) # pylint: disable=E1101 for i in range(len(images) - 1): images_interpolated.append(images[i]) @@ -152,12 +152,12 @@ def interpolate_frames(frame_dir, mul=4): os.makedirs('interpolated', exist_ok=True) for i, img in enumerate(images_interpolated): - cv2.imwrite('interpolated/{:05d}.png'.format(i), img * 255.0) + cv2.imwrite(f'interpolated/{i:05d}.png', img * 255.0) # pylint: disable=E1101 -def ffmpeg_common_args(frame_rate, input, width, height, crf, output_path): - return f"{get_ffmpeg_path()} -y -loglevel panic -framerate {frame_rate} -i {input} -s:v {width}x{height} " + \ - f"-c:v libx264 -profile:v high -crf {crf} -pix_fmt yuv420p {output_path}" +def ffmpeg_common_args(frame_rate, input_fn, width, height, crf, output_path): + return f"{get_ffmpeg_path()} -y -loglevel panic -framerate {frame_rate} -i {input_fn} -s:v {width}x{height} " + \ + f"-c:v libx264 -profile:v high -crf {crf} -pix_fmt yuv420p {output_path}" def make_video(input_files, @@ -173,20 +173,20 @@ def make_video(input_files, tmp_dir = 'tmp_ffmpeg_dir' os.mkdir(tmp_dir) if width % 2 != 0: - print("Width ({}) not divisible by 2".format(width)) + print(f"Width ({width}) not divisible by 2") width -= 1 if height % 2 != 0: - print("Height ({}) not divisible by 2".format(width)) + print(f"Height ({width}) not divisible by 2") height -= 1 for i, inp in enumerate(input_files): - shutil.copy(inp, os.path.join(tmp_dir, '%06d.png' % i)) + shutil.copy(inp, os.path.join(tmp_dir, f'{i:06d}.png')) inputs = f'{tmp_dir}/%06d.png' command = ffmpeg_common_args(frame_rate, inputs, width, height, crf, output_path) ret = os.system(command) assert ret == 0, "ffmpeg failed to generate video file." for i in range(len(input_files)): - os.remove(os.path.join(tmp_dir, '%06d.png' % i)) + os.remove(os.path.join(tmp_dir, f'{i:06d}.png')) os.rmdir(tmp_dir) elif isinstance(input_files, str): assert width != 0 and height != 0 @@ -196,3 +196,6 @@ def make_video(input_files, assert ret == 0, "ffmpeg failed to generate video file." else: assert False, f'input_files should be list (of files) or str (of file template, e.g., "%04d.png") instead of {type(input_files)}' + + +__all__ = ['VideoManager'] diff --git a/python/taichi/torch_io.py b/python/taichi/torch_io.py deleted file mode 100644 index 819c4d9164572..0000000000000 --- a/python/taichi/torch_io.py +++ /dev/null @@ -1,26 +0,0 @@ -from taichi.lang.kernel_impl import kernel -from taichi.type.annotations import ext_arr, template - - -@kernel -def from_torch_template(expr: template(), torch_tensor: ext_arr()): - for i in expr: - expr[i] = torch_tensor[i] - - -@kernel -def to_torch_template(expr: template(), torch_tensor: ext_arr()): - for i in expr: - torch_tensor[i] = expr[i] - - -def from_torch(expr, torch_tensor): - if not expr.from_torch_: - expr.from_torch_ = lambda x: from_torch_template(expr, x.contiguous()) - expr.from_torch_(torch_tensor) - - -def to_torch(expr, torch_tensor): - if not expr.to_torch_: - expr.to_torch_ = lambda x: to_torch_template(expr, x.contiguous()) - expr.to_torch_(torch_tensor) diff --git a/python/taichi/type/__init__.py b/python/taichi/type/__init__.py deleted file mode 100644 index cc82c2a60626c..0000000000000 --- a/python/taichi/type/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from taichi.type.annotations import * -from taichi.type.primitive_types import * diff --git a/python/taichi/type/primitive_types.py b/python/taichi/type/primitive_types.py deleted file mode 100644 index 09a47006529a9..0000000000000 --- a/python/taichi/type/primitive_types.py +++ /dev/null @@ -1,86 +0,0 @@ -from taichi.core import ti_core - -# Real types - -float32 = ti_core.DataType_f32 -"""32-bit single precision floating point data type. -""" -f32 = float32 -"""Alias for :const:`~taichi.type.primitive_types.float32` -""" -float64 = ti_core.DataType_f64 -"""64-bit double precision floating point data type. -""" -f64 = float64 -"""Alias for :const:`~taichi.type.primitive_types.float64` -""" - -real_types = [f32, f64, float] -real_type_ids = [id(t) for t in real_types] - -# Integer types - -int8 = ti_core.DataType_i8 -i8 = int8 -int16 = ti_core.DataType_i16 -i16 = int16 -int32 = ti_core.DataType_i32 -"""32-bit signed integer data type. -""" -i32 = int32 -"""Alias for :const:`~taichi.type.primitive_types.int32` -""" -int64 = ti_core.DataType_i64 -"""64-bit signed integer data type. -""" -i64 = int64 -"""Alias for :const:`~taichi.type.primitive_types.int64` -""" - -uint8 = ti_core.DataType_u8 -u8 = uint8 -uint16 = ti_core.DataType_u16 -u16 = uint16 -uint32 = ti_core.DataType_u32 -"""32-bit unsigned integer data type. -""" -u32 = uint32 -"""Alias for :const:`~taichi.type.primitive_types.uint32` -""" -uint64 = ti_core.DataType_u64 -"""64-bit unsigned integer data type. -""" -u64 = uint64 -"""Alias for :const:`~taichi.type.primitive_types.uint64` -""" - -integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int] -integer_type_ids = [id(t) for t in integer_types] - -types = real_types + integer_types -type_ids = [id(t) for t in types] - -__all__ = [ - 'float32', - 'f32', - 'float64', - 'f64', - 'int8', - 'i8', - 'int16', - 'i16', - 'int32', - 'i32', - 'int64', - 'i64', - 'uint8', - 'u8', - 'uint16', - 'u16', - 'uint32', - 'u32', - 'uint64', - 'u64', - 'real_types', - 'integer_types', -] diff --git a/python/taichi/types/__init__.py b/python/taichi/types/__init__.py new file mode 100644 index 0000000000000..ee913ba4f7ccd --- /dev/null +++ b/python/taichi/types/__init__.py @@ -0,0 +1,5 @@ +from taichi.types.annotations import * +from taichi.types.compound_types import * +from taichi.types.primitive_types import * +from taichi.types.quantized_types import * +from taichi.types.utils import * diff --git a/python/taichi/type/annotations.py b/python/taichi/types/annotations.py similarity index 58% rename from python/taichi/type/annotations.py rename to python/taichi/types/annotations.py index 3de93335c435b..4e98870c2dc7a 100644 --- a/python/taichi/type/annotations.py +++ b/python/taichi/types/annotations.py @@ -6,28 +6,54 @@ class ArgAnyArray: Args: element_dim (Union[Int, NoneType], optional): None if not specified (will be treated as 0 for external arrays), 0 if scalar elements, 1 if vector elements, and 2 if matrix elements. + element_shape (Union[Tuple[Int], NoneType]): None if not specified, shapes of each element. For example, element_shape must be 1d for vector and 2d tuple for matrix. This argument is ignored for external arrays for now. + field_dim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now. layout (Union[Layout, NoneType], optional): None if not specified (will be treated as Layout.AOS for external arrays), Layout.AOS or Layout.SOA. """ - def __init__(self, element_dim=None, layout=None): + def __init__(self, + element_dim=None, + element_shape=None, + field_dim=None, + layout=None): if element_dim is not None and (element_dim < 0 or element_dim > 2): raise ValueError( "Only scalars, vectors, and matrices are allowed as elements of ti.any_arr()" ) - self.element_dim = element_dim + if element_dim is not None and element_shape is not None and len( + element_shape) != element_dim: + raise ValueError( + f"Both element_shape and element_dim are specified, but shape doesn't match specified dim: {len(element_shape)}!={element_dim}" + ) + self.element_shape = element_shape + self.element_dim = len( + element_shape) if element_shape is not None else element_dim + self.field_dim = field_dim self.layout = layout - def check_element_dim(self, arg, arg_dim): + def _check_element_dim(self, arg, arg_dim): if self.element_dim is not None and self.element_dim != arg_dim: raise ValueError( f"Invalid argument into ti.any_arr() - required element_dim={self.element_dim}, but {arg} is provided" ) - def check_layout(self, arg): + def _check_layout(self, arg): if self.layout is not None and self.layout != arg.layout: raise ValueError( f"Invalid argument into ti.any_arr() - required layout={self.layout}, but {arg} is provided" ) + def _check_element_shape(self, shapes): + if self.element_shape is not None and shapes != self.element_shape: + raise ValueError( + f"Invalid argument into ti.any_arr() - required element_shape={self.element_shape}, but {shapes} is provided" + ) + + def _check_field_dim(self, field_dim): + if self.field_dim is not None and field_dim != self.field_dim: + raise ValueError( + f"Invalid argument into ti.any_arr() - required field_dim={self.field_dim}, but {field_dim} is provided" + ) + def ext_arr(): """Type annotation for external arrays. @@ -49,7 +75,7 @@ def ext_arr(): any_arr = ArgAnyArray -"""Alias for :class:`~taichi.type.annotations.ArgAnyArray`. +"""Alias for :class:`~taichi.types.annotations.ArgAnyArray`. Example:: @@ -80,7 +106,12 @@ def __init__(self, tensor=None, dim=None): template = Template -"""Alias for :class:`~taichi.type.annotations.Template`. +"""Alias for :class:`~taichi.types.annotations.Template`. """ -__all__ = ['ext_arr', 'any_arr', 'template'] + +class sparse_matrix_builder: + pass + + +__all__ = ['ext_arr', 'any_arr', 'template', 'sparse_matrix_builder'] diff --git a/python/taichi/types/compound_types.py b/python/taichi/types/compound_types.py new file mode 100644 index 0000000000000..cad8854969f16 --- /dev/null +++ b/python/taichi/types/compound_types.py @@ -0,0 +1,21 @@ +import taichi + + +class CompoundType: + pass + + +# TODO: maybe move MatrixType, StructType here to avoid the circular import? +def matrix(n, m, dtype): + return taichi.lang.matrix.MatrixType(n, m, dtype) + + +def vector(n, dtype): + return taichi.lang.matrix.MatrixType(n, 1, dtype) + + +def struct(**kwargs): + return taichi.lang.struct.StructType(**kwargs) + + +__all__ = ['matrix', 'vector', 'struct'] diff --git a/python/taichi/types/primitive_types.py b/python/taichi/types/primitive_types.py new file mode 100644 index 0000000000000..726da1831a32a --- /dev/null +++ b/python/taichi/types/primitive_types.py @@ -0,0 +1,176 @@ +from taichi._lib import core as ti_core + +# ======================================== +# real types + +# ---------------------------------------- + +float16 = ti_core.DataType_f16 +"""16-bit precision floating point data type. +""" + +# ---------------------------------------- + +f16 = float16 +"""Alias for :const:`~taichi.types.primitive_types.float16` +""" + +# ---------------------------------------- + +float32 = ti_core.DataType_f32 +"""32-bit single precision floating point data type. +""" + +# ---------------------------------------- + +f32 = float32 +"""Alias for :const:`~taichi.types.primitive_types.float32` +""" + +# ---------------------------------------- + +float64 = ti_core.DataType_f64 +"""64-bit double precision floating point data type. +""" + +# ---------------------------------------- + +f64 = float64 +"""Alias for :const:`~taichi.types.primitive_types.float64` +""" +# ---------------------------------------- + +# ======================================== +# Integer types + +# ---------------------------------------- + +int8 = ti_core.DataType_i8 +"""8-bit signed integer data type. +""" + +# ---------------------------------------- + +i8 = int8 +"""Alias for :const:`~taichi.types.primitive_types.int8` +""" + +# ---------------------------------------- + +int16 = ti_core.DataType_i16 +"""16-bit signed integer data type. +""" + +# ---------------------------------------- + +i16 = int16 +"""Alias for :const:`~taichi.types.primitive_types.int16` +""" + +# ---------------------------------------- + +int32 = ti_core.DataType_i32 +"""32-bit signed integer data type. +""" + +# ---------------------------------------- + +i32 = int32 +"""Alias for :const:`~taichi.types.primitive_types.int32` +""" + +# ---------------------------------------- + +int64 = ti_core.DataType_i64 +"""64-bit signed integer data type. +""" + +# ---------------------------------------- + +i64 = int64 +"""Alias for :const:`~taichi.types.primitive_types.int64` +""" + +# ---------------------------------------- + +uint8 = ti_core.DataType_u8 +"""8-bit unsigned integer data type. +""" + +# ---------------------------------------- + +u8 = uint8 +"""Alias for :const:`~taichi.types.primitive_types.uint8` +""" + +# ---------------------------------------- + +uint16 = ti_core.DataType_u16 +"""16-bit unsigned integer data type. +""" + +# ---------------------------------------- + +u16 = uint16 +"""Alias for :const:`~taichi.types.primitive_types.uint16` +""" + +# ---------------------------------------- + +uint32 = ti_core.DataType_u32 +"""32-bit unsigned integer data type. +""" + +# ---------------------------------------- + +u32 = uint32 +"""Alias for :const:`~taichi.types.primitive_types.uint32` +""" + +# ---------------------------------------- + +uint64 = ti_core.DataType_u64 +"""64-bit unsigned integer data type. +""" + +# ---------------------------------------- + +u64 = uint64 +"""Alias for :const:`~taichi.types.primitive_types.uint64` +""" + +# ---------------------------------------- + +real_types = [f16, f32, f64, float] +real_type_ids = [id(t) for t in real_types] + +integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int] +integer_type_ids = [id(t) for t in integer_types] + +types = real_types + integer_types +type_ids = [id(t) for t in types] + +__all__ = [ + 'float32', + 'f32', + 'float64', + 'f64', + 'float16', + 'f16', + 'int8', + 'i8', + 'int16', + 'i16', + 'int32', + 'i32', + 'int64', + 'i64', + 'uint8', + 'u8', + 'uint16', + 'u16', + 'uint32', + 'u32', + 'uint64', + 'u64', +] diff --git a/python/taichi/types/quantized_types.py b/python/taichi/types/quantized_types.py new file mode 100644 index 0000000000000..19a6dec44a2cb --- /dev/null +++ b/python/taichi/types/quantized_types.py @@ -0,0 +1,129 @@ +from taichi._lib.utils import ti_core as _ti_core +from taichi.lang import impl +from taichi.types.primitive_types import i32 + + +class TypeFactory: + """A Python-side TypeFactory wrapper.""" + def __init__(self): + self.core = _ti_core.get_type_factory_instance() + + def custom_int(self, bits, signed=True, compute_type=None): + """Generates a custom int type. + + Args: + bits (int): Number of bits. + signed (bool): Signed or unsigned. + compute_type (DataType): Type for computation. + + Returns: + DataType: The specified type. + """ + if compute_type is None: + compute_type = impl.get_runtime().default_ip + if isinstance(compute_type, _ti_core.DataType): + compute_type = compute_type.get_ptr() + return self.core.get_custom_int_type(bits, signed, compute_type) + + def custom_float(self, + significand_type, + exponent_type=None, + compute_type=None, + scale=1.0): + """Generates a custom float type. + + Args: + significand_type (DataType): Type of significand. + exponent_type (DataType): Type of exponent. + compute_type (DataType): Type for computation. + scale (float): Scaling factor. + + Returns: + DataType: The specified type. + """ + if compute_type is None: + compute_type = impl.get_runtime().default_fp + if isinstance(compute_type, _ti_core.DataType): + compute_type = compute_type.get_ptr() + return self.core.get_custom_float_type(significand_type, + exponent_type, + compute_type, + scale=scale) + + +# Unstable API +type_factory = TypeFactory() + + +class Quant: + """Generator of quantized types. + + For more details, read https://yuanming.taichi.graphics/publication/2021-quantaichi/quantaichi.pdf. + """ + @staticmethod + def int(bits, signed=False, compute=None): + """Generates a quantized type for integers. + + Args: + bits (int): Number of bits. + signed (bool): Signed or unsigned. + compute (DataType): Type for computation. + + Returns: + DataType: The specified type. + """ + if compute is None: + compute = impl.get_runtime().default_ip + return type_factory.custom_int(bits, signed, compute) + + @staticmethod + def fixed(frac, signed=True, num_range=1.0, compute=None): + """Generates a quantized type for fixed-point real numbers. + + Args: + frac (int): Number of bits. + signed (bool): Signed or unsigned. + num_range (float): Range of the number. + compute (DataType): Type for computation. + + Returns: + DataType: The specified type. + """ + # TODO: handle cases with frac > 32 + frac_type = Quant.int(bits=frac, signed=signed, compute=i32) + if signed: + scale = num_range / 2**(frac - 1) + else: + scale = num_range / 2**frac + if compute is None: + compute = impl.get_runtime().default_fp + return type_factory.custom_float(frac_type, None, compute, scale) + + @staticmethod + def float(exp, frac, signed=True, compute=None): + """Generates a quantized type for floating-point real numbers. + + Args: + exp (int): Number of exponent bits. + frac (int): Number of fraction bits. + signed (bool): Signed or unsigned. + compute (DataType): Type for computation. + + Returns: + DataType: The specified type. + """ + # Exponent is always unsigned + exp_type = Quant.int(bits=exp, signed=False, compute=i32) + # TODO: handle cases with frac > 32 + frac_type = Quant.int(bits=frac, signed=signed, compute=i32) + if compute is None: + compute = impl.get_runtime().default_fp + return type_factory.custom_float(significand_type=frac_type, + exponent_type=exp_type, + compute_type=compute) + + +# Unstable API +quant = Quant + +__all__ = [] diff --git a/python/taichi/types/utils.py b/python/taichi/types/utils.py new file mode 100644 index 0000000000000..5f168cfc8d605 --- /dev/null +++ b/python/taichi/types/utils.py @@ -0,0 +1,7 @@ +from taichi._lib import core as ti_core + +is_signed = ti_core.is_signed + +is_integral = ti_core.is_integral + +__all__ = ['is_signed', 'is_integral'] diff --git a/python/taichi/ui/__init__.py b/python/taichi/ui/__init__.py index 3023d814c3c81..972d3fc806635 100644 --- a/python/taichi/ui/__init__.py +++ b/python/taichi/ui/__init__.py @@ -1 +1,2 @@ +from .gui import * from .ui import * diff --git a/python/taichi/ui/camera.py b/python/taichi/ui/camera.py index 09586945f39c1..e9c20f5b5849b 100644 --- a/python/taichi/ui/camera.py +++ b/python/taichi/ui/camera.py @@ -3,7 +3,6 @@ from taichi.lang.matrix import Vector from .utils import euler_to_vec, vec_to_euler -from .window import Window class Camera: @@ -51,7 +50,7 @@ def bottom(self, bottom): def z_near(self, z_near): self.ptr.z_near(z_near) - def z_near(self, z_far): + def z_far(self, z_far): self.ptr.z_far(z_far) # move the camera according to user inputs, FPS game style. @@ -64,6 +63,7 @@ def track_user_inputs(self, front = (self.curr_lookat - self.curr_position).normalized() position_change = Vector([0.0, 0.0, 0.0]) left = self.curr_up.cross(front) + up = self.curr_up if window.is_pressed('w'): position_change = front * movement_speed if window.is_pressed('s'): @@ -72,6 +72,10 @@ def track_user_inputs(self, position_change = left * movement_speed if window.is_pressed('d'): position_change = -left * movement_speed + if window.is_pressed('e'): + position_change = up * movement_speed + if window.is_pressed('q'): + position_change = -up * movement_speed self.position(*(self.curr_position + position_change)) self.lookat(*(self.curr_lookat + position_change)) diff --git a/python/taichi/ui/canvas.py b/python/taichi/ui/canvas.py index eccf0fb6e5353..985b4ce3f28a1 100644 --- a/python/taichi/ui/canvas.py +++ b/python/taichi/ui/canvas.py @@ -1,12 +1,6 @@ -from taichi.core import ti_core as _ti_core -from taichi.lang.impl import default_cfg -from taichi.lang.kernel_impl import kernel -from taichi.lang.ops import get_addr -from taichi.type.annotations import ext_arr, template - from .staging_buffer import (copy_colors_to_vbo, copy_vertices_to_vbo, get_vbo_field, to_u8_rgba) -from .utils import * +from .utils import get_field_info class Canvas: @@ -92,4 +86,4 @@ def circles(self, def scene(self, scene): """Draw a 3D scene on the canvas""" - self.canvas.scene(scene) + self.canvas.scene(scene.scene) diff --git a/python/taichi/ui/gui.py b/python/taichi/ui/gui.py index 404e120451db3..946d21d49ffef 100644 --- a/python/taichi/ui/gui.py +++ b/python/taichi/ui/gui.py @@ -1,99 +1,938 @@ -import pathlib -from contextlib import contextmanager +import math +import numbers +import os -from taichi.core import ti_core as _ti_core -from taichi.lang.impl import default_cfg -from taichi.lang.kernel_impl import kernel -from taichi.lang.ops import get_addr -from taichi.type.annotations import ext_arr, template +import numpy as np +import taichi.lang +from taichi._kernels import (tensor_to_image, vector_to_fast_image, + vector_to_image) +from taichi._lib import core as _ti_core +from taichi.lang.field import Field, ScalarField -from .utils import * +import taichi as ti -class Gui: - def __init__(self, gui) -> None: - self.gui = gui #reference to a PyGui +# For window creation and drawing in the original ti.GUI system. +class GUI: + """Taichi Graphical User Interface class. - @contextmanager - def sub_window(self, name, x, y, width, height): - """Creating a context manager for subwindow + Args: + name (str, optional): The name of the GUI to be constructed. + Default is 'Taichi'. + res (Union[int, List[int]], optional): The resolution of created + GUI. Default is 512*512. If `res` is scalar, then width will be equal to height. + background_color (int, optional): The background color of created GUI. + Default is 0x000000. + show_gui (bool, optional): Specify whether to render the GUI. Default is True. + fullscreen (bool, optional): Specify whether to render the GUI in + fullscreen mode. Default is False. + fast_gui (bool, optional): Specify whether to use fast gui mode of + Taichi. Default is False. - Note: - All args of this method should align with `begin`. + Returns: + :class:`~taichi.misc.gui.GUI` :The created taichi GUI object. + + """ + class Event: + def __init__(self): + self.type = None + self.modifier = None + self.pos = None + self.key = None + self.delta = None + + # Event keys + SHIFT = 'Shift' + ALT = 'Alt' + CTRL = 'Control' + ESCAPE = 'Escape' + RETURN = 'Return' + TAB = 'Tab' + BACKSPACE = 'BackSpace' + SPACE = ' ' + UP = 'Up' + DOWN = 'Down' + LEFT = 'Left' + RIGHT = 'Right' + CAPSLOCK = 'Caps_Lock' + LMB = 'LMB' + MMB = 'MMB' + RMB = 'RMB' + EXIT = 'WMClose' + WHEEL = 'Wheel' + MOVE = 'Motion' + + # Event types + MOTION = _ti_core.KeyEvent.EType.Move + PRESS = _ti_core.KeyEvent.EType.Press + RELEASE = _ti_core.KeyEvent.EType.Release + + def __init__(self, + name='Taichi', + res=512, + background_color=0x0, + show_gui=True, + fullscreen=False, + fast_gui=False): + show_gui = self.get_bool_environ('TI_GUI_SHOW', show_gui) + fullscreen = self.get_bool_environ('TI_GUI_FULLSCREEN', fullscreen) + fast_gui = self.get_bool_environ('TI_GUI_FAST', fast_gui) + + self.name = name + if isinstance(res, numbers.Number): + res = (res, res) + self.res = res + self.fast_gui = fast_gui + if fast_gui: + self.img = np.ascontiguousarray( + np.zeros(self.res[0] * self.res[1], dtype=np.uint32)) + fast_buf = self.img.ctypes.data + else: + # The GUI canvas uses RGBA for storage, therefore we need NxMx4 for an image. + self.img = np.ascontiguousarray( + np.zeros(self.res + (4, ), np.float32)) + fast_buf = 0 + self.core = _ti_core.GUI(name, core_veci(*res), show_gui, fullscreen, + fast_gui, fast_buf) + self.canvas = self.core.get_canvas() + self.background_color = background_color + self.key_pressed = set() + self.event = None + self.frame = 0 + self.clear() + + def __enter__(self): + return self + + def __exit__(self, e_type, val, tb): + self.close() + + def __del__(self): + self.close() + + def close(self): + self.core = None # dereference to call GUI::~GUI() + + # Widget system + + class WidgetValue: + def __init__(self, gui, wid): + self.gui = gui + self.wid = wid + + @property + def value(self): + return self.gui.core.get_widget_value(self.wid) + + @value.setter + def value(self, value): + self.gui.core.set_widget_value(self.wid, value) + + @staticmethod + def get_bool_environ(key, default): + """Get an environment variable and cast to bool. + Args: + key (str): The environment variable key. + default (bool): The default value. + Return: + The environment variable value cast to bool. If the value is not found, directly return argument 'default'. + """ + if key not in os.environ: + return default + return bool(int(os.environ[key])) + + def slider(self, text, minimum, maximum, step=1): + """Create a slider object on canvas to be manipulated with. Args: - x (float): The x-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. - y (float): The y-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. - width (float): The width of the subwindow relative to the full window. - height (float): The height of the subwindow relative to the full window. + text (str): The title of slider. + minimum (Number): The minimum value of slider. + maximum (Number): The maximum value of slider. + step (Number, optional): The changing step of slider. Default is 1. - Usage:: + Return: + :class:`~taichi.misc.gui.GUI.WidgetValue` :The created slider object. - >>> with gui.sub_window(name, x, y, width, height) as g: - >>> g.text("Hello, World!") """ - self.begin(name, x, y, width, height) - try: - yield self - finally: - self.end() + wid = self.core.make_slider(text, minimum, minimum, maximum, step) + return GUI.WidgetValue(self, wid) - def begin(self, name, x, y, width, height): - """Creates a subwindow that holds imgui widgets. + def label(self, text): + """Create a label object on canvas. - All widget function calls (e.g. `text`, `button`) after the `begin` and before the next `end` will describe the widgets within this subwindow. + Args: + text (str): The title of label. + + Return: + :class:`~taichi.misc.gui.GUI.WidgetValue` :The created label object. + + """ + wid = self.core.make_label(text, 0) + return GUI.WidgetValue(self, wid) + + def button(self, text, event_name=None): + """Create a button object on canvas to be manipulated with. Args: - x (float): The x-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. - y (float): The y-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. - width (float): The width of the subwindow relative to the full window. - height (float): The height of the subwindow relative to the full window. + text (str): The title of button. + event_name (str, optional): The event name associated with button. + Default is WidgetButton_{text} + + Return: + The event name associated with created button. + """ - self.gui.begin(name, x, y, width, height) + event_name = event_name or f'WidgetButton_{text}' + self.core.make_button(text, event_name) + return event_name + + # Drawing system + + def clear(self, color=None): + """Clear the canvas with the color provided. + + Args: + color (int, optional): Specify the color to clear the canvas. Default + is the background color of GUI. - def end(self): - """End the description of the current subwindow. """ - self.gui.end() + if color is None: + color = self.background_color + self.canvas.clear(color) + + def cook_image(self, img): + if img.dtype in [np.uint8, np.uint16, np.uint32, np.uint64]: + img = img.astype(np.float32) * (1 / np.iinfo(img.dtype).max) + elif img.dtype in [np.float16, np.float32, np.float64]: + img = img.astype(np.float32) + else: + raise ValueError( + f'Data type {img.dtype} not supported in GUI.set_image') + + if len(img.shape) == 2: + img = img[..., None] + + if img.shape[2] == 1: + img = img + np.zeros((1, 1, 4), np.float32) + if img.shape[2] == 3: + zeros = np.zeros((img.shape[0], img.shape[1], 1), np.float32) + img = np.concatenate([img, zeros], axis=2) + if img.shape[2] == 2: + zeros = np.zeros((img.shape[0], img.shape[1], 2), np.float32) + img = np.concatenate([img, zeros], axis=2) + + assert img.shape[2] == 4, "Image must be grayscale, RG, RGB or RGBA" + + res = img.shape[:2] + assert res == self.res, "Image resolution does not match GUI resolution" + return np.ascontiguousarray(img) + + def get_image(self): + """Get the image data. + + Returns: + :class:`numpy.array` :The image data in numpy contiguous array type. - def text(self, text): - """Declares a line of text. """ - self.gui.text(text) + self.img = np.ascontiguousarray(self.img) + self.core.get_img(self.img.ctypes.data) + return self.img - def checkbox(self, text, old_value): - """Declares a checkbox, and returns whether or not it has been checked. + def set_image(self, img): + """Sets an image to display on the window. + The image pixels are set from the values of `img[i, j]`, where `i` indicates the horizontal coordinates (from left to right) and `j` the vertical coordinates (from bottom to top). + If the window size is `(x, y)`, then `img` must be one of: + - `ti.field(shape=(x, y))`, a gray-scale image + - `ti.field(shape=(x, y, 3))`, where `3` is for `(r, g, b)` channels + - `ti.field(shape=(x, y, 2))`, where `2` is for `(r, g)` channels + - `ti.Vector.field(3, shape=(x, y))` `(r, g, b)` channels on each component + - `ti.Vector.field(2, shape=(x, y))` `(r, g)` channels on each component + - `np.ndarray(shape=(x, y))` + - `np.ndarray(shape=(x, y, 3))` + - `np.ndarray(shape=(x, y, 2))` + The data type of `img` must be one of: + - `uint8`, range `[0, 255]` + - `uint16`, range `[0, 65535]` + - `uint32`, range `[0, 4294967295]` + - `float32`, range `[0, 1]` + - `float64`, range `[0, 1]` Args: - text (str): a line of text to be shown next to the checkbox - old_value (bool): whether the checkbox is currently checked + img (Union[ti.field, numpy.array]): The color array representing the + image to be drawn. Support greyscale, RG, RGB, and RGBA color + representations. Its shape must match GUI resolution. + """ - return self.gui.checkbox(text, old_value) - def slider_float(self, text, old_value, minimum, maximum): - """Declares a slider, and returns its newest value. + if self.fast_gui: + assert isinstance(img, taichi.lang.matrix.MatrixField), \ + "Only ti.Vector.field is supported in GUI.set_image when fast_gui=True" + assert img.shape == self.res, \ + "Image resolution does not match GUI resolution" + assert img.n in [3, 4] and img.m == 1, \ + "Only RGB images are supported in GUI.set_image when fast_gui=True" + assert img.dtype in [ti.f32, ti.f64, ti.u8], \ + "Only f32, f64, u8 are supported in GUI.set_image when fast_gui=True" + + vector_to_fast_image(img, self.img) + return + + if isinstance(img, ScalarField): + if _ti_core.is_integral(img.dtype) or len(img.shape) != 2: + # Images of uint is not optimized by xxx_to_image + self.img = self.cook_image(img.to_numpy()) + else: + # Type matched! We can use an optimized copy kernel. + assert img.shape \ + == self.res, "Image resolution does not match GUI resolution" + tensor_to_image(img, self.img) + ti.sync() + + elif isinstance(img, taichi.lang.matrix.MatrixField): + if _ti_core.is_integral(img.dtype): + self.img = self.cook_image(img.to_numpy()) + else: + # Type matched! We can use an optimized copy kernel. + assert img.shape == self.res, \ + "Image resolution does not match GUI resolution" + assert img.n in [2, 3, 4] and img.m == 1, \ + "Only greyscale, RG, RGB or RGBA images are supported in GUI.set_image" + + vector_to_image(img, self.img) + ti.sync() + + elif isinstance(img, np.ndarray): + self.img = self.cook_image(img) + + else: + raise ValueError( + f"GUI.set_image only takes a Taichi field or NumPy array, not {type(img)}" + ) + + self.core.set_img(self.img.ctypes.data) + + def circle(self, pos, color=0xFFFFFF, radius=1): + """Draw a single circle on canvas. Args: - text (str): a line of text to be shown next to the slider - old_value (float): the current value of the slider. - minimum (float): the minimum value of the slider. - maximum (float): the maximum value of the slider. + pos (Union[List[int], numpy.array]): The position of the circle. + color (int, Optional): The color of the circle. Default is 0xFFFFFF. + radius (Number, Optional): The radius of the circle in pixel. Default is 1. + """ - return self.gui.slider_float(text, old_value, minimum, maximum) + self.canvas.circle_single(pos[0], pos[1], color, radius) - def color_edit_3(self, text, old_value): - """Declares a color edit palate. + def circles(self, + pos, + radius=1, + color=0xFFFFFF, + palette=None, + palette_indices=None): + """Draw a list of circles on canvas. Args: - text (str): a line of text to be shown next to the palate - old_value (Tuple[float]): the current value of the color, this should be a tuple of floats in [0,1] that indicates RGB values. + pos (numpy.array): The positions of the circles. + radius (Number, optional): The radius of the circles in pixel. Default is 1. + color (int, optional): The color of the circles. Default is 0xFFFFFF. + palette (list[int], optional): The List of colors from which to + choose to draw. Default is None. + palette_indices (Union[list[int], ti.field, numpy.array], optional): + The List of indices that choose color from palette for each + circle. Shape must match pos. Default is None. + """ - return self.gui.color_edit_3(text, old_value) + n = pos.shape[0] + if len(pos.shape) == 3: + assert pos.shape[2] == 1 + pos = pos[:, :, 0] - def button(self, text): - """Declares a button, and returns whether or not it had just been clicked. + assert pos.shape == (n, 2) + pos = np.ascontiguousarray(pos.astype(np.float32)) + # Note: do not use "pos = int(pos.ctypes.data)" here + # Otherwise pos will get garbage collected by Python + # and the pointer to its data becomes invalid + pos_ptr = int(pos.ctypes.data) + + if isinstance(color, np.ndarray): + assert color.shape == (n, ) + color = np.ascontiguousarray(color.astype(np.uint32)) + color_array = int(color.ctypes.data) + color_single = 0 + elif isinstance(color, int): + color_array = 0 + color_single = color + else: + raise ValueError( + 'Color must be an ndarray or int (e.g., 0x956333)') + + if palette is not None: + assert palette_indices is not None, 'palette must be used together with palette_indices' + + if isinstance(palette_indices, Field): + ind_int = palette_indices.to_numpy().astype(np.uint32) + elif isinstance(palette_indices, list) or isinstance( + palette_indices, np.ndarray): + ind_int = np.array(palette_indices).astype(np.uint32) + else: + try: + ind_int = np.array(palette_indices) + except: + raise TypeError( + 'palette_indices must be a type that can be converted to numpy.ndarray' + ) + + assert issubclass( + ind_int.dtype.type, + np.integer), 'palette_indices must be an integer array' + assert ind_int.shape == ( + n, + ), 'palette_indices must be in 1-d shape with shape (num_particles, )' + assert min( + ind_int + ) >= 0, 'the min of palette_indices must not be less than zero' + assert max(ind_int) < len( + palette + ), 'the max of palette_indices must not exceed the length of palette' + color_array = np.array(palette, dtype=np.uint32)[ind_int] + color_array = np.ascontiguousarray(color_array) + color_array = color_array.ctypes.data + + if isinstance(radius, np.ndarray): + assert radius.shape == (n, ) + radius = np.ascontiguousarray(radius.astype(np.float32)) + radius_array = int(radius.ctypes.data) + radius_single = 0 + elif isinstance(radius, numbers.Number): + radius_array = 0 + radius_single = radius + else: + raise ValueError('Radius must be an ndarray or float (e.g., 0.4)') + + self.canvas.circles_batched(n, pos_ptr, color_single, color_array, + radius_single, radius_array) + + def triangles(self, a, b, c, color=0xFFFFFF): + """Draw a list of triangles on canvas. Args: - text (str): a line of text to be shown next to the button + a (numpy.array): The positions of the first points of triangles. + b (numpy.array): The positions of the second points of triangles. + c (numpy.array): The positions of the thrid points of triangles. + color (Union[int, numpy.array], optional): The color or colors of triangles. + Can be either a single color or a list of colors whose shape matches + the shape of a & b & c. Default is 0xFFFFFF. + """ - return self.gui.button(text) + assert a.shape == b.shape + assert a.shape == c.shape + n = a.shape[0] + if len(a.shape) == 3: + assert a.shape[2] == 1 + a = a[:, :, 0] + b = b[:, :, 0] + c = c[:, :, 0] + + assert a.shape == (n, 2) + a = np.ascontiguousarray(a.astype(np.float32)) + b = np.ascontiguousarray(b.astype(np.float32)) + c = np.ascontiguousarray(c.astype(np.float32)) + # Note: do not use "a = int(a.ctypes.data)" here + # Otherwise a will get garbage collected by Python + # and the pointer to its data becomes invalid + a_ptr = int(a.ctypes.data) + b_ptr = int(b.ctypes.data) + c_ptr = int(c.ctypes.data) + + if isinstance(color, np.ndarray): + assert color.shape == (n, ) + color = np.ascontiguousarray(color.astype(np.uint32)) + color_array = int(color.ctypes.data) + color_single = 0 + elif isinstance(color, int): + color_array = 0 + color_single = color + else: + raise ValueError( + '"color" must be an ndarray or int (e.g., 0x956333)') + + self.canvas.triangles_batched(n, a_ptr, b_ptr, c_ptr, color_single, + color_array) + + def triangle(self, a, b, c, color=0xFFFFFF): + """Draw a single triangle on canvas. + + Args: + a (List[Number]): The position of the first point of triangle. Shape must be 2. + b (List[Number]): The position of the second point of triangle. Shape must be 2. + c (List[Number]): The position of the third point of triangle. Shape must be 2. + color (int, optional): The color of the triangle. Default is 0xFFFFFF. + + """ + self.canvas.triangle_single(a[0], a[1], b[0], b[1], c[0], c[1], color) + + def lines(self, begin, end, radius=1, color=0xFFFFFF): + """Draw a list of lines on canvas. + + Args: + begin (numpy.array): The positions of one end of lines. + end (numpy.array): The positions of the other end of lines. + radius (Union[Number, numpy.array], optional): The width of lines. + Can be either a single width or a list of width whose shape matches + the shape of begin & end. Default is 1. + color (Union[int, numpy.array], optional): The color or colors of lines. + Can be either a single color or a list of colors whose shape matches + the shape of begin & end. Default is 0xFFFFFF. + + """ + assert begin.shape == end.shape + n = begin.shape[0] + if len(begin.shape) == 3: + assert begin.shape[2] == 1 + begin = begin[:, :, 0] + end = end[:, :, 0] + + assert begin.shape == (n, 2) + begin = np.ascontiguousarray(begin.astype(np.float32)) + end = np.ascontiguousarray(end.astype(np.float32)) + # Note: do not use "begin = int(begin.ctypes.data)" here + # Otherwise begin will get garbage collected by Python + # and the pointer to its data becomes invalid + begin_ptr = int(begin.ctypes.data) + end_ptr = int(end.ctypes.data) + + if isinstance(color, np.ndarray): + assert color.shape == (n, ) + color = np.ascontiguousarray(color.astype(np.uint32)) + color_array = int(color.ctypes.data) + color_single = 0 + elif isinstance(color, int): + color_array = 0 + color_single = color + else: + raise ValueError( + 'Color must be an ndarray or int (e.g., 0x956333)') + + if isinstance(radius, np.ndarray): + assert radius.shape == (n, ) + radius = np.ascontiguousarray(radius.astype(np.float32)) + radius_array = int(radius.ctypes.data) + radius_single = 0 + elif isinstance(radius, numbers.Number): + radius_array = 0 + radius_single = radius + else: + raise ValueError('Radius must be an ndarray or float (e.g., 0.4)') + + self.canvas.paths_batched(n, begin_ptr, end_ptr, color_single, + color_array, radius_single, radius_array) + + def line(self, begin, end, radius=1, color=0xFFFFFF): + """Draw a single line on canvas. + + Args: + begin (List[Number]): The position of one end of line. Shape must be 2. + end (List[Number]): The position of the other end of line. Shape must be 2. + radius (Number, optional): The width of line. Default is 1. + color (int, optional): The color of line. Default is 0xFFFFFF. + + """ + self.canvas.path_single(begin[0], begin[1], end[0], end[1], color, + radius) + + @staticmethod + def _arrow_to_lines(orig, major, tip_scale=0.2, angle=45): + angle = math.radians(180 - angle) + c, s = math.cos(angle), math.sin(angle) + minor1 = np.array([ + major[:, 0] * c - major[:, 1] * s, + major[:, 0] * s + major[:, 1] * c + ]).swapaxes(0, 1) + minor2 = np.array([ + major[:, 0] * c + major[:, 1] * s, + -major[:, 0] * s + major[:, 1] * c + ]).swapaxes(0, 1) + end = orig + major + return [(orig, end), (end, end + minor1 * tip_scale), + (end, end + minor2 * tip_scale)] + + def arrows(self, orig, direction, radius=1, color=0xffffff, **kwargs): + """Draw a list arrows on canvas. + + Args: + orig (numpy.array): The positions where arrows start. + direction (numpy.array): The directions where arrows point to. + radius (Union[Number, np.array], optional): The width of arrows. Default is 1. + color (Union[int, np.array], optional): The color or colors of arrows. Default is 0xffffff. + + """ + for begin, end in self._arrow_to_lines(orig, direction, **kwargs): + self.lines(begin, end, radius, color) + + def arrow(self, orig, direction, radius=1, color=0xffffff, **kwargs): + """Draw a single arrow on canvas. + + Args: + orig (List[Number]): The position where arrow starts. Shape must be 2. + direction (List[Number]): The direction where arrow points to. Shape must be 2. + radius (Number, optional): The width of arrow. Default is 1. + color (int, optional): The color of arrow. Default is 0xFFFFFF. + + """ + orig = np.array([orig]) + direction = np.array([direction]) + for begin, end in self._arrow_to_lines(orig, direction, **kwargs): + self.line(begin[0], end[0], radius, color) + + def rect(self, topleft, bottomright, radius=1, color=0xFFFFFF): + """Draw a single rectangle on canvas. + + Args: + topleft (List[Number]): The position of the topleft corner of rectangle. + Shape must be 2. + bottomright (List[Number]): The position of the bottomright corner + of rectangle. Shape must be 2. + radius (Number, optional): The width of rectangle's sides. Default is 1. + color (int, optional): The color of rectangle. Default is 0xFFFFFF. + + """ + a = topleft[0], topleft[1] + b = bottomright[0], topleft[1] + c = bottomright[0], bottomright[1] + d = topleft[0], bottomright[1] + self.line(a, b, radius, color) + self.line(b, c, radius, color) + self.line(c, d, radius, color) + self.line(d, a, radius, color) + + def text(self, content, pos, font_size=15, color=0xFFFFFF): + """Draw texts on canvas. + + Args: + content (str): The text to be drawn on canvas. + pos (List[Number]): The position where the text is to be put. + font_size (Number, optional): The font size of the text. + color (int, optional): The color of the text. Default is 0xFFFFFF. + + """ + + # TODO: refactor Canvas::text + font_size = float(font_size) + pos = core_vec(*pos) + r, g, b = hex_to_rgb(color) + color = core_vec(r, g, b, 1) + self.canvas.text(content, pos, font_size, color) + + @staticmethod + def _make_field_base(w, h, bound): + x = np.linspace(bound / w, 1 - bound / w, w) + y = np.linspace(bound / h, 1 - bound / h, h) + base = np.array(np.meshgrid(x, y)) + base = base.swapaxes(0, 1).swapaxes(1, 2).swapaxes(0, 1) + return base.reshape(w * h, 2) + + def point_field(self, radius, color=0xffffff, bound=0.5): + """Draw a field of points on canvas. + + Args: + radius (np.array): The pattern and radius of the field of points. + color (Union[int, np.array], optional): The color or colors of points. + Default is 0xFFFFFF. + bound (Number, optional): The boundary of the field. Default is 0.5. + + """ + assert len(radius.shape) == 2 + base = self._make_field_base(radius.shape[0], radius.shape[1], bound) + radius = radius.reshape(radius.shape[0] * radius.shape[1]) + self.circles(base, radius=radius, color=color) + + def arrow_field(self, + direction, + radius=1, + color=0xffffff, + bound=0.5, + **kwargs): + """Draw a field of arrows on canvas. + + Args: + direction (np.array): The pattern and direction of the field of arrows. + color (Union[int, np.array], optional): The color or colors of arrows. + Default is 0xFFFFFF. + bound (Number, optional): The boundary of the field. Default is 0.5. + + """ + assert len(direction.shape) == 3 + assert direction.shape[2] == 2 + base = self._make_field_base(direction.shape[0], direction.shape[1], + bound) + direction = direction.reshape(direction.shape[0] * direction.shape[1], + 2) + self.arrows(base, direction, radius=radius, color=color, **kwargs) + + def show(self, file=None): + """Show the frame or save current frame as a picture. + + Args: + file (str, optional): The path & name of the picture to be saved. + Default is None. + + """ + self.core.update() + if file: + self.core.screenshot(file) + self.frame += 1 + self.clear() + + # Event system + + class EventFilter: + def __init__(self, *e_filter): + self.filter = set() + for ent in e_filter: + if isinstance(ent, (list, tuple)): + e_type, key = ent + ent = (e_type, key) + self.filter.add(ent) + + def match(self, e): + if (e.type, e.key) in self.filter: + return True + if e.type in self.filter: + return True + if e.key in self.filter: + return True + return False + + def has_key_event(self): + """Check if there are any key event registered. + + Returns: + Bool to indicate whether there is any key event registered. + + """ + return self.core.has_key_event() + + def get_event(self, *e_filter): + """Check if the specific event is triggered. + + Args: + *e_filter (ti.GUI.EVENT): The specific event to be checked. + + Returns: + Bool to indicate whether the specific event is triggered. + + """ + for e in self.get_events(*e_filter): + self.event = e + return True + else: + return False + + def get_events(self, *e_filter): + """Get a list of events that are triggered. + + Args: + *e_filter (List[ti.GUI.EVENT]): The type of events to be filtered. + + Returns: + :class:`~taichi.misc.gui.GUI.EVENT` :A list of events that are triggered. + + """ + e_filter = e_filter and GUI.EventFilter(*e_filter) or None + + while True: + if not self.has_key_event(): + break + e = self.get_key_event() + if e_filter is None or e_filter.match(e): # pylint: disable=E1101 + yield e + + def get_key_event(self): + """Get keyboard triggered event. + + Returns: + :class:`~taichi.misc.gui.GUI.EVENT` :The keyboard triggered event. + + """ + self.core.wait_key_event() + + e = GUI.Event() + event = self.core.get_key_event_head() + + e.type = event.type + e.key = event.key + e.pos = self.core.canvas_untransform(event.pos) + e.pos = (e.pos[0], e.pos[1]) + e.modifier = [] + + if e.key == GUI.WHEEL: + e.delta = event.delta + else: + e.delta = (0, 0) + + for mod in ['Shift', 'Alt', 'Control']: + if self.is_pressed(mod): + e.modifier.append(mod) + + if e.type == GUI.PRESS: + self.key_pressed.add(e.key) + else: + self.key_pressed.discard(e.key) + + self.core.pop_key_event_head() + return e + + def is_pressed(self, *keys): + """Check if the specific key or keys are pressed. + + Args: + *keys (Union[str, List[str]]): The string that stands for keys in keyboard. + + Returns: + Bool to indicate whether the key or keys are pressed. + + """ + for key in keys: + if key in ['Shift', 'Alt', 'Control']: + if key + '_L' in self.key_pressed or key + '_R' in self.key_pressed: + return True + if key in self.key_pressed: + return True + else: + return False + + def get_cursor_pos(self): + """Get the current position of mouse. + + Returns: + The current position of mouse. + + """ + pos = self.core.get_cursor_pos() + return pos[0], pos[1] + + @property + def running(self): + """Get the property of whether the gui is running. + + Returns: + The running property of gui(bool). + + """ + return not self.core.should_close + + @running.setter + def running(self, value): + if value: + self.core.should_close = 0 + elif not self.core.should_close: + self.core.should_close = 1 + + @property + def fps_limit(self): + """Get the property of fps limit. + + Returns: + The property of fps limit of gui. + + """ + if self.core.frame_delta_limit == 0: + return None + return 1 / self.core.frame_delta_limit + + @fps_limit.setter + def fps_limit(self, value): + if value is None: + self.core.frame_delta_limit = 0 + else: + self.core.frame_delta_limit = 1 / value + + +def rgb_to_hex(c): + """Convert rgb color format to hex color format. + + Args: + c (List[int]): The rgb representation of color. + + Returns: + The hex representation of color. + + """ + def to255(x): + return np.clip(np.int32(x * 255), 0, 255) + + return (to255(c[0]) << 16) + (to255(c[1]) << 8) + to255(c[2]) + + +def hex_to_rgb(color): + """Convert hex color format to rgb color format. + + Args: + color (int): The hex representation of color. + + Returns: + The rgb representation of color. + + """ + r, g, b = (color >> 16) & 0xff, (color >> 8) & 0xff, color & 0xff + return r / 255, g / 255, b / 255 + + +def core_veci(*args): + if isinstance(args[0], _ti_core.Vector2i): + return args[0] + if isinstance(args[0], _ti_core.Vector3i): + return args[0] + if isinstance(args[0], tuple): + args = tuple(*args) + if len(args) == 2: + return _ti_core.Vector2i(int(args[0]), int(args[1])) + if len(args) == 3: + return _ti_core.Vector3i(int(args[0]), int(args[1]), int(args[2])) + if len(args) == 4: + return _ti_core.Vector4i(int(args[0]), int(args[1]), int(args[2]), + int(args[3])) + assert False, type(args[0]) + + +def core_vec(*args): + if isinstance(args[0], _ti_core.Vector2f): + return args[0] + if isinstance(args[0], _ti_core.Vector3f): + return args[0] + if isinstance(args[0], _ti_core.Vector4f): + return args[0] + if isinstance(args[0], _ti_core.Vector2d): + return args[0] + if isinstance(args[0], _ti_core.Vector3d): + return args[0] + if isinstance(args[0], _ti_core.Vector4d): + return args[0] + if isinstance(args[0], tuple): + args = tuple(*args) + if _ti_core.get_default_float_size() == 4: + if len(args) == 2: + return _ti_core.Vector2f(float(args[0]), float(args[1])) + if len(args) == 3: + return _ti_core.Vector3f(float(args[0]), float(args[1]), + float(args[2])) + if len(args) == 4: + return _ti_core.Vector4f(float(args[0]), float(args[1]), + float(args[2]), float(args[3])) + assert False, type(args[0]) + else: + if len(args) == 2: + return _ti_core.Vector2d(float(args[0]), float(args[1])) + if len(args) == 3: + return _ti_core.Vector3d(float(args[0]), float(args[1]), + float(args[2])) + if len(args) == 4: + return _ti_core.Vector4d(float(args[0]), float(args[1]), + float(args[2]), float(args[3])) + assert False, type(args[0]) + + +__all__ = [ + 'GUI', + 'rgb_to_hex', + 'hex_to_rgb', +] diff --git a/python/taichi/ui/imgui.py b/python/taichi/ui/imgui.py new file mode 100644 index 0000000000000..28b426d48246e --- /dev/null +++ b/python/taichi/ui/imgui.py @@ -0,0 +1,91 @@ +from contextlib import contextmanager + + +#For declaring IMGUI components in a ti.Window created by the GGUI system. +class Gui: + def __init__(self, gui) -> None: + self.gui = gui #reference to a PyGui + + @contextmanager + def sub_window(self, name, x, y, width, height): + """Creating a context manager for subwindow + + Note: + All args of this method should align with `begin`. + + Args: + x (float): The x-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. + y (float): The y-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. + width (float): The width of the subwindow relative to the full window. + height (float): The height of the subwindow relative to the full window. + + Usage:: + + >>> with gui.sub_window(name, x, y, width, height) as g: + >>> g.text("Hello, World!") + """ + self.begin(name, x, y, width, height) + try: + yield self + finally: + self.end() + + def begin(self, name, x, y, width, height): + """Creates a subwindow that holds imgui widgets. + + All widget function calls (e.g. `text`, `button`) after the `begin` and before the next `end` will describe the widgets within this subwindow. + + Args: + x (float): The x-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. + y (float): The y-coordinate (between 0 and 1) of the top-left corner of the subwindow, relative to the full window. + width (float): The width of the subwindow relative to the full window. + height (float): The height of the subwindow relative to the full window. + """ + self.gui.begin(name, x, y, width, height) + + def end(self): + """End the description of the current subwindow. + """ + self.gui.end() + + def text(self, text): + """Declares a line of text. + """ + self.gui.text(text) + + def checkbox(self, text, old_value): + """Declares a checkbox, and returns whether or not it has been checked. + + Args: + text (str): a line of text to be shown next to the checkbox + old_value (bool): whether the checkbox is currently checked + """ + return self.gui.checkbox(text, old_value) + + def slider_float(self, text, old_value, minimum, maximum): + """Declares a slider, and returns its newest value. + + Args: + text (str): a line of text to be shown next to the slider + old_value (float): the current value of the slider. + minimum (float): the minimum value of the slider. + maximum (float): the maximum value of the slider. + """ + return self.gui.slider_float(text, old_value, minimum, maximum) + + def color_edit_3(self, text, old_value): + """Declares a color edit palate. + + Args: + text (str): a line of text to be shown next to the palate + old_value (Tuple[float]): the current value of the color, this should be a tuple of floats in [0,1] that indicates RGB values. + """ + return self.gui.color_edit_3(text, old_value) + + def button(self, text): + """Declares a button, and returns whether or not it had just been clicked. + + Args: + text (str): a line of text to be shown next to the button + """ + return self.gui.button(text) diff --git a/python/taichi/ui/scene.py b/python/taichi/ui/scene.py index 8b5f5646aaa88..d7333f9630d26 100644 --- a/python/taichi/ui/scene.py +++ b/python/taichi/ui/scene.py @@ -1,17 +1,14 @@ -import pathlib - -from taichi.core import ti_core as _ti_core -from taichi.lang.impl import default_cfg, field +from taichi._lib import core as _ti_core +from taichi.lang.impl import field from taichi.lang.kernel_impl import kernel from taichi.lang.matrix import Vector -from taichi.lang.ops import atomic_add, get_addr -from taichi.type.annotations import ext_arr, template -from taichi.type.primitive_types import f32 +from taichi.lang.ops import atomic_add +from taichi.types.annotations import template +from taichi.types.primitive_types import f32 -from .camera import Camera from .staging_buffer import (copy_colors_to_vbo, copy_normals_to_vbo, copy_vertices_to_vbo, get_vbo_field) -from .utils import get_field_info +from .utils import check_ggui_availability, get_field_info normals_field_cache = {} @@ -23,8 +20,7 @@ def get_normals_field(vertices): normal_weights = field(f32, shape=(N, )) normals_field_cache[vertices] = (normals, normal_weights) return (normals, normal_weights) - else: - return normals_field_cache[vertices] + return normals_field_cache[vertices] @kernel @@ -76,14 +72,15 @@ def gen_normals(vertices, indices): return normals -class Scene(_ti_core.PyScene): +class Scene: """A 3D scene, which can contain meshes and particles, and can be rendered on a canvas """ def __init__(self): - super().__init__() + check_ggui_availability() + self.scene = _ti_core.PyScene() def set_camera(self, camera): - super().set_camera(camera.ptr) + self.scene.set_camera(camera.ptr) def mesh(self, vertices, @@ -113,8 +110,8 @@ def mesh(self, vbo_info = get_field_info(vbo) indices_info = get_field_info(indices) - super().mesh(vbo_info, has_per_vertex_color, indices_info, color, - two_sided) + self.scene.mesh(vbo_info, has_per_vertex_color, indices_info, color, + two_sided) def particles(self, centers, @@ -135,10 +132,10 @@ def particles(self, if has_per_vertex_color: copy_colors_to_vbo(vbo, per_vertex_color) vbo_info = get_field_info(vbo) - super().particles(vbo_info, has_per_vertex_color, color, radius) + self.scene.particles(vbo_info, has_per_vertex_color, color, radius) - def point_light(self, pos, color): - super().point_light(pos, color) + def point_light(self, pos, color): # pylint: disable=W0235 + self.scene.point_light(pos, color) def ambient_light(self, color): - super().ambient_light(color) + self.scene.ambient_light(color) diff --git a/python/taichi/ui/staging_buffer.py b/python/taichi/ui/staging_buffer.py index b66822a125b7c..f7377c9843002 100644 --- a/python/taichi/ui/staging_buffer.py +++ b/python/taichi/ui/staging_buffer.py @@ -1,16 +1,10 @@ -from taichi.core import ti_core as _ti_core -from taichi.lang.impl import default_cfg, field, static from taichi.lang.kernel_impl import kernel from taichi.lang.matrix import Vector -from taichi.lang.ndrange import ndrange -from taichi.lang.ops import atomic_add, get_addr -from taichi.type.annotations import ext_arr, template -from taichi.type.primitive_types import f32, u8 +from taichi.types.annotations import template +from taichi.types.primitive_types import f32, u8 import taichi as ti -from .utils import get_field_info - vbo_field_cache = {} @@ -20,13 +14,12 @@ def get_vbo_field(vertices): pos = 3 normal = 3 tex_coord = 2 - color = 3 + color = 4 vertex_stride = pos + normal + tex_coord + color vbo = Vector.field(vertex_stride, f32, shape=(N, )) vbo_field_cache[vertices] = vbo return vbo - else: - return vbo_field_cache[vertices] + return vbo_field_cache[vertices] @kernel @@ -37,6 +30,14 @@ def copy_to_vbo(vbo: template(), src: template(), offset: template(), vbo[i][offset + c] = src[i][c] +@kernel +def fill_vbo(vbo: template(), value: template(), offset: template(), + num_components: template()): + for i in vbo: + for c in ti.static(range(num_components)): + vbo[i][offset + c] = value + + def validate_input_field(f, name): if f.dtype != f32: raise Exception(f"{name} needs to have dtype f32") @@ -53,29 +54,31 @@ def validate_input_field(f, name): def copy_vertices_to_vbo(vbo, vertices): validate_input_field(vertices, "vertices") if not 2 <= vertices.n <= 3: - raise Exception(f'vertices can only be 2D or 3D vector fields') + raise Exception('vertices can only be 2D or 3D vector fields') copy_to_vbo(vbo, vertices, 0, vertices.n) def copy_normals_to_vbo(vbo, normals): validate_input_field(normals, "normals") if normals.n != 3: - raise Exception(f'normals can only be 3D vector fields') + raise Exception('normals can only be 3D vector fields') copy_to_vbo(vbo, normals, 3, normals.n) def copy_texcoords_to_vbo(vbo, texcoords): validate_input_field(texcoords, "texcoords") if texcoords.n != 2: - raise Exception(f'texcoords can only be 3D vector fields') + raise Exception('texcoords can only be 3D vector fields') copy_to_vbo(vbo, texcoords, 6, texcoords.n) def copy_colors_to_vbo(vbo, colors): validate_input_field(colors, "colors") - if colors.n != 3: - raise Exception(f'colors can only be 3D vector fields') + if colors.n != 3 and colors.n != 4: + raise Exception('colors can only be 3D/4D vector fields') copy_to_vbo(vbo, colors, 8, colors.n) + if colors.n == 3: + fill_vbo(vbo, 1.0, 11, 1) @ti.kernel @@ -87,6 +90,9 @@ def copy_image_f32_to_u8(src: ti.template(), dst: ti.template(), c = max(0.0, min(1.0, c)) c = c * 255 dst[i, j][k] = int(c) + if num_components < 4: + # alpha channel + dst[i, j][3] = 255 @ti.kernel @@ -95,6 +101,9 @@ def copy_image_u8_to_u8(src: ti.template(), dst: ti.template(), for i, j in src: for k in ti.static(range(num_components)): dst[i, j][k] = src[i, j][k] + if num_components < 4: + # alpha channel + dst[i, j][3] = 255 # ggui renderer always assumes the input image to be u8 RGBA @@ -105,11 +114,11 @@ def copy_image_u8_to_u8(src: ti.template(), dst: ti.template(), def to_u8_rgba(image): if not hasattr(image, 'n') or image.m != 1: raise Exception( - f'the input image needs to be a Vector field (matrix with 1 column)' + 'the input image needs to be a Vector field (matrix with 1 column)' ) if len(image.shape) != 2: raise Exception( - f"the shape of the image must be of the form (width,height)") + "the shape of the image must be of the form (width,height)") if image.dtype == u8 and image.n == 4: # already in the desired format diff --git a/python/taichi/ui/ui.py b/python/taichi/ui/ui.py index 7026e4c6b01b8..d98c1d1215ae1 100644 --- a/python/taichi/ui/ui.py +++ b/python/taichi/ui/ui.py @@ -1,37 +1,17 @@ -import pathlib +from taichi._lib import core as _ti_core -from taichi.core import ti_core as _ti_core -from taichi.lang.impl import default_cfg, field -from taichi.lang.kernel_impl import kernel -from taichi.lang.ops import get_addr -from taichi.type.annotations import ext_arr, template +from .camera import Camera +from .canvas import Canvas # pylint: disable=unused-import +from .constants import * # pylint: disable=unused-import,wildcard-import +from .imgui import Gui # pylint: disable=unused-import +from .scene import Scene # pylint: disable=unused-import +from .utils import check_ggui_availability +from .window import Window # pylint: disable=unused-import -if _ti_core.GGUI_AVAILABLE: - from .camera import Camera - from .canvas import Canvas - from .constants import * - from .gui import Gui - from .scene import Scene - from .window import Window +def make_camera(): + check_ggui_availability() + return Camera(_ti_core.PyCamera()) - def make_camera(): - return Camera(_ti_core.PyCamera()) - ProjectionMode = _ti_core.ProjectionMode - -else: - - def err_no_ggui(): - raise Exception("GGUI Not Available") - - class Window: - def __init__(self, name, res, vsync=False): - err_no_ggui() - - class Scene: - def __init__(self): - err_no_ggui() - - def make_camera(): - err_no_ggui() +ProjectionMode = _ti_core.ProjectionMode if _ti_core.GGUI_AVAILABLE else None diff --git a/python/taichi/ui/utils.py b/python/taichi/ui/utils.py index 0b18c726f81f1..3bc7aa1604450 100644 --- a/python/taichi/ui/utils.py +++ b/python/taichi/ui/utils.py @@ -1,13 +1,8 @@ -import pathlib -from math import acos, asin, cos, pi, sin +from math import acos, asin, cos, sin -from taichi.core import ti_core as _ti_core +from taichi._lib import core as _ti_core from taichi.lang.impl import default_cfg -from taichi.lang.kernel_impl import kernel from taichi.lang.matrix import Vector -from taichi.lang.ops import get_addr -from taichi.type.annotations import ext_arr, template -from taichi.type.primitive_types import u64 def get_field_info(field): @@ -20,6 +15,8 @@ def get_field_info(field): info.field_source = _ti_core.FieldSource.TaichiCuda elif default_cfg().arch == _ti_core.x64: info.field_source = _ti_core.FieldSource.TaichiX64 + elif default_cfg().arch == _ti_core.vulkan: + info.field_source = _ti_core.FieldSource.TaichiVulkan else: raise Exception("unsupported taichi backend") info.shape = [n for n in field.shape] @@ -59,7 +56,12 @@ def vec_to_euler(v): yaw = 0 else: yaw = acos(cos_yaw) - if (sin_yaw < 0): + if sin_yaw < 0: yaw = -yaw return yaw, pitch + + +def check_ggui_availability(): + if not _ti_core.GGUI_AVAILABLE: + raise Exception("GGUI Not Available") diff --git a/python/taichi/ui/window.py b/python/taichi/ui/window.py index fd442798bf007..7d6827e3d787e 100644 --- a/python/taichi/ui/window.py +++ b/python/taichi/ui/window.py @@ -1,18 +1,15 @@ import pathlib -from taichi.core import ti_core as _ti_core -from taichi.lang.impl import default_cfg -from taichi.lang.kernel_impl import kernel -from taichi.lang.ops import get_addr -from taichi.type.annotations import ext_arr, template +from taichi._lib import core as _ti_core +from taichi.lang.impl import default_cfg, get_runtime from .canvas import Canvas -from .constants import * -from .gui import Gui -from .utils import get_field_info +from .constants import PRESS, RELEASE +from .imgui import Gui +from .utils import check_ggui_availability -class Window(_ti_core.PyWindow): +class Window: """The window class. Args: @@ -20,20 +17,31 @@ class Window(_ti_core.PyWindow): res (Tuple[Int]): resolution (width, height) of the window, in pixels. layout (vsync): whether or not vertical sync should be enabled. """ - def __init__(self, name, res, vsync=False): + def __init__(self, name, res, vsync=False, show_window=True): + check_ggui_availability() package_path = str(pathlib.Path(__file__).parent.parent) ti_arch = default_cfg().arch is_packed = default_cfg().packed - super().__init__(name, res, vsync, package_path, ti_arch, is_packed) + self.window = _ti_core.PyWindow(get_runtime().prog, name, res, vsync, + show_window, package_path, ti_arch, + is_packed) @property def running(self): - return self.is_running() + return self.window.is_running() @running.setter def running(self, value): - self.set_is_running(value) + self.window.set_is_running(value) + + @property + def event(self): + return self.window.get_current_event() + + @event.setter + def event(self, value): + self.window.set_current_event(value) def get_events(self, tag=None): """ Obtain a list of unprocessed events. @@ -42,11 +50,11 @@ def get_events(self, tag=None): tag (str): A tag used for filtering events. If it is None, then all events are returned. """ if tag is None: - return super().get_events(_ti_core.EventType.Any) - elif tag is PRESS: - return super().get_events(_ti_core.EventType.Press) - elif tag is RELEASE: - return super().get_events(_ti_core.EventType.Release) + return self.window.get_events(_ti_core.EventType.Any) + if tag is PRESS: + return self.window.get_events(_ti_core.EventType.Press) + if tag is RELEASE: + return self.window.get_events(_ti_core.EventType.Release) raise Exception("unrecognized event tag") def get_event(self, tag=None): @@ -56,24 +64,36 @@ def get_event(self, tag=None): """ if tag is None: - return super().get_event(_ti_core.EventType.Any) - elif tag is PRESS: - return super().get_event(_ti_core.EventType.Press) - elif tag is RELEASE: - return super().get_event(_ti_core.EventType.Release) + return self.window.get_event(_ti_core.EventType.Any) + if tag is PRESS: + return self.window.get_event(_ti_core.EventType.Press) + if tag is RELEASE: + return self.window.get_event(_ti_core.EventType.Release) raise Exception("unrecognized event tag") def is_pressed(self, *keys): for k in keys: - if super().is_pressed(k): + if self.window.is_pressed(k): return True return False def get_canvas(self): """Returns a canvas handle. See :class`~taichi.ui.canvas.Canvas` """ - return Canvas(super().get_canvas()) + return Canvas(self.window.get_canvas()) @property def GUI(self): """Returns a IMGUI handle. See :class`~taichi.ui.ui.Gui` """ - return Gui(super().GUI()) + return Gui(self.window.GUI()) + + def get_cursor_pos(self): + return self.window.get_cursor_pos() + + def show(self): + return self.window.show() + + def write_image(self, filename): + return self.window.write_image(filename) + + def destroy(self): + return self.window.destroy() diff --git a/python/ti.cpp b/python/ti.cpp deleted file mode 100644 index f5ab3400cccf9..0000000000000 --- a/python/ti.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include -#include -#include -#include -#include "taichi/platform/windows/windows.h" -#include -#include -#include - -void main(int argc, char **argv) { - Py_SetProgramName(L"ti"); - Py_Initialize(); - - std::vector argv_converted; - std::vector argv_char; - argv_converted.resize(argc); - argv_char.resize(argc); - - for (int i = 0; i < argc; i++) { - int buffer_len = 3 * std::strlen(argv[i]) + 2; - // printf("len %d\n", buffer_len); - // would rather be safe here... TODO: figure out the maximum converted - // length - argv_converted[i].resize(buffer_len); - auto ret = mbstowcs(&argv_converted[i][0], argv[i], buffer_len); - argv_char[i] = &argv_converted[i][0]; - } - PySys_SetArgv(argc, &argv_char[0]); - // TODO: implement release mode for this - auto dir = getenv("TAICHI_REPO_DIR"); - if (dir == nullptr) { - std::cout << "Please set TAICHI_REPO_DIR" << std::endl; - exit(-1); - } - auto path = std::string(dir) + "/bin/ti"; - auto file = std::fopen(path.c_str(), "r"); - PyRun_SimpleFile(file, "ti"); - Py_Finalize(); - return; -} diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 93b7d71e9a97d..0000000000000 --- a/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -GitPython -astor -autograd -colorama -coverage -isort -numpy -pybind11 -pylint -pytest -pytest-rerunfailures -pytest-xdist -setuptools -sourceinspect -yapf==0.31.0 diff --git a/requirements_dev.txt b/requirements_dev.txt index 2620b6b2fb3dd..718434b2d8d72 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,19 +1,16 @@ cmake colorama coverage -numpy Pillow pybind11 GitPython yapf==0.31.0 distro -autograd astor sourceinspect -pylint -pytest -pytest-xdist -pytest-rerunfailures -pytest-cov -torch isort +pylint +requests==2.26 +twine +wheel +astunparse diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000000000..84228242d1dc7 --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1,8 @@ +pytest +pytest-xdist +pytest-rerunfailures +pytest-cov +numpy +autograd +requests==2.26 +matplotlib diff --git a/scripts/generate_pylint_tags.py b/scripts/generate_pylint_tags.py new file mode 100644 index 0000000000000..5413cc7ff3797 --- /dev/null +++ b/scripts/generate_pylint_tags.py @@ -0,0 +1,31 @@ +TAGS = { + 'C0121': True, + 'C0415': True, + 'W0611': True, + 'W0202': True, + 'W0621': True, + 'W0622': True, + 'W0401': True, + 'C0209': True, + 'W0404': True, + 'W0612': True, + 'E1101': True, + 'R0402': True, + 'R0201': True, + 'W0235': True, + 'R1705': True, + 'C0200': True, + 'R0205': True, + 'R1732': True, + 'W0101': True, + 'R1710': True, + 'R1703': True, + 'W0108': True, + 'W1309': True, + 'C0321': True, + 'C0325': True, +} + +if __name__ == '__main__': + enabled = [kv[0] for kv in TAGS.items() if kv[1]] + print(','.join(enabled)) diff --git a/scripts/run-clang-tidy.py b/scripts/run_clang_tidy.py similarity index 97% rename from scripts/run-clang-tidy.py rename to scripts/run_clang_tidy.py index 8921b1859c8ac..229d91bfdf4d0 100644 --- a/scripts/run-clang-tidy.py +++ b/scripts/run_clang_tidy.py @@ -82,6 +82,7 @@ def get_tidy_invocation(f, clang_tidy_binary, checks, tmpdir, build_path, config): """Gets a command line for clang-tidy.""" start = [clang_tidy_binary] + start.append('-warnings-as-errors=*') if header_filter is not None: start.append('-header-filter=' + header_filter) if checks: @@ -172,12 +173,12 @@ def run_tidy(args, tmpdir, build_path, queue, lock, failed_files): output, err = proc.communicate() if proc.returncode != 0: failed_files.append(name) - with lock: - sys.stdout.write(' '.join(invocation) + '\n' + - output.decode('utf-8')) - if len(err) > 0: - sys.stdout.flush() - sys.stderr.write(err.decode('utf-8')) + with lock: + sys.stdout.write(' '.join(invocation) + '\n' + + output.decode('utf-8')) + if len(err) > 0: + sys.stdout.flush() + sys.stderr.write(err.decode('utf-8')) queue.task_done() @@ -324,6 +325,8 @@ def main(): task_queue.join() if len(failed_files): return_code = 1 + else: + print("No errors detected, congratulations!") except KeyboardInterrupt: # This is a sad hack. Unfortunately subprocess goes diff --git a/scripts/run_clang_tidy.sh b/scripts/run_clang_tidy.sh index bc268d3e797d9..e5bd6813270ad 100644 --- a/scripts/run_clang_tidy.sh +++ b/scripts/run_clang_tidy.sh @@ -6,4 +6,6 @@ mkdir -p build_clang_tidy/ cd build_clang_tidy cmake .. -DCMAKE_CXX_COMPILER=clang -DCMAKE_C_COMPILER=clang -DCMAKE_EXPORT_COMPILE_COMMANDS=ON cd .. -python3 scripts/run-clang-tidy.py $PWD/taichi -header-filter="$PWD/taichi/" -p build_clang_tidy -j16 -fix +TAICHI_SRC=$PWD/taichi +VAR=${1:-${TAICHI_SRC}} +python3 scripts/run_clang_tidy.py $PWD/taichi -header-filter="$PWD/taichi/" -p build_clang_tidy -j16 -fix diff --git a/setup.cfg b/setup.cfg index bd804696ecc4f..e48864773ecf0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,3 @@ [metadata] -description_file = README +long_description = file: README.md +long_description_content_type = text/markdown; charset=UTF-8 diff --git a/setup.py b/setup.py index 274f795ce3f54..355e09679ef74 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ # Optional environment variables supported by setup.py: -# DEBUG -# build the C++ taichi_core extension with debug symbols. +# {DEBUG, RELWITHDEBINFO, MINSIZEREL} +# build the C++ taichi_core extension with various build types. # # TAICHI_CMAKE_ARGS # extra cmake args for C++ taichi_core extension. @@ -32,15 +32,25 @@ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', ] + +def get_version(): + if os.getenv("RELEASE_VERSION"): + version = os.environ["RELEASE_VERSION"] + else: + version_file = os.path.join(os.path.dirname(__file__), 'version.txt') + with open(version_file, 'r') as f: + version = f.read().strip() + return version.lstrip("v") + + project_name = os.getenv('PROJECT_NAME', 'taichi') -TI_VERSION_MAJOR = 0 -TI_VERSION_MINOR = 8 -TI_VERSION_PATCH = 3 -version = f'{TI_VERSION_MAJOR}.{TI_VERSION_MINOR}.{TI_VERSION_PATCH}' +version = get_version() +TI_VERSION_MAJOR, TI_VERSION_MINOR, TI_VERSION_PATCH = version.split('.') -data_files = glob.glob('python/lib/*') +data_files = glob.glob('python/_lib/runtime/*') print(data_files) packages = find_packages('python') print(packages) @@ -65,13 +75,19 @@ def get_os_name(): return 'win' elif name.lower().startswith('linux'): return 'linux' + elif 'bsd' in name.lower(): + return 'unix' assert False, "Unknown platform name %s" % name def remove_tmp(taichi_dir): shutil.rmtree(os.path.join(taichi_dir, 'assets'), ignore_errors=True) - shutil.rmtree(os.path.join(taichi_dir, 'examples'), ignore_errors=True) - shutil.rmtree(os.path.join(taichi_dir, 'tests'), ignore_errors=True) + + +def remove_files_with_extension(dir_name, extension): + for file in os.listdir(dir_name): + if file.endswith(extension): + os.remove(os.path.join(dir_name, file)) class CMakeExtension(Extension): @@ -84,8 +100,6 @@ def run(self): taichi_dir = os.path.join(package_dir, 'taichi') remove_tmp(taichi_dir) - shutil.copytree('tests/python', os.path.join(taichi_dir, 'tests')) - shutil.copytree('examples', os.path.join(taichi_dir, 'examples')) shutil.copytree('external/assets', os.path.join(taichi_dir, 'assets')) egg_info.run(self) @@ -110,7 +124,7 @@ def parse_cmake_args_from_env(self): def run(self): try: - out = subprocess.check_output(['cmake', '--version']) + subprocess.check_call(['cmake', '--version']) except OSError: raise RuntimeError( "CMake must be installed to build the following extensions: " + @@ -132,8 +146,21 @@ def run(self): f'-DTI_VERSION_PATCH={TI_VERSION_PATCH}', ] - self.debug = os.getenv('DEBUG', '0') in ('1', 'ON') - cfg = 'Debug' if self.debug else 'Release' + emscriptened = os.getenv('TI_EMSCRIPTENED', '0') in ('1', 'ON') + if emscriptened: + cmake_args += ['-DTI_EMSCRIPTENED=ON'] + + if shutil.which('ninja'): + cmake_args += ['-GNinja'] + + cfg = 'Release' + if (os.getenv('DEBUG', '0') in ('1', 'ON')): + cfg = 'Debug' + elif (os.getenv('RELWITHDEBINFO', '0') in ('1', 'ON')): + cfg = 'RelWithDebInfo' + elif (os.getenv('MINSIZEREL', '0') in ('1', 'ON')): + cfg = 'MinSizeRel' + build_args = ['--config', cfg] cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] @@ -150,6 +177,7 @@ def run(self): os.makedirs(self.build_temp, exist_ok=True) print('-' * 10, 'Running CMake prepare', '-' * 40) + print(' '.join(['cmake', cmake_list_dir] + cmake_args)) subprocess.check_call(['cmake', cmake_list_dir] + cmake_args, cwd=self.build_temp, env=env) @@ -164,36 +192,50 @@ def prepare_package(self): # We need to make sure these additional files are ready for # - develop mode: must exist in local python/taichi/lib/ folder # - install mode: must exist in self.build_lib/taichi/lib - taichi_lib_dir = 'taichi/lib' - for target in ( - os.path.join(package_dir, taichi_lib_dir), - os.path.join(self.build_lib, taichi_lib_dir), - ): - shutil.rmtree(target, ignore_errors=True) - os.makedirs(target) - if get_os_name() == 'linux': - shutil.copy(os.path.join(self.build_temp, 'libtaichi_core.so'), - os.path.join(target, 'taichi_core.so')) - elif get_os_name() == 'osx': - shutil.copy( - os.path.join(self.build_temp, 'libtaichi_core.dylib'), - os.path.join(target, 'taichi_core.so')) - else: - shutil.copy('runtimes/Release/taichi_core.dll', - os.path.join(target, 'taichi_core.pyd')) - - if get_os_name() != 'osx': - libdevice_path = 'external/cuda_libdevice/slim_libdevice.10.bc' - print("copying libdevice:", libdevice_path) - assert os.path.exists(libdevice_path) - shutil.copy(libdevice_path, - os.path.join(target, 'slim_libdevice.10.bc')) - - llvm_runtime_dir = 'taichi/runtime/llvm' - for f in os.listdir(llvm_runtime_dir): - if f.startswith('runtime_') and f.endswith('.bc'): - print(f"Fetching runtime file {f} to {target} folder") - shutil.copy(os.path.join(llvm_runtime_dir, f), target) + base_dir = package_dir if self.inplace else self.build_lib + taichi_lib_dir = os.path.join(base_dir, 'taichi', '_lib') + + runtime_dir = os.path.join(taichi_lib_dir, "runtime") + core_dir = os.path.join(taichi_lib_dir, "core") + os.makedirs(runtime_dir, exist_ok=True) + os.makedirs(core_dir, exist_ok=True) + + if (get_os_name() == 'linux' or get_os_name() == 'unix' + or get_os_name() == 'osx'): + remove_files_with_extension(core_dir, ".so") + else: + remove_files_with_extension(core_dir, ".pyd") + if get_os_name() == 'osx': + remove_files_with_extension(runtime_dir, ".dylib") + remove_files_with_extension(runtime_dir, ".bc") + + if get_os_name() == 'linux' or get_os_name() == 'unix': + self.copy_file(os.path.join(self.build_temp, 'libtaichi_core.so'), + os.path.join(core_dir, 'taichi_core.so')) + elif get_os_name() == 'osx': + self.copy_file( + os.path.join(self.build_temp, 'libtaichi_core.dylib'), + os.path.join(core_dir, 'taichi_core.so')) + moltenvk_path = os.path.join(self.build_temp, 'libMoltenVK.dylib') + if os.path.exists(moltenvk_path): + self.copy_file(moltenvk_path, + os.path.join(runtime_dir, 'libMoltenVK.dylib')) + else: + self.copy_file('runtimes/taichi_core.dll', + os.path.join(core_dir, 'taichi_core.pyd')) + + if get_os_name() != 'osx': + libdevice_path = 'external/cuda_libdevice/slim_libdevice.10.bc' + print("copying libdevice:", libdevice_path) + assert os.path.exists(libdevice_path) + self.copy_file(libdevice_path, + os.path.join(runtime_dir, 'slim_libdevice.10.bc')) + + llvm_runtime_dir = 'taichi/runtime/llvm' + for f in os.listdir(llvm_runtime_dir): + if f.startswith('runtime_') and f.endswith('.bc'): + print(f"Fetching runtime file {f} to {taichi_lib_dir} folder") + self.copy_file(os.path.join(llvm_runtime_dir, f), runtime_dir) class Clean(clean): @@ -203,8 +245,8 @@ def run(self): if os.path.exists(self.build_temp): remove_tree(self.build_temp, dry_run=self.dry_run) generated_folders = ('bin', 'dist', 'python/taichi/assets', - 'python/taichi/lib', 'python/taichi/examples', - 'python/taichi/tests', 'python/taichi.egg-info') + 'python/taichi/_lib/runtime', + 'python/taichi.egg-info') for d in generated_folders: if os.path.exists(d): remove_tree(d, dry_run=self.dry_run) @@ -212,7 +254,8 @@ def run(self): 'taichi/common/commit_hash.h', 'taichi/common/version.h' ] generated_files += glob.glob('taichi/runtime/llvm/runtime_*.bc') - generated_files += glob.glob('taichi/runtime/llvm/runtime_*.ll') + generated_files += glob.glob('python/taichi/_lib/core/*.so') + generated_files += glob.glob('python/taichi/_lib/core/*.pyd') for f in generated_files: if os.path.exists(f): print(f'removing generated file {f}') @@ -228,20 +271,18 @@ def run(self): author='Taichi developers', author_email='yuanmhu@gmail.com', url='https://github.com/taichi-dev/taichi', + python_requires=">=3.6,<3.11", install_requires=[ - 'numpy', - 'pybind11>=2.5.0', - 'sourceinspect>=0.0.4', - 'colorama', - 'astor', + 'numpy', 'sourceinspect>=0.0.4', 'colorama', 'astor', + 'astunparse;python_version<"3.9"' ], - data_files=[('lib', data_files)], + data_files=[(os.path.join('_lib', 'runtime'), data_files)], keywords=['graphics', 'simulation'], license='MIT', include_package_data=True, entry_points={ 'console_scripts': [ - 'ti=taichi.main:main', + 'ti=taichi._main:main', ], }, classifiers=classifiers, diff --git a/taichi/analysis/alias_analysis.cpp b/taichi/analysis/alias_analysis.cpp index 24e3acfb8fa8d..5f5928997eea1 100644 --- a/taichi/analysis/alias_analysis.cpp +++ b/taichi/analysis/alias_analysis.cpp @@ -29,8 +29,17 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) { Stmt *origin1 = retrieve_local(var1); Stmt *origin2 = retrieve_local(var2); if (origin1 != nullptr && origin2 != nullptr) { - if (origin1 == origin2) + if (origin1 == origin2) { + if (var1->is() && var2->is()) { + auto diff = value_diff_ptr_index(var1->cast()->offset, + var2->cast()->offset); + if (diff.is_diff_certain) { + return diff.diff_range == 0 ? AliasResult::same + : AliasResult::different; + } + } return AliasResult::uncertain; + } if (origin1->is() || origin2->is()) return AliasResult::different; TI_ASSERT(origin1->is() && @@ -82,16 +91,36 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) { : AliasResult::uncertain; } + TI_ASSERT(var1->width() == 1); + TI_ASSERT(var2->width() == 1); + if (var1->is() || var2->is()) { if (!var1->is() || !var2->is()) return AliasResult::different; - return AliasResult::uncertain; + auto ptr1 = var1->as(); + auto ptr2 = var2->as(); + if (ptr1->base_ptrs[0] != ptr2->base_ptrs[0]) { + auto base1 = ptr1->base_ptrs[0]->as(); + auto base2 = ptr2->base_ptrs[0]->as(); + if (base1->arg_id != base2->arg_id) { + return AliasResult::different; + } + } + TI_ASSERT(ptr1->indices.size() == ptr2->indices.size()); + bool uncertain = false; + for (int i = 0; i < (int)ptr1->indices.size(); i++) { + auto diff = value_diff_ptr_index(ptr1->indices[i], ptr2->indices[i]); + if (!diff.is_diff_certain) { + uncertain = true; + } else if (diff.diff_range != 0) { + return AliasResult::different; + } + } + return uncertain ? AliasResult::uncertain : AliasResult::same; } // If both statements are GlobalPtrStmts or GetChStmts, we can check by // SNode::id. - TI_ASSERT(var1->width() == 1); - TI_ASSERT(var2->width() == 1); auto get_snode_id = [](Stmt *s) { if (auto ptr = s->cast()) { return ptr->snodes[0]->id; diff --git a/taichi/analysis/bls_analyzer.cpp b/taichi/analysis/bls_analyzer.cpp index 868df66abd849..e9e7e1f4f54d1 100644 --- a/taichi/analysis/bls_analyzer.cpp +++ b/taichi/analysis/bls_analyzer.cpp @@ -50,7 +50,7 @@ void BLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) { for (int i = 0; i < num_indices; i++) { auto diff = irpass::analysis::value_diff_loop_index(ptr->indices[i], for_stmt_, i); - if (diff.related_() && diff.coeff > 0) { + if (diff.related() && diff.coeff > 0) { offsets[i].low = diff.low; offsets[i].high = diff.high; coeffs[i] = diff.coeff; diff --git a/taichi/analysis/build_cfg.cpp b/taichi/analysis/build_cfg.cpp index b4e2eda30b86f..1ae05aeb46d41 100644 --- a/taichi/analysis/build_cfg.cpp +++ b/taichi/analysis/build_cfg.cpp @@ -62,18 +62,18 @@ namespace lang { class CFGBuilder : public IRVisitor { public: CFGBuilder() - : current_block(nullptr), - last_node_in_current_block(nullptr), - current_stmt_id(-1), - begin_location(-1), - current_offload(nullptr), - in_parallel_for(false) { + : current_block_(nullptr), + last_node_in_current_block_(nullptr), + current_stmt_id_(-1), + begin_location_(-1), + current_offload_(nullptr), + in_parallel_for_(false) { allow_undefined_visitor = true; invoke_default_visitor = true; - graph = std::make_unique(); + graph_ = std::make_unique(); // Make an empty start node. - auto start_node = graph->push_back(); - prev_nodes.push_back(start_node); + auto start_node = graph_->push_back(); + prev_nodes_.push_back(start_node); } void visit(Stmt *stmt) override { @@ -93,18 +93,18 @@ class CFGBuilder : public IRVisitor { * @return The node which is just created. */ CFGNode *new_node(int next_begin_location) { - auto node = graph->push_back( - current_block, begin_location, /*end_location=*/current_stmt_id, - /*is_parallel_executed=*/in_parallel_for, - /*prev_node_in_same_block=*/last_node_in_current_block); - for (auto &prev_node : prev_nodes) { + auto node = graph_->push_back( + current_block_, begin_location_, /*end_location=*/current_stmt_id_, + /*is_parallel_executed=*/in_parallel_for_, + /*prev_node_in_same_block=*/last_node_in_current_block_); + for (auto &prev_node : prev_nodes_) { // Now that the "(next node)" is created, we should insert edges // "node... -> (next node)" here. CFGNode::add_edge(prev_node, node); } - prev_nodes.clear(); - begin_location = next_begin_location; - last_node_in_current_block = node; + prev_nodes_.clear(); + begin_location_ = next_begin_location; + last_node_in_current_block_ = node; return node; } @@ -125,7 +125,7 @@ class CFGBuilder : public IRVisitor { */ void visit(ContinueStmt *stmt) override { // Don't put ContinueStmt in any CFGNodes. - continues_in_current_loop.push_back(new_node(current_stmt_id + 1)); + continues_in_current_loop_.push_back(new_node(current_stmt_id_ + 1)); } /** @@ -145,9 +145,9 @@ class CFGBuilder : public IRVisitor { */ void visit(WhileControlStmt *stmt) override { // Don't put WhileControlStmt in any CFGNodes. - auto node = new_node(current_stmt_id + 1); - breaks_in_current_loop.push_back(node); - prev_nodes.push_back(node); + auto node = new_node(current_stmt_id_ + 1); + breaks_in_current_loop_.push_back(node); + prev_nodes_.push_back(node); } /** @@ -174,20 +174,20 @@ class CFGBuilder : public IRVisitor { */ void visit(FuncCallStmt *stmt) override { auto node_before_func_call = new_node(-1); - CFGFuncKey func_key = {stmt->func->func_key, in_parallel_for}; - if (node_func_begin.count(func_key) == 0) { + CFGFuncKey func_key = {stmt->func->func_key, in_parallel_for_}; + if (node_func_begin_.count(func_key) == 0) { // Generate CFG for the function. TI_ASSERT(stmt->func->ir->is()); - auto func_begin_index = graph->size(); + auto func_begin_index = graph_->size(); stmt->func->ir->accept(this); - node_func_begin[func_key] = graph->nodes[func_begin_index].get(); - node_func_end[func_key] = graph->nodes.back().get(); + node_func_begin_[func_key] = graph_->nodes[func_begin_index].get(); + node_func_end_[func_key] = graph_->nodes.back().get(); } - CFGNode::add_edge(node_before_func_call, node_func_begin[func_key]); - prev_nodes.push_back(node_func_end[func_key]); + CFGNode::add_edge(node_before_func_call, node_func_begin_[func_key]); + prev_nodes_.push_back(node_func_end_[func_key]); // Don't put FuncCallStmt in any CFGNodes. - begin_location = current_stmt_id + 1; + begin_location_ = current_stmt_id_ + 1; } /** @@ -219,27 +219,27 @@ class CFGBuilder : public IRVisitor { auto before_if = new_node(-1); CFGNode *true_branch_end = nullptr; if (if_stmt->true_statements) { - auto true_branch_begin = graph->size(); + auto true_branch_begin = graph_->size(); if_stmt->true_statements->accept(this); - CFGNode::add_edge(before_if, graph->nodes[true_branch_begin].get()); - true_branch_end = graph->back(); + CFGNode::add_edge(before_if, graph_->nodes[true_branch_begin].get()); + true_branch_end = graph_->back(); } CFGNode *false_branch_end = nullptr; if (if_stmt->false_statements) { - auto false_branch_begin = graph->size(); + auto false_branch_begin = graph_->size(); if_stmt->false_statements->accept(this); - CFGNode::add_edge(before_if, graph->nodes[false_branch_begin].get()); - false_branch_end = graph->back(); + CFGNode::add_edge(before_if, graph_->nodes[false_branch_begin].get()); + false_branch_end = graph_->back(); } - TI_ASSERT(prev_nodes.empty()); + TI_ASSERT(prev_nodes_.empty()); if (if_stmt->true_statements) - prev_nodes.push_back(true_branch_end); + prev_nodes_.push_back(true_branch_end); if (if_stmt->false_statements) - prev_nodes.push_back(false_branch_end); + prev_nodes_.push_back(false_branch_end); if (!if_stmt->true_statements || !if_stmt->false_statements) - prev_nodes.push_back(before_if); + prev_nodes_.push_back(before_if); // Container statements don't belong to any CFGNodes. - begin_location = current_stmt_id + 1; + begin_location_ = current_stmt_id_ + 1; } /** @@ -262,34 +262,34 @@ class CFGBuilder : public IRVisitor { * } */ void visit_loop(Block *body, CFGNode *before_loop, bool is_while_true) { - int loop_stmt_id = current_stmt_id; - auto backup_continues = std::move(continues_in_current_loop); - auto backup_breaks = std::move(breaks_in_current_loop); - continues_in_current_loop.clear(); - breaks_in_current_loop.clear(); + int loop_stmt_id = current_stmt_id_; + auto backup_continues = std::move(continues_in_current_loop_); + auto backup_breaks = std::move(breaks_in_current_loop_); + continues_in_current_loop_.clear(); + breaks_in_current_loop_.clear(); - auto loop_begin_index = graph->size(); + auto loop_begin_index = graph_->size(); body->accept(this); - auto loop_begin = graph->nodes[loop_begin_index].get(); + auto loop_begin = graph_->nodes[loop_begin_index].get(); CFGNode::add_edge(before_loop, loop_begin); - auto loop_end = graph->back(); + auto loop_end = graph_->back(); CFGNode::add_edge(loop_end, loop_begin); if (!is_while_true) { - prev_nodes.push_back(before_loop); - prev_nodes.push_back(loop_end); + prev_nodes_.push_back(before_loop); + prev_nodes_.push_back(loop_end); } - for (auto &node : continues_in_current_loop) { + for (auto &node : continues_in_current_loop_) { CFGNode::add_edge(node, loop_begin); - prev_nodes.push_back(node); + prev_nodes_.push_back(node); } - for (auto &node : breaks_in_current_loop) { - prev_nodes.push_back(node); + for (auto &node : breaks_in_current_loop_) { + prev_nodes_.push_back(node); } // Container statements don't belong to any CFGNodes. - begin_location = loop_stmt_id + 1; - continues_in_current_loop = std::move(backup_continues); - breaks_in_current_loop = std::move(backup_breaks); + begin_location_ = loop_stmt_id + 1; + continues_in_current_loop_ = std::move(backup_continues); + breaks_in_current_loop_ = std::move(backup_breaks); } void visit(WhileStmt *stmt) override { @@ -297,19 +297,27 @@ class CFGBuilder : public IRVisitor { } void visit(RangeForStmt *stmt) override { - auto old_in_parallel_for = in_parallel_for; - if (!current_offload) - in_parallel_for = true; + auto old_in_parallel_for = in_parallel_for_; + if (!current_offload_) + in_parallel_for_ = true; visit_loop(stmt->body.get(), new_node(-1), false); - in_parallel_for = old_in_parallel_for; + in_parallel_for_ = old_in_parallel_for; } void visit(StructForStmt *stmt) override { - auto old_in_parallel_for = in_parallel_for; - if (!current_offload) - in_parallel_for = true; + auto old_in_parallel_for = in_parallel_for_; + if (!current_offload_) + in_parallel_for_ = true; visit_loop(stmt->body.get(), new_node(-1), false); - in_parallel_for = old_in_parallel_for; + in_parallel_for_ = old_in_parallel_for; + } + + void visit(MeshForStmt *stmt) override { + auto old_in_parallel_for = in_parallel_for_; + if (!current_offload_) + in_parallel_for_ = true; + visit_loop(stmt->body.get(), new_node(-1), false); + in_parallel_for_ = old_in_parallel_for; } /** @@ -320,6 +328,9 @@ class CFGBuilder : public IRVisitor { * } -> node_tls_prologue; * node_tls_prologue { * ... + * } -> node_mesh_prologue; + * node_mesh_prologue: + * ... * } -> node_bls_prologue; * node_bls_prologue { * ... @@ -338,63 +349,74 @@ class CFGBuilder : public IRVisitor { * } */ void visit(OffloadedStmt *stmt) override { - current_offload = stmt; + current_offload_ = stmt; if (stmt->tls_prologue) { auto before_offload = new_node(-1); - int offload_stmt_id = current_stmt_id; - auto block_begin_index = graph->size(); + int offload_stmt_id = current_stmt_id_; + auto block_begin_index = graph_->size(); stmt->tls_prologue->accept(this); - prev_nodes.push_back(graph->back()); + prev_nodes_.push_back(graph_->back()); + // Container statements don't belong to any CFGNodes. + begin_location_ = offload_stmt_id + 1; + CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); + } + if (stmt->mesh_prologue) { + auto before_offload = new_node(-1); + int offload_stmt_id = current_stmt_id_; + auto block_begin_index = graph_->size(); + stmt->mesh_prologue->accept(this); + prev_nodes_.push_back(graph_->back()); // Container statements don't belong to any CFGNodes. - begin_location = offload_stmt_id + 1; - CFGNode::add_edge(before_offload, graph->nodes[block_begin_index].get()); + begin_location_ = offload_stmt_id + 1; + CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); } if (stmt->bls_prologue) { auto before_offload = new_node(-1); - int offload_stmt_id = current_stmt_id; - auto block_begin_index = graph->size(); + int offload_stmt_id = current_stmt_id_; + auto block_begin_index = graph_->size(); stmt->bls_prologue->accept(this); - prev_nodes.push_back(graph->back()); + prev_nodes_.push_back(graph_->back()); // Container statements don't belong to any CFGNodes. - begin_location = offload_stmt_id + 1; - CFGNode::add_edge(before_offload, graph->nodes[block_begin_index].get()); + begin_location_ = offload_stmt_id + 1; + CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); } if (stmt->has_body()) { auto before_offload = new_node(-1); - int offload_stmt_id = current_stmt_id; - auto block_begin_index = graph->size(); + int offload_stmt_id = current_stmt_id_; + auto block_begin_index = graph_->size(); if (stmt->task_type == OffloadedStmt::TaskType::range_for || - stmt->task_type == OffloadedStmt::TaskType::struct_for) { - in_parallel_for = true; + stmt->task_type == OffloadedStmt::TaskType::struct_for || + stmt->task_type == OffloadedStmt::TaskType::mesh_for) { + in_parallel_for_ = true; } stmt->body->accept(this); - in_parallel_for = false; - prev_nodes.push_back(graph->back()); + in_parallel_for_ = false; + prev_nodes_.push_back(graph_->back()); // Container statements don't belong to any CFGNodes. - begin_location = offload_stmt_id + 1; - CFGNode::add_edge(before_offload, graph->nodes[block_begin_index].get()); + begin_location_ = offload_stmt_id + 1; + CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); } if (stmt->bls_epilogue) { auto before_offload = new_node(-1); - int offload_stmt_id = current_stmt_id; - auto block_begin_index = graph->size(); + int offload_stmt_id = current_stmt_id_; + auto block_begin_index = graph_->size(); stmt->bls_epilogue->accept(this); - prev_nodes.push_back(graph->back()); + prev_nodes_.push_back(graph_->back()); // Container statements don't belong to any CFGNodes. - begin_location = offload_stmt_id + 1; - CFGNode::add_edge(before_offload, graph->nodes[block_begin_index].get()); + begin_location_ = offload_stmt_id + 1; + CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); } if (stmt->tls_epilogue) { auto before_offload = new_node(-1); - int offload_stmt_id = current_stmt_id; - auto block_begin_index = graph->size(); + int offload_stmt_id = current_stmt_id_; + auto block_begin_index = graph_->size(); stmt->tls_epilogue->accept(this); - prev_nodes.push_back(graph->back()); + prev_nodes_.push_back(graph_->back()); // Container statements don't belong to any CFGNodes. - begin_location = offload_stmt_id + 1; - CFGNode::add_edge(before_offload, graph->nodes[block_begin_index].get()); + begin_location_ = offload_stmt_id + 1; + CFGNode::add_edge(before_offload, graph_->nodes[block_begin_index].get()); } - current_offload = nullptr; + current_offload_ = nullptr; } /** @@ -415,56 +437,56 @@ class CFGBuilder : public IRVisitor { * graph->final_node = node_block_end; */ void visit(Block *block) override { - auto backup_block = current_block; - auto backup_last_node = last_node_in_current_block; - auto backup_stmt_id = current_stmt_id; + auto backup_block = current_block_; + auto backup_last_node = last_node_in_current_block_; + auto backup_stmt_id = current_stmt_id_; // |begin_location| must be -1 (indicating we are not building any CFGNode) // when the |current_block| changes. - TI_ASSERT(begin_location == -1); - TI_ASSERT(prev_nodes.empty() || graph->size() == 1); - current_block = block; - last_node_in_current_block = nullptr; - begin_location = 0; + TI_ASSERT(begin_location_ == -1); + TI_ASSERT(prev_nodes_.empty() || graph_->size() == 1); + current_block_ = block; + last_node_in_current_block_ = nullptr; + begin_location_ = 0; for (int i = 0; i < (int)block->size(); i++) { - current_stmt_id = i; + current_stmt_id_ = i; block->statements[i]->accept(this); } - current_stmt_id = block->size(); + current_stmt_id_ = block->size(); new_node(-1); // Each block has a deterministic last node. - graph->final_node = (int)graph->size() - 1; + graph_->final_node = (int)graph_->size() - 1; - current_block = backup_block; - last_node_in_current_block = backup_last_node; - current_stmt_id = backup_stmt_id; + current_block_ = backup_block; + last_node_in_current_block_ = backup_last_node; + current_stmt_id_ = backup_stmt_id; } static std::unique_ptr run(IRNode *root) { CFGBuilder builder; root->accept(&builder); - if (!builder.graph->nodes[builder.graph->final_node]->empty()) { + if (!builder.graph_->nodes[builder.graph_->final_node]->empty()) { // Make the final node empty (by adding an empty final node). - builder.graph->push_back(); - CFGNode::add_edge(builder.graph->nodes[builder.graph->final_node].get(), - builder.graph->back()); - builder.graph->final_node = (int)builder.graph->size() - 1; + builder.graph_->push_back(); + CFGNode::add_edge(builder.graph_->nodes[builder.graph_->final_node].get(), + builder.graph_->back()); + builder.graph_->final_node = (int)builder.graph_->size() - 1; } - return std::move(builder.graph); + return std::move(builder.graph_); } private: - std::unique_ptr graph; - Block *current_block; - CFGNode *last_node_in_current_block; - std::vector continues_in_current_loop; - std::vector breaks_in_current_loop; - int current_stmt_id; - int begin_location; - std::vector prev_nodes; - OffloadedStmt *current_offload; - bool in_parallel_for; - std::unordered_map node_func_begin; - std::unordered_map node_func_end; + std::unique_ptr graph_; + Block *current_block_; + CFGNode *last_node_in_current_block_; + std::vector continues_in_current_loop_; + std::vector breaks_in_current_loop_; + int current_stmt_id_; + int begin_location_; + std::vector prev_nodes_; + OffloadedStmt *current_offload_; + bool in_parallel_for_; + std::unordered_map node_func_begin_; + std::unordered_map node_func_end_; }; namespace irpass::analysis { diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index dc2e1b230a7c1..e4c2476008da4 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -12,7 +12,7 @@ TLANG_NAMESPACE_BEGIN class IRCloner : public IRVisitor { private: IRNode *other_node; - std::unordered_map operand_map; + std::unordered_map operand_map_; public: enum Phase { register_operand_map, replace_operand } phase; @@ -34,16 +34,16 @@ class IRCloner : public IRVisitor { void generic_visit(Stmt *stmt) { if (phase == register_operand_map) - operand_map[stmt] = other_node->as(); + operand_map_[stmt] = other_node->as(); else { TI_ASSERT(phase == replace_operand); auto other_stmt = other_node->as(); TI_ASSERT(stmt->num_operands() == other_stmt->num_operands()); for (int i = 0; i < stmt->num_operands(); i++) { - if (operand_map.find(stmt->operand(i)) == operand_map.end()) + if (operand_map_.find(stmt->operand(i)) == operand_map_.end()) other_stmt->set_operand(i, stmt->operand(i)); else - other_stmt->set_operand(i, operand_map[stmt->operand(i)]); + other_stmt->set_operand(i, operand_map_[stmt->operand(i)]); } } } @@ -67,14 +67,6 @@ class IRCloner : public IRVisitor { } } - void visit(FuncBodyStmt *stmt) override { - generic_visit(stmt); - auto other = other_node->as(); - other_node = other->body.get(); - stmt->body->accept(this); - other_node = other; - } - void visit(WhileStmt *stmt) override { generic_visit(stmt); auto other = other_node->as(); @@ -112,6 +104,7 @@ class IRCloner : public IRVisitor { CLONE_BLOCK(tls_prologue) CLONE_BLOCK(bls_prologue) + CLONE_BLOCK(mesh_prologue) if (stmt->body) { other_node = other->body.get(); diff --git a/taichi/analysis/count_statements.cpp b/taichi/analysis/count_statements.cpp index 43796c3c501f7..78733878b0a3e 100644 --- a/taichi/analysis/count_statements.cpp +++ b/taichi/analysis/count_statements.cpp @@ -8,7 +8,7 @@ TLANG_NAMESPACE_BEGIN class StmtCounter : public BasicStmtVisitor { private: StmtCounter() { - counter = 0; + counter_ = 0; allow_undefined_visitor = true; invoke_default_visitor = true; } @@ -17,21 +17,21 @@ class StmtCounter : public BasicStmtVisitor { public: void preprocess_container_stmt(Stmt *stmt) override { - counter++; + counter_++; } void visit(Stmt *stmt) override { - counter++; + counter_++; } static int run(IRNode *root) { StmtCounter stmt_counter; root->accept(&stmt_counter); - return stmt_counter.counter; + return stmt_counter.counter_; } private: - int counter; + int counter_; }; namespace irpass::analysis { diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index d1218f947d093..ec23e8be085ef 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -42,7 +42,7 @@ std::vector get_load_pointers(Stmt *load_stmt) { Stmt *get_store_data(Stmt *store_stmt) { // If store_stmt provides one data source, return the data. - if (store_stmt->is()) { + if (store_stmt->is() && !store_stmt->ret_type->is()) { // For convenience, return store_stmt instead of the const [0] it actually // stores. return store_stmt; @@ -57,7 +57,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is()) { + if (store_stmt->is() && !store_stmt->ret_type->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { @@ -67,7 +67,12 @@ std::vector get_store_destination(Stmt *store_stmt) { } else if (auto atomic = store_stmt->cast()) { return std::vector(1, atomic->dest); } else if (auto external_func = store_stmt->cast()) { - return external_func->output_stmts; + if (store_stmt->cast()->type == + ExternalFuncCallStmt::BITCODE) { + return external_func->arg_stmts; + } else { + return external_func->output_stmts; + } } else { return std::vector(); } diff --git a/taichi/analysis/gather_mesh_thread_local.cpp b/taichi/analysis/gather_mesh_thread_local.cpp new file mode 100644 index 0000000000000..41bc86a8aa953 --- /dev/null +++ b/taichi/analysis/gather_mesh_thread_local.cpp @@ -0,0 +1,82 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/snode.h" +#include "taichi/ir/mesh.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" + +TLANG_NAMESPACE_BEGIN + +using MeshElementTypeSet = std::unordered_set; + +class GatherMeshThreadLocal : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + GatherMeshThreadLocal(OffloadedStmt *offload_, + MeshElementTypeSet *owned_ptr_, + MeshElementTypeSet *total_ptr_, + bool optimize_mesh_reordered_mapping_) { + allow_undefined_visitor = true; + invoke_default_visitor = true; + + this->offload = offload_; + this->owned_ptr = owned_ptr_; + this->total_ptr = total_ptr_; + this->optimize_mesh_reordered_mapping = optimize_mesh_reordered_mapping_; + } + + static void run(OffloadedStmt *offload, + MeshElementTypeSet *owned_ptr, + MeshElementTypeSet *total_ptr, + const CompileConfig &config) { + TI_ASSERT(offload->task_type == OffloadedStmt::TaskType::mesh_for); + GatherMeshThreadLocal analyser(offload, owned_ptr, total_ptr, + config.optimize_mesh_reordered_mapping); + offload->accept(&analyser); + } + + void visit(LoopIndexStmt *stmt) override { + if (stmt->is_mesh_index()) { + this->owned_ptr->insert(stmt->mesh_index_type()); + } + } + + void visit(MeshRelationAccessStmt *stmt) override { + if (mesh::element_order(stmt->from_type()) > + mesh::element_order(stmt->to_type)) { + this->total_ptr->insert(stmt->from_type()); + } else { + this->owned_ptr->insert(stmt->from_type()); + } + } + + void visit(MeshIndexConversionStmt *stmt) override { + this->total_ptr->insert(stmt->idx_type); + if (optimize_mesh_reordered_mapping && + stmt->conv_type == mesh::ConvType::l2r) { + this->owned_ptr->insert(stmt->idx_type); + } + } + + OffloadedStmt *offload{nullptr}; + MeshElementTypeSet *owned_ptr{nullptr}; + MeshElementTypeSet *total_ptr{nullptr}; + bool optimize_mesh_reordered_mapping{false}; +}; + +namespace irpass::analysis { + +std::pair +gather_mesh_thread_local(OffloadedStmt *offload, const CompileConfig &config) { + MeshElementTypeSet local_owned{}; + MeshElementTypeSet local_total{}; + + GatherMeshThreadLocal::run(offload, &local_owned, &local_total, config); + return std::make_pair(local_owned, local_total); +} + +} // namespace irpass::analysis + +TLANG_NAMESPACE_END diff --git a/taichi/analysis/gather_meshfor_relation_types.cpp b/taichi/analysis/gather_meshfor_relation_types.cpp new file mode 100644 index 0000000000000..21ce9ef60f328 --- /dev/null +++ b/taichi/analysis/gather_meshfor_relation_types.cpp @@ -0,0 +1,65 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/snode.h" +#include "taichi/ir/mesh.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" + +TLANG_NAMESPACE_BEGIN + +namespace irpass::analysis { + +class GatherMeshforRelationTypes : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + GatherMeshforRelationTypes() { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + + static void run(IRNode *root) { + GatherMeshforRelationTypes analyser; + root->accept(&analyser); + } + + void visit(MeshForStmt *stmt) override { + TI_ASSERT(mesh_for == nullptr); + TI_ASSERT(stmt->major_to_types.size() == 0); + TI_ASSERT(stmt->minor_relation_types.size() == 0); + mesh_for = stmt; + stmt->body->accept(this); + mesh_for = nullptr; + } + + void visit(MeshRelationAccessStmt *stmt) override { + if (auto from_stmt = + stmt->mesh_idx->cast()) { // major relation + TI_ASSERT(from_stmt->mesh_index_type() == mesh_for->major_from_type); + mesh_for->major_to_types.insert(stmt->to_type); + } else if (auto from_stmt = + stmt->mesh_idx + ->cast()) { // minor relation + TI_ASSERT(!from_stmt->is_size()); + auto from_order = mesh::element_order(from_stmt->to_type); + auto to_order = mesh::element_order(stmt->to_type); + TI_ASSERT_INFO(from_order > to_order, + "Cannot access an indeterminate relation (E.g, Vert-Vert) " + "in a nested neighbor access"); + mesh_for->minor_relation_types.insert( + mesh::relation_by_orders(from_order, to_order)); + } else { + TI_NOT_IMPLEMENTED; + } + } + + MeshForStmt *mesh_for{nullptr}; +}; + +void gather_meshfor_relation_types(IRNode *node) { + GatherMeshforRelationTypes::run(node); +} + +} // namespace irpass::analysis + +TLANG_NAMESPACE_END diff --git a/taichi/analysis/gather_statements.cpp b/taichi/analysis/gather_statements.cpp index 3264cfea72181..785afd0b767df 100644 --- a/taichi/analysis/gather_statements.cpp +++ b/taichi/analysis/gather_statements.cpp @@ -6,27 +6,27 @@ TLANG_NAMESPACE_BEGIN class StmtSearcher : public BasicStmtVisitor { private: - std::function test; - std::vector results; + std::function test_; + std::vector results_; public: using BasicStmtVisitor::visit; - StmtSearcher(std::function test) : test(test) { + StmtSearcher(std::function test) : test_(test) { allow_undefined_visitor = true; invoke_default_visitor = true; } - void visit(Stmt *stmt) { - if (test(stmt)) - results.push_back(stmt); + void visit(Stmt *stmt) override { + if (test_(stmt)) + results_.push_back(stmt); } static std::vector run(IRNode *root, const std::function &test) { StmtSearcher searcher(test); root->accept(&searcher); - return searcher.results; + return searcher.results_; } }; diff --git a/taichi/analysis/gather_uniquely_accessed_pointers.cpp b/taichi/analysis/gather_uniquely_accessed_pointers.cpp index 8febbe134ba36..69b4d767a2960 100644 --- a/taichi/analysis/gather_uniquely_accessed_pointers.cpp +++ b/taichi/analysis/gather_uniquely_accessed_pointers.cpp @@ -42,6 +42,10 @@ class LoopUniqueStmtSearcher : public BasicStmtVisitor { loop_invariant_.insert(stmt); } + void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { + loop_invariant_.insert(stmt); + } + void visit(UnaryOpStmt *stmt) override { if (loop_invariant_.count(stmt->operand) > 0) { loop_invariant_.insert(stmt); @@ -55,6 +59,21 @@ class LoopUniqueStmtSearcher : public BasicStmtVisitor { } } + void visit(DecorationStmt *stmt) override { + if (stmt->decoration.size() == 2 && + stmt->decoration[0] == + uint32_t(DecorationStmt::Decoration::kLoopUnique)) { + if (loop_unique_.find(stmt->operand) == loop_unique_.end()) { + // This decoration exists IFF we are looping over NDArray (or any other + // cases where the array index is linearized by the codegen) In that + // case the original loop dimensions have been reduced to 1D. + loop_unique_[stmt->operand] = stmt->decoration[1]; + num_different_loop_indices = std::max(loop_unique_[stmt->operand] + 1, + num_different_loop_indices); + } + } + } + void visit(BinaryOpStmt *stmt) override { if (loop_invariant_.count(stmt->lhs) > 0 && loop_invariant_.count(stmt->rhs) > 0) { @@ -81,6 +100,10 @@ class LoopUniqueStmtSearcher : public BasicStmtVisitor { } } + bool is_partially_loop_unique(Stmt *stmt) const { + return loop_unique_.find(stmt) != loop_unique_.end(); + } + bool is_ptr_indices_loop_unique(GlobalPtrStmt *stmt) const { // Check if the address is loop-unique, i.e., stmt contains // either a loop-unique index or all top-level loop indices. @@ -108,6 +131,36 @@ class LoopUniqueStmtSearcher : public BasicStmtVisitor { // b[i, i] is not loop-unique (because there's no j) return current_num_different_loop_indices == num_different_loop_indices; } + + bool is_ptr_indices_loop_unique(ExternalPtrStmt *stmt) const { + // Check if the address is loop-unique, i.e., stmt contains + // either a loop-unique index or all top-level loop indices. + TI_ASSERT(num_different_loop_indices != -1); + std::vector loop_indices; + loop_indices.reserve(stmt->indices.size()); + for (auto &index : stmt->indices) { + auto loop_unique_index = loop_unique_.find(index); + if (loop_unique_index != loop_unique_.end()) { + if (loop_unique_index->second == -1) { + // LoopUniqueStmt + return true; + } else { + // LoopIndexStmt + loop_indices.push_back(loop_unique_index->second); + } + } + } + std::sort(loop_indices.begin(), loop_indices.end()); + auto current_num_different_loop_indices = + std::unique(loop_indices.begin(), loop_indices.end()) - + loop_indices.begin(); + + // for i, j in x: + // a[j, i] is loop-unique + // b[i, i] is not loop-unique (because there's no j) + // c[j, i, 1] is loop-unique + return current_num_different_loop_indices == num_different_loop_indices; + } }; class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { @@ -118,6 +171,10 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { // one GlobalPtrStmt (or by definitely-same-address GlobalPtrStmts), // and that GlobalPtrStmt's address is loop-unique. std::unordered_map accessed_pointer_; + std::unordered_map rel_access_pointer_; + + // Search any_arrs that are uniquely accessed. Maps: ArgID -> ExternalPtrStmt + std::unordered_map accessed_arr_pointer_; public: using BasicStmtVisitor::visit; @@ -128,6 +185,33 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { } void visit(GlobalPtrStmt *stmt) override { + // mesh-for loop unique + if (stmt->indices.size() == 1 && + stmt->indices[0]->is()) { + auto idx = stmt->indices[0]->as()->idx; + while (idx->is()) { // special case: l2g + + // g2r + idx = idx->as()->idx; + } + if (idx->is() && + idx->as()->is_mesh_index()) { // from-end access + for (auto &snode : stmt->snodes.data) { + if (rel_access_pointer_.find(snode) == + rel_access_pointer_.end()) { // not accessed by neibhours yet + accessed_pointer_[snode] = stmt; + } else { // accessed by neibhours, so it's not unique + accessed_pointer_[snode] = nullptr; + } + } + } else { // to-end access + for (auto &snode : stmt->snodes.data) { + rel_access_pointer_[snode] = stmt; + accessed_pointer_[snode] = + nullptr; // from-end access should not be unique + } + } + } + // Range-for / struct-for for (auto &snode : stmt->snodes.data) { auto accessed_ptr = accessed_pointer_.find(snode); if (accessed_ptr == accessed_pointer_.end()) { @@ -145,11 +229,70 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { } } - static std::unordered_map run(IRNode *root) { + void visit(ExternalPtrStmt *stmt) override { + // A memory location of an ExternalPtrStmt depends on the indices + // If the accessed indices are loop unique, + // the accessed memory location is loop unique + for (auto base_ptr : stmt->base_ptrs.data) { + ArgLoadStmt *arg_load_stmt = base_ptr->as(); + int arg_id = arg_load_stmt->arg_id; + + auto accessed_ptr = accessed_arr_pointer_.find(arg_id); + + bool stmt_loop_unique = + loop_unique_stmt_searcher_.is_ptr_indices_loop_unique(stmt); + + if (!stmt_loop_unique) { + accessed_arr_pointer_[arg_id] = nullptr; // not loop-unique + } else { + if (accessed_ptr == accessed_arr_pointer_.end()) { + // First time using arr @ arg_id + accessed_arr_pointer_[arg_id] = stmt; + } else { + /** + * We know stmt->base_ptr and the previously recorded pointers + * are loop-unique. We need to figure out whether their loop-unique + * indicies are the same while ignoring the others. + * e.g. a[i, j, 1] and a[i, j, 2] are both uniquely accessed + * a[i, j, 1] and a[j, i, 2] are not uniquely accessed + * a[i, j + 1, 1] and a[i, j, 2] are not uniquely accessed + * This is a bit stricter than needed. + * e.g. a[i, j, i] and a[i, j, 0] are uniquely accessed + * However this is probably not common and improvements can be made + * in a future patch. + */ + if (accessed_ptr->second) { + ExternalPtrStmt *other_ptr = accessed_ptr->second; + TI_ASSERT(stmt->indices.size() == other_ptr->indices.size()); + for (int axis = 0; axis < stmt->indices.size(); axis++) { + Stmt *this_index = stmt->indices[axis]; + Stmt *other_index = other_ptr->indices[axis]; + // We only compare unique indices here. + // Since both pointers are loop-unique, all the unique indices + // need to be the same for both to be uniquely accessed + if (loop_unique_stmt_searcher_.is_partially_loop_unique( + this_index)) { + if (!irpass::analysis::same_value(this_index, other_index)) { + // Not equal -> not uniquely accessed + accessed_arr_pointer_[arg_id] = nullptr; + break; + } + } + } + } + } + } + } + } + + static std::pair, + std::unordered_map> + run(IRNode *root) { TI_ASSERT(root->is()); auto offload = root->as(); UniquelyAccessedSNodeSearcher searcher; - if (offload->task_type == OffloadedTaskType::range_for) { + if (offload->task_type == OffloadedTaskType::range_for || + offload->task_type == OffloadedTaskType::mesh_for) { searcher.loop_unique_stmt_searcher_.num_different_loop_indices = 1; } else if (offload->task_type == OffloadedTaskType::struct_for) { searcher.loop_unique_stmt_searcher_.num_different_loop_indices = @@ -160,7 +303,9 @@ class UniquelyAccessedSNodeSearcher : public BasicStmtVisitor { } root->accept(&searcher.loop_unique_stmt_searcher_); root->accept(&searcher); - return searcher.accessed_pointer_; + + return std::make_pair(searcher.accessed_pointer_, + searcher.accessed_arr_pointer_); } }; @@ -180,13 +325,18 @@ class UniquelyAccessedBitStructGatherer : public BasicStmtVisitor { void visit(OffloadedStmt *stmt) override { if (stmt->task_type == OffloadedTaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || stmt->task_type == OffloadedTaskType::struct_for) { auto &loop_unique_bit_struct = result_[stmt]; auto loop_unique_ptr = - irpass::analysis::gather_uniquely_accessed_pointers(stmt); + irpass::analysis::gather_uniquely_accessed_pointers(stmt).first; for (auto &it : loop_unique_ptr) { auto *snode = it.first; auto *ptr1 = it.second; + if (ptr1 != nullptr && ptr1->indices.size() > 0 && + ptr1->indices[0]->is()) { + continue; + } if (snode->is_bit_level) { // Find the nearest non-bit-level ancestor while (snode->is_bit_level) { @@ -229,7 +379,8 @@ const std::string GatherUniquelyAccessedBitStructsPass::id = "GatherUniquelyAccessedBitStructsPass"; namespace irpass::analysis { -std::unordered_map +std::pair, + std::unordered_map> gather_uniquely_accessed_pointers(IRNode *root) { // TODO: What about SNodeOpStmts? return UniquelyAccessedSNodeSearcher::run(root); diff --git a/taichi/analysis/gather_used_atomics.cpp b/taichi/analysis/gather_used_atomics.cpp index 1664be7688e6a..cb1d8d32150ba 100644 --- a/taichi/analysis/gather_used_atomics.cpp +++ b/taichi/analysis/gather_used_atomics.cpp @@ -8,7 +8,7 @@ TLANG_NAMESPACE_BEGIN class UsedAtomicsSearcher : public BasicStmtVisitor { private: - std::unique_ptr> used_atomics; + std::unique_ptr> used_atomics_; public: using BasicStmtVisitor::visit; @@ -16,13 +16,13 @@ class UsedAtomicsSearcher : public BasicStmtVisitor { UsedAtomicsSearcher() { allow_undefined_visitor = true; invoke_default_visitor = true; - used_atomics = std::make_unique>(); + used_atomics_ = std::make_unique>(); } void search_operands(Stmt *stmt) { for (auto &op : stmt->get_operands()) { if (op != nullptr && op->is()) { - used_atomics->insert(op->as()); + used_atomics_->insert(op->as()); } } } @@ -38,7 +38,7 @@ class UsedAtomicsSearcher : public BasicStmtVisitor { static std::unique_ptr> run(IRNode *root) { UsedAtomicsSearcher searcher; root->accept(&searcher); - return std::move(searcher.used_atomics); + return std::move(searcher.used_atomics_); } }; diff --git a/taichi/analysis/has_store_or_atomic.cpp b/taichi/analysis/has_store_or_atomic.cpp index 04d546673364c..391ee99fbba5a 100644 --- a/taichi/analysis/has_store_or_atomic.cpp +++ b/taichi/analysis/has_store_or_atomic.cpp @@ -8,14 +8,14 @@ TLANG_NAMESPACE_BEGIN // Find if there is a store (or AtomicOpStmt). class LocalStoreSearcher : public BasicStmtVisitor { private: - const std::vector &vars; - bool result; + const std::vector &vars_; + bool result_; public: using BasicStmtVisitor::visit; explicit LocalStoreSearcher(const std::vector &vars) - : vars(vars), result(false) { + : vars_(vars), result_(false) { for (auto var : vars) { TI_ASSERT(var->is()); } @@ -24,18 +24,18 @@ class LocalStoreSearcher : public BasicStmtVisitor { } void visit(LocalStoreStmt *stmt) override { - for (auto var : vars) { + for (auto var : vars_) { if (stmt->dest == var) { - result = true; + result_ = true; break; } } } void visit(AtomicOpStmt *stmt) override { - for (auto var : vars) { + for (auto var : vars_) { if (stmt->dest == var) { - result = true; + result_ = true; break; } } @@ -44,7 +44,7 @@ class LocalStoreSearcher : public BasicStmtVisitor { static bool run(IRNode *root, const std::vector &vars) { LocalStoreSearcher searcher(vars); root->accept(&searcher); - return searcher.result; + return searcher.result_; } }; diff --git a/taichi/analysis/last_store_or_atomic.cpp b/taichi/analysis/last_store_or_atomic.cpp index 744e4ef1c62d9..4c90a094a11b1 100644 --- a/taichi/analysis/last_store_or_atomic.cpp +++ b/taichi/analysis/last_store_or_atomic.cpp @@ -9,37 +9,37 @@ TLANG_NAMESPACE_BEGIN // after the last store. class LocalStoreForwarder : public BasicStmtVisitor { private: - Stmt *var; - bool is_valid; - Stmt *result; + Stmt *var_; + bool is_valid_; + Stmt *result_; public: using BasicStmtVisitor::visit; explicit LocalStoreForwarder(Stmt *var) - : var(var), is_valid(true), result(nullptr) { + : var_(var), is_valid_(true), result_(nullptr) { TI_ASSERT(var->is()); allow_undefined_visitor = true; invoke_default_visitor = true; } void visit(LocalStoreStmt *stmt) override { - if (stmt->dest == var) { - is_valid = true; - result = stmt; + if (stmt->dest == var_) { + is_valid_ = true; + result_ = stmt; } } void visit(AllocaStmt *stmt) override { - if (stmt == var) { - is_valid = true; - result = stmt; + if (stmt == var_) { + is_valid_ = true; + result_ = stmt; } } void visit(AtomicOpStmt *stmt) override { - if (stmt->dest == var) { - is_valid = false; + if (stmt->dest == var_) { + is_valid_ = false; } } @@ -50,33 +50,33 @@ class LocalStoreForwarder : public BasicStmtVisitor { std::pair true_branch(true, nullptr); if (if_stmt->true_statements) { // create a new LocalStoreForwarder instance - true_branch = run(if_stmt->true_statements.get(), var); + true_branch = run(if_stmt->true_statements.get(), var_); } std::pair false_branch(true, nullptr); if (if_stmt->false_statements) { - false_branch = run(if_stmt->false_statements.get(), var); + false_branch = run(if_stmt->false_statements.get(), var_); } auto true_stmt = true_branch.second; auto false_stmt = false_branch.second; if (!true_branch.first || !false_branch.first) { // at least one branch finally modifies the variable without storing - is_valid = false; + is_valid_ = false; } else if (true_stmt == nullptr && false_stmt == nullptr) { // both branches don't modify the variable return; } else if (true_stmt == nullptr || false_stmt == nullptr) { // only one branch modifies the variable - is_valid = false; + is_valid_ = false; } else { TI_ASSERT(true_stmt->is()); TI_ASSERT(false_stmt->is()); if (true_stmt->as()->val != false_stmt->as()->val) { // two branches finally store the variable differently - is_valid = false; + is_valid_ = false; } else { - is_valid = true; - result = true_stmt; // same as false_stmt + is_valid_ = true; + result_ = true_stmt; // same as false_stmt } } } @@ -85,33 +85,33 @@ class LocalStoreForwarder : public BasicStmtVisitor { // the "last" store inside a loop to the local load statement. // What we can do is just check if the loop doesn't modify the variable. void visit(WhileStmt *stmt) override { - if (irpass::analysis::has_store_or_atomic(stmt, {var})) { - is_valid = false; + if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { + is_valid_ = false; } } void visit(RangeForStmt *stmt) override { - if (irpass::analysis::has_store_or_atomic(stmt, {var})) { - is_valid = false; + if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { + is_valid_ = false; } } void visit(StructForStmt *stmt) override { - if (irpass::analysis::has_store_or_atomic(stmt, {var})) { - is_valid = false; + if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { + is_valid_ = false; } } void visit(OffloadedStmt *stmt) override { - if (irpass::analysis::has_store_or_atomic(stmt, {var})) { - is_valid = false; + if (irpass::analysis::has_store_or_atomic(stmt, {var_})) { + is_valid_ = false; } } static std::pair run(IRNode *root, Stmt *var) { LocalStoreForwarder searcher(var); root->accept(&searcher); - return std::make_pair(searcher.is_valid, searcher.result); + return std::make_pair(searcher.is_valid_, searcher.result_); } }; diff --git a/taichi/analysis/mesh_bls_analyzer.cpp b/taichi/analysis/mesh_bls_analyzer.cpp new file mode 100644 index 0000000000000..280636acf006e --- /dev/null +++ b/taichi/analysis/mesh_bls_analyzer.cpp @@ -0,0 +1,118 @@ +#include "taichi/analysis/mesh_bls_analyzer.h" + +#include "taichi/system/profiler.h" +#include "taichi/ir/analysis.h" + +namespace taichi { +namespace lang { + +MeshBLSAnalyzer::MeshBLSAnalyzer(OffloadedStmt *for_stmt, + MeshBLSCaches *caches, + bool auto_mesh_local, + const CompileConfig &config) + : for_stmt_(for_stmt), + caches_(caches), + auto_mesh_local_(auto_mesh_local), + config_(config) { + TI_AUTO_PROF; + allow_undefined_visitor = true; + invoke_default_visitor = false; +} + +void MeshBLSAnalyzer::record_access(Stmt *stmt, AccessFlag flag) { + if (!analysis_ok_) { + return; + } + if (!stmt->is()) + return; // local alloca + auto ptr = stmt->as(); + if (ptr->indices.size() != std::size_t(1) || + !ptr->indices[0]->is()) + return; + auto conv = ptr->indices[0]->as(); + auto element_type = conv->idx_type; + auto conv_type = conv->conv_type; + auto idx = conv->idx; + if (conv_type == mesh::ConvType::g2r) + return; + for (int l = 0; l < stmt->width(); l++) { + auto snode = ptr->snodes[l]; + if (!caches_->has(snode)) { + if (auto_mesh_local_ && + (flag == AccessFlag::accumulate || + (flag == AccessFlag::read && config_.arch == Arch::cuda)) && + (!idx->is() || + !idx->as()->is_mesh_index())) { + caches_->insert(snode); + } else { + continue; + } + } + + if (!caches_->access(snode, element_type, conv_type, flag, + idx->as()->neighbor_idx)) { + analysis_ok_ = false; + break; + } + } +} + +void MeshBLSAnalyzer::visit(GlobalLoadStmt *stmt) { + TI_ASSERT(stmt->width() == 1); // TODO: support vectorization + record_access(stmt->src, AccessFlag::read); +} + +void MeshBLSAnalyzer::visit(GlobalStoreStmt *stmt) { + TI_ASSERT(stmt->width() == 1); // TODO: support vectorization + record_access(stmt->dest, AccessFlag::write); +} + +void MeshBLSAnalyzer::visit(AtomicOpStmt *stmt) { + if (stmt->op_type == AtomicOpType::add) { + record_access(stmt->dest, AccessFlag::accumulate); + } +} + +void MeshBLSAnalyzer::visit(Stmt *stmt) { + TI_ASSERT(!stmt->is_container_statement()); +} + +bool MeshBLSAnalyzer::run() { + const auto &block = for_stmt_->body; + + for (int i = 0; i < (int)block->statements.size(); i++) { + block->statements[i]->accept(this); + } + + return analysis_ok_; +} + +namespace irpass { +namespace analysis { + +std::unique_ptr initialize_mesh_local_attribute( + OffloadedStmt *offload, + bool auto_mesh_local, + const CompileConfig &config) { + TI_AUTO_PROF + TI_ASSERT(offload->task_type == OffloadedTaskType::mesh_for); + std::unique_ptr caches; + caches = std::make_unique(); + for (auto snode : offload->mem_access_opt.get_snodes_with_flag( + SNodeAccessFlag::mesh_local)) { + caches->insert(snode); + } + + MeshBLSAnalyzer bls_analyzer(offload, caches.get(), auto_mesh_local, config); + bool analysis_ok = bls_analyzer.run(); + if (!analysis_ok) { + TI_ERROR("Mesh BLS analysis failed !"); + } + return caches; +} + +} // namespace analysis +} // namespace irpass + +} // namespace lang +} // namespace taichi diff --git a/taichi/analysis/mesh_bls_analyzer.h b/taichi/analysis/mesh_bls_analyzer.h new file mode 100644 index 0000000000000..f8db1c3d2b9ba --- /dev/null +++ b/taichi/analysis/mesh_bls_analyzer.h @@ -0,0 +1,158 @@ +#pragma once + +#include "taichi/program/compile_config.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/mesh.h" + +#include + +namespace taichi { +namespace lang { + +class MeshBLSCache { + public: + using AccessFlag = taichi::lang::AccessFlag; + using Rec = std::map, + std::set>>; + + SNode *snode{nullptr}; + mesh::MeshElementType element_type; + mesh::ConvType conv_type; + + bool initialized; + bool finalized; + bool loop_index; + int unique_accessed; + AccessFlag total_flags; + + MeshBLSCache() = default; + + MeshBLSCache(SNode *snode) : snode(snode) { + total_flags = AccessFlag(0); + initialized = false; + finalized = false; + loop_index = false; + unique_accessed = 0; + } + + bool access(mesh::MeshElementType element_type, + mesh::ConvType conv_type, + AccessFlag flags, + Stmt *idx) { + if (!initialized) { + initialized = true; + this->conv_type = conv_type; + this->element_type = element_type; + } else { + if (this->conv_type != conv_type || this->element_type != element_type) + return false; + } + this->total_flags |= flags; + if (idx->is()) { + loop_index = true; + } else { + unique_accessed++; + } + return true; + } + + void finalize(Rec &rec) { + TI_ASSERT(!finalized); + finalized = true; + if (initialized) { + const auto cache_type = std::make_pair(element_type, conv_type); + auto ptr = rec.find(cache_type); + if (ptr == rec.end()) { + ptr = rec.emplace(std::piecewise_construct, + std::forward_as_tuple(cache_type), + std::forward_as_tuple()) + .first; + } + ptr->second.insert(std::make_pair(snode, total_flags)); + } + } +}; + +class MeshBLSCaches { + public: + std::map caches; + + using AccessFlag = MeshBLSCache::AccessFlag; + using Rec = MeshBLSCache::Rec; + + void insert(SNode *snode) { + if (caches.find(snode) == caches.end()) { + caches.emplace(std::piecewise_construct, std::forward_as_tuple(snode), + std::forward_as_tuple(snode)); + } else { + TI_ERROR("mesh::MeshBLSCaches for {} already exists.", + snode->node_type_name); + } + } + + bool access(SNode *snode, + mesh::MeshElementType element_type, + mesh::ConvType conv_type, + AccessFlag flags, + Stmt *idx) { + if (caches.find(snode) == caches.end()) + return false; + return caches.find(snode)->second.access(element_type, conv_type, flags, + idx); + } + + Rec finalize() { + Rec rec; + for (auto &cache : caches) { + cache.second.finalize(rec); + } + return rec; + } + + bool has(SNode *snode) { + return caches.find(snode) != caches.end(); + } + + MeshBLSCache &get(SNode *snode) { + TI_ASSERT(caches.find(snode) != caches.end()); + return caches[snode]; + } +}; + +// Figure out accessed SNodes, and their ranges in this for stmt +class MeshBLSAnalyzer : public BasicStmtVisitor { + using BasicStmtVisitor::visit; + + public: + MeshBLSAnalyzer(OffloadedStmt *for_stmt, + MeshBLSCaches *caches, + bool auto_mesh_local, + const CompileConfig &config); + + void visit(GlobalPtrStmt *stmt) override { + } + + // Do not eliminate global data access + void visit(GlobalLoadStmt *stmt) override; + + void visit(GlobalStoreStmt *stmt) override; + + void visit(AtomicOpStmt *stmt) override; + + void visit(Stmt *stmt) override; + + bool run(); + + private: + void record_access(Stmt *stmt, AccessFlag flag); + + OffloadedStmt *for_stmt_{nullptr}; + MeshBLSCaches *caches_{nullptr}; + bool analysis_ok_{true}; + bool auto_mesh_local_{false}; + CompileConfig config_; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index e53b4baa35a5a..4a8cd890ff1b8 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -12,9 +12,9 @@ TLANG_NAMESPACE_BEGIN // Compare if two IRNodes are equivalent. class IRNodeComparator : public IRVisitor { private: - IRNode *other_node; + IRNode *other_node_; // map the id from this node to the other node - std::unordered_map id_map; + std::unordered_map id_map_; bool recursively_check_; @@ -39,13 +39,13 @@ class IRNodeComparator : public IRVisitor { const std::optional> &possibly_modified_states, IRBank *ir_bank) - : other_node(other_node) { + : other_node_(other_node) { allow_undefined_visitor = true; invoke_default_visitor = true; same = true; if (id_map.has_value()) { recursively_check_ = true; - this->id_map = id_map.value(); + this->id_map_ = id_map.value(); } else { recursively_check_ = false; } @@ -66,9 +66,9 @@ class IRNodeComparator : public IRVisitor { } void map_id(int this_id, int other_id) { - auto it = id_map.find(this_id); - if (it == id_map.end()) { - id_map[this_id] = other_id; + auto it = id_map_.find(this_id); + if (it == id_map_.end()) { + id_map_[this_id] = other_id; } else if (it->second != other_id) { same = false; } @@ -77,8 +77,8 @@ class IRNodeComparator : public IRVisitor { void check_mapping(Stmt *this_stmt, Stmt *other_stmt) { // get the corresponding id in the other node // and check if it is other_stmt->id - auto it = id_map.find(this_stmt->id); - if (it != id_map.end()) { + auto it = id_map_.find(this_stmt->id); + if (it != id_map_.end()) { if (it->second != other_stmt->id) { same = false; } @@ -89,43 +89,43 @@ class IRNodeComparator : public IRVisitor { if (this_stmt->id != other_stmt->id) { same = false; } - id_map[this_stmt->id] = other_stmt->id; + id_map_[this_stmt->id] = other_stmt->id; } else { // recursively check them - IRNode *backup_other_node = other_node; - other_node = other_stmt; + IRNode *backup_other_node = other_node_; + other_node_ = other_stmt; this_stmt->accept(this); - other_node = backup_other_node; + other_node_ = backup_other_node; } } void visit(Block *stmt_list) override { - if (!other_node->is()) { + if (!other_node_->is()) { same = false; return; } - auto other = other_node->as(); + auto other = other_node_->as(); if (stmt_list->size() != other->size()) { same = false; return; } for (int i = 0; i < (int)stmt_list->size(); i++) { - other_node = other->statements[i].get(); + other_node_ = other->statements[i].get(); stmt_list->statements[i]->accept(this); if (!same) break; } - other_node = other; + other_node_ = other; } void basic_check(Stmt *stmt) { // type check - if (typeid(*other_node) != typeid(*stmt)) { + if (typeid(*other_node_) != typeid(*stmt)) { same = false; return; } - auto other = other_node->as(); + auto other = other_node_->as(); if (stmt == other) { return; } @@ -240,78 +240,68 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); if (!same) return; - auto other = other_node->as(); + auto other = other_node_->as(); if (stmt->true_statements) { if (!other->true_statements) { same = false; return; } - other_node = other->true_statements.get(); + other_node_ = other->true_statements.get(); stmt->true_statements->accept(this); - other_node = other; + other_node_ = other; } if (stmt->false_statements && same) { if (!other->false_statements) { same = false; return; } - other_node = other->false_statements.get(); + other_node_ = other->false_statements.get(); stmt->false_statements->accept(this); - other_node = other; + other_node_ = other; } } - void visit(FuncBodyStmt *stmt) override { - basic_check(stmt); - if (!same) - return; - auto other = other_node->as(); - other_node = other->body.get(); - stmt->body->accept(this); - other_node = other; - } - void visit(WhileStmt *stmt) override { basic_check(stmt); if (!same) return; - auto other = other_node->as(); - other_node = other->body.get(); + auto other = other_node_->as(); + other_node_ = other->body.get(); stmt->body->accept(this); - other_node = other; + other_node_ = other; } void visit(RangeForStmt *stmt) override { basic_check(stmt); if (!same) return; - auto other = other_node->as(); - other_node = other->body.get(); + auto other = other_node_->as(); + other_node_ = other->body.get(); stmt->body->accept(this); - other_node = other; + other_node_ = other; } void visit(StructForStmt *stmt) override { basic_check(stmt); if (!same) return; - auto other = other_node->as(); - other_node = other->body.get(); + auto other = other_node_->as(); + other_node_ = other->body.get(); stmt->body->accept(this); - other_node = other; + other_node_ = other; } void visit(OffloadedStmt *stmt) override { basic_check(stmt); if (!same) return; - auto other = other_node->as(); + auto other = other_node_->as(); if (stmt->has_body()) { TI_ASSERT(stmt->body); TI_ASSERT(other->body); - other_node = other->body.get(); + other_node_ = other->body.get(); stmt->body->accept(this); - other_node = other; + other_node_ = other; } } diff --git a/taichi/analysis/value_diff.cpp b/taichi/analysis/value_diff.cpp index b6d88e4b25eb7..bea01f39ef63a 100644 --- a/taichi/analysis/value_diff.cpp +++ b/taichi/analysis/value_diff.cpp @@ -9,18 +9,18 @@ namespace taichi { namespace lang { DiffRange operator+(const DiffRange &a, const DiffRange &b) { - return DiffRange(a.related_() && b.related_(), a.coeff + b.coeff, - a.low + b.low, a.high + b.high - 1); + return DiffRange(a.related() && b.related(), a.coeff + b.coeff, a.low + b.low, + a.high + b.high - 1); } DiffRange operator-(const DiffRange &a, const DiffRange &b) { - return DiffRange(a.related_() && b.related_(), a.coeff - b.coeff, + return DiffRange(a.related() && b.related(), a.coeff - b.coeff, a.low - b.high + 1, a.high - b.low); } DiffRange operator*(const DiffRange &a, const DiffRange &b) { return DiffRange( - a.related_() && b.related_() && a.coeff * b.coeff == 0, + a.related() && b.related() && a.coeff * b.coeff == 0, fmax(a.low * b.coeff, a.coeff * b.low), fmin(a.low * b.low, fmin(a.low * (b.high - 1), @@ -33,7 +33,7 @@ DiffRange operator*(const DiffRange &a, const DiffRange &b) { DiffRange operator<<(const DiffRange &a, const DiffRange &b) { return DiffRange( - a.related_() && b.related_() && b.coeff == 0 && b.high - b.low == 1, + a.related() && b.related() && b.coeff == 0 && b.high - b.low == 1, a.coeff << b.low, a.low << b.low, ((a.high - 1) << b.low) + 1); } @@ -113,7 +113,7 @@ class ValueDiffLoopIndex : public IRVisitor { stmt->rhs->accept(this); auto ret1 = results[stmt->lhs->instance_id]; auto ret2 = results[stmt->rhs->instance_id]; - if (ret1.related_() && ret2.related_()) { + if (ret1.related() && ret2.related()) { if (stmt->op_type == BinaryOpType::add) { results[stmt->instance_id] = ret1 + ret2; } else if (stmt->op_type == BinaryOpType::sub) { diff --git a/taichi/analysis/verify.cpp b/taichi/analysis/verify.cpp index ed22468538304..cd76cb8e30316 100644 --- a/taichi/analysis/verify.cpp +++ b/taichi/analysis/verify.cpp @@ -12,51 +12,51 @@ TLANG_NAMESPACE_BEGIN class IRVerifier : public BasicStmtVisitor { private: - Block *current_block; - Stmt *current_container_stmt; + Block *current_block_; + Stmt *current_container_stmt_; // each scope corresponds to an unordered_set - std::vector> visible_stmts; + std::vector> visible_stmts_; public: using BasicStmtVisitor::visit; explicit IRVerifier(IRNode *root) - : current_block(nullptr), current_container_stmt(nullptr) { + : current_block_(nullptr), current_container_stmt_(nullptr) { allow_undefined_visitor = true; invoke_default_visitor = true; if (!root->is()) - visible_stmts.emplace_back(); + visible_stmts_.emplace_back(); if (root->is() && root->as()->is_container_statement()) { - current_container_stmt = root->as(); + current_container_stmt_ = root->as(); } } void basic_verify(Stmt *stmt) { - TI_ASSERT_INFO(stmt->parent == current_block, + TI_ASSERT_INFO(stmt->parent == current_block_, "stmt({})->parent({}) != current_block({})", stmt->id, - fmt::ptr(stmt->parent), fmt::ptr(current_block)); + fmt::ptr(stmt->parent), fmt::ptr(current_block_)); for (auto &op : stmt->get_operands()) { if (op == nullptr) continue; bool found = false; - for (int depth = (int)visible_stmts.size() - 1; depth >= 0; depth--) { - if (visible_stmts[depth].find(op) != visible_stmts[depth].end()) { + for (int depth = (int)visible_stmts_.size() - 1; depth >= 0; depth--) { + if (visible_stmts_[depth].find(op) != visible_stmts_[depth].end()) { found = true; break; } } TI_ASSERT_INFO( found, - "IR broken: stmt {} cannot have operand {}." + "IR broken: stmt {} {} cannot have operand {} {}." " If you are using autodiff, please check" " https://docs.taichi.graphics/lang/articles/advanced/" "differentiable_programming#kernel-simplicity-rule" " If it doesn't help, please report this bug by opening an issue at" " https://github.com/taichi-dev/taichi to help us improve." " Thanks in advance!", - stmt->id, op->id); + stmt->type(), stmt->id, op->type(), op->id); } - visible_stmts.back().insert(stmt); + visible_stmts_.back().insert(stmt); } void preprocess_container_stmt(Stmt *stmt) override { @@ -69,23 +69,25 @@ class IRVerifier : public BasicStmtVisitor { void visit(Block *block) override { TI_ASSERT_INFO( - block->parent_stmt == current_container_stmt, + block->parent_stmt == current_container_stmt_, "block({})->parent({}) != current_container_stmt({})", fmt::ptr(block), block->parent_stmt ? block->parent_stmt->name() : "nullptr", - current_container_stmt ? current_container_stmt->name() : "nullptr"); - auto backup_block = current_block; - current_block = block; - auto backup_container_stmt = current_container_stmt; - visible_stmts.emplace_back(); + current_container_stmt_ ? current_container_stmt_->name() : "nullptr"); + auto backup_block = current_block_; + current_block_ = block; + auto backup_container_stmt = current_container_stmt_; + if (!block->parent_stmt || !block->parent_stmt->is()) + visible_stmts_.emplace_back(); for (auto &stmt : block->statements) { if (stmt->is_container_statement()) - current_container_stmt = stmt.get(); + current_container_stmt_ = stmt.get(); stmt->accept(this); if (stmt->is_container_statement()) - current_container_stmt = backup_container_stmt; + current_container_stmt_ = backup_container_stmt; } - current_block = backup_block; - visible_stmts.pop_back(); + current_block_ = backup_block; + if (!block->parent_stmt || !block->parent_stmt->is()) + current_block_ = backup_block; } void visit(OffloadedStmt *stmt) override { @@ -122,10 +124,13 @@ class IRVerifier : public BasicStmtVisitor { if (stmt->loop->is()) { TI_ASSERT(stmt->loop->as()->task_type == OffloadedStmt::TaskType::struct_for || + stmt->loop->as()->task_type == + OffloadedStmt::TaskType::mesh_for || stmt->loop->as()->task_type == OffloadedStmt::TaskType::range_for); } else { TI_ASSERT(stmt->loop->is() || + stmt->loop->is() || stmt->loop->is()); } } diff --git a/taichi/program/aot_module_builder.cpp b/taichi/aot/module_builder.cpp similarity index 59% rename from taichi/program/aot_module_builder.cpp rename to taichi/aot/module_builder.cpp index 306b5a49997e9..b194d2ee384c6 100644 --- a/taichi/program/aot_module_builder.cpp +++ b/taichi/aot/module_builder.cpp @@ -1,4 +1,4 @@ -#include "taichi/program/aot_module_builder.h" +#include "taichi/aot/module_builder.h" #include "taichi/program/kernel.h" namespace taichi { @@ -12,12 +12,14 @@ void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) { } void AotModuleBuilder::add_field(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, int row_num, int column_num) { - add_per_backend_field(identifier, is_scalar, dt, shape, row_num, column_num); + add_field_per_backend(identifier, rep_snode, is_scalar, dt, shape, row_num, + column_num); } void AotModuleBuilder::add_kernel_template(const std::string &identifier, @@ -29,5 +31,26 @@ void AotModuleBuilder::add_kernel_template(const std::string &identifier, add_per_backend_tmpl(identifier, key, kernel); } +bool AotModuleBuilder::all_fields_are_dense_in_container( + const SNode *container) { + for (const auto &ch : container->ch) { + if (ch->type != SNodeType::place) { + return false; + } + } + const auto *parent = container->parent; + if (!parent) { + return false; + } + if (parent->type != SNodeType::root) { + return false; + } + return true; +} + +void AotModuleBuilder::load(const std::string &output_dir) { + TI_ERROR("Aot loader not supported"); +} + } // namespace lang } // namespace taichi diff --git a/taichi/program/aot_module_builder.h b/taichi/aot/module_builder.h similarity index 64% rename from taichi/program/aot_module_builder.h rename to taichi/aot/module_builder.h index 15fb4013b8662..1b05af1875038 100644 --- a/taichi/program/aot_module_builder.h +++ b/taichi/aot/module_builder.h @@ -3,6 +3,11 @@ #include #include +#include "taichi/aot/module_data.h" +#include "taichi/backends/device.h" +#include "taichi/ir/snode.h" +#include "taichi/aot/module_data.h" + namespace taichi { namespace lang { @@ -16,6 +21,7 @@ class AotModuleBuilder { void add(const std::string &identifier, Kernel *kernel); void add_field(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, @@ -26,6 +32,8 @@ class AotModuleBuilder { const std::string &key, Kernel *kernel); + virtual void load(const std::string &output_dir); + virtual void dump(const std::string &output_dir, const std::string &filename) const = 0; @@ -35,15 +43,27 @@ class AotModuleBuilder { */ virtual void add_per_backend(const std::string &identifier, Kernel *kernel) = 0; - virtual void add_per_backend_field(const std::string &identifier, + virtual void add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, int row_num, int column_num) = 0; + virtual void add_ndarray_per_backend(const std::string &identifier, + bool is_scalar, + DataType dt, + std::vector shape, + int row_num, + int column_num) { + TI_NOT_IMPLEMENTED; + } + virtual void add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) = 0; + + static bool all_fields_are_dense_in_container(const SNode *container); }; } // namespace lang diff --git a/taichi/aot/module_data.h b/taichi/aot/module_data.h new file mode 100644 index 0000000000000..7c401e6148a26 --- /dev/null +++ b/taichi/aot/module_data.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include + +#include "taichi/common/core.h" +#include "taichi/common/serialization.h" + +namespace taichi { +namespace lang { +namespace aot { + +struct CompiledFieldData { + std::string field_name; + uint32_t dtype{0}; + std::string dtype_name; + size_t mem_offset_in_parent{0}; + std::vector shape; + bool is_scalar{false}; + std::vector element_shape; + + TI_IO_DEF(field_name, + dtype, + dtype_name, + mem_offset_in_parent, + shape, + is_scalar, + element_shape); +}; + +struct CompiledOffloadedTask { + std::string type; + std::string range_hint; + std::string name; + // Do we need to inline the source code? + std::string source_path; + int gpu_block_size{0}; + + TI_IO_DEF(type, range_hint, name, source_path, gpu_block_size); +}; + +struct ScalarArg { + std::string dtype_name; + // Unit: byte + size_t offset_in_args_buf{0}; + + TI_IO_DEF(dtype_name, offset_in_args_buf); +}; + +struct ArrayArg { + std::string dtype_name; + std::size_t field_dim{0}; + // If |element_shape| is empty, it means this is a scalar + std::vector element_shape; + // Unit: byte + std::size_t shape_offset_in_args_buf{0}; + // For Vulkan/OpenGL/Metal, this is the binding index + int bind_index{0}; + + TI_IO_DEF(dtype_name, + field_dim, + element_shape, + shape_offset_in_args_buf, + bind_index); +}; + +struct CompiledTaichiKernel { + std::vector tasks; + int args_count{0}; + int rets_count{0}; + size_t args_buffer_size{0}; + size_t rets_buffer_size{0}; + + std::unordered_map scalar_args; + std::unordered_map arr_args; + + TI_IO_DEF(tasks, + args_count, + rets_count, + args_buffer_size, + rets_buffer_size, + scalar_args, + arr_args); +}; + +struct ModuleData { + std::unordered_map kernels; + std::unordered_map kernel_tmpls; + std::vector fields; + + size_t root_buffer_size; + + void dump_json(std::string path) { + TextSerializer ts; + ts.serialize_to_json("aot_data", *this); + ts.write_to_file(path); + } + + TI_IO_DEF(kernels, kernel_tmpls, fields, root_buffer_size); +}; + +} // namespace aot +} // namespace lang +} // namespace taichi diff --git a/taichi/aot/module_loader.cpp b/taichi/aot/module_loader.cpp new file mode 100644 index 0000000000000..9d168d3055da3 --- /dev/null +++ b/taichi/aot/module_loader.cpp @@ -0,0 +1,28 @@ +#include "taichi/aot/module_loader.h" + +#include "taichi/backends/vulkan/aot_module_loader_impl.h" +#include "taichi/backends/metal/aot_module_loader_impl.h" + +namespace taichi { +namespace lang { +namespace aot { + +std::unique_ptr Module::load(const std::string &path, + Arch arch, + std::any mod_params) { + if (arch == Arch::vulkan) { +#ifdef TI_WITH_VULKAN + return vulkan::make_aot_module(mod_params); +#endif + } else if (arch == Arch::metal) { +#ifdef TI_WITH_METAL + return metal::make_aot_module(mod_params); +#endif + } else { + TI_NOT_IMPLEMENTED; + } +} + +} // namespace aot +} // namespace lang +} // namespace taichi diff --git a/taichi/aot/module_loader.h b/taichi/aot/module_loader.h new file mode 100644 index 0000000000000..634f8462e92ec --- /dev/null +++ b/taichi/aot/module_loader.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "taichi/aot/module_data.h" +#include "taichi/backends/device.h" +#include "taichi/ir/snode.h" +#include "taichi/aot/module_data.h" + +namespace taichi { +namespace lang { + +struct RuntimeContext; + +namespace aot { + +class TI_DLL_EXPORT Field { + public: + // Rule of 5 to make MSVC happy + Field() = default; + virtual ~Field() = default; + Field(const Field &) = delete; + Field &operator=(const Field &) = delete; + Field(Field &&) = default; + Field &operator=(Field &&) = default; +}; + +class TI_DLL_EXPORT Kernel { + public: + // Rule of 5 to make MSVC happy + Kernel() = default; + virtual ~Kernel() = default; + Kernel(const Kernel &) = delete; + Kernel &operator=(const Kernel &) = delete; + Kernel(Kernel &&) = default; + Kernel &operator=(Kernel &&) = default; + + /** + * @brief Launches the kernel to the device + * + * This does not manage the device to host synchronization. + * + * @param ctx Host context + */ + virtual void launch(RuntimeContext *ctx) = 0; +}; + +class TI_DLL_EXPORT Module { + public: + // Rule of 5 to make MSVC happy + Module() = default; + virtual ~Module() = default; + Module(const Module &) = delete; + Module &operator=(const Module &) = delete; + Module(Module &&) = default; + Module &operator=(Module &&) = default; + + static std::unique_ptr load(const std::string &path, + Arch arch, + std::any mod_params); + + // Module metadata + virtual Arch arch() const = 0; + virtual uint64_t version() const = 0; + virtual std::unique_ptr get_kernel(const std::string &name) = 0; + virtual std::unique_ptr get_field(const std::string &name) = 0; + virtual size_t get_root_size() const = 0; + + protected: + virtual std::unique_ptr make_new_kernel(const std::string &name) = 0; + + private: + std::unordered_map> loaded_kernels_; +}; + +// Only responsible for reporting device capabilities +class TargetDevice : public Device { + public: + TargetDevice(Arch arch) { + // TODO: make this configurable + set_default_caps(arch); + } + + void set_default_caps(Arch arch) { + if (arch == Arch::vulkan) { + set_cap(DeviceCapability::spirv_version, 0x10300); + } + } + + DeviceAllocation allocate_memory(const AllocParams ¶ms) override { + TI_NOT_IMPLEMENTED; + } + void dealloc_memory(DeviceAllocation handle) override { + TI_NOT_IMPLEMENTED; + } + std::unique_ptr create_pipeline( + const PipelineSourceDesc &src, + std::string name = "Pipeline") override { + TI_NOT_IMPLEMENTED; + } + void *map_range(DevicePtr ptr, uint64_t size) override { + TI_NOT_IMPLEMENTED; + } + void *map(DeviceAllocation alloc) override { + TI_NOT_IMPLEMENTED; + } + void unmap(DevicePtr ptr) override { + TI_NOT_IMPLEMENTED; + } + void unmap(DeviceAllocation alloc) override { + TI_NOT_IMPLEMENTED; + } + void memcpy_internal(DevicePtr dst, DevicePtr src, uint64_t size) override { + TI_NOT_IMPLEMENTED; + } + Stream *get_compute_stream() override { + TI_NOT_IMPLEMENTED; + } +}; + +} // namespace aot +} // namespace lang +} // namespace taichi diff --git a/taichi/program/arch.cpp b/taichi/backends/arch.cpp similarity index 94% rename from taichi/program/arch.cpp rename to taichi/backends/arch.cpp index 53848d8c58751..9de9c8e82f7e3 100644 --- a/taichi/program/arch.cpp +++ b/taichi/backends/arch.cpp @@ -1,6 +1,6 @@ -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { std::string arch_name(Arch arch) { switch (arch) { @@ -79,4 +79,4 @@ int default_simd_width(Arch arch) { } } -TLANG_NAMESPACE_END +} // namespace taichi diff --git a/taichi/program/arch.h b/taichi/backends/arch.h similarity index 92% rename from taichi/program/arch.h rename to taichi/backends/arch.h index b2627b5ff4d22..2d7cffde8950f 100644 --- a/taichi/program/arch.h +++ b/taichi/backends/arch.h @@ -4,7 +4,6 @@ #include "taichi/common/core.h" namespace taichi { -namespace lang { enum class Arch : int { #define PER_ARCH(x) x, @@ -29,5 +28,4 @@ bool arch_use_host_memory(Arch arch); int default_simd_width(Arch arch); -} // namespace lang } // namespace taichi diff --git a/taichi/backends/cc/cc_kernel.h b/taichi/backends/cc/cc_kernel.h index fc322d4375fdd..5ec18eebfeeac 100644 --- a/taichi/backends/cc/cc_kernel.h +++ b/taichi/backends/cc/cc_kernel.h @@ -24,7 +24,7 @@ class CCKernel { } void compile(); - void launch(Context *ctx); + void launch(RuntimeContext *ctx); std::string get_object() { return obj_path_; } diff --git a/taichi/backends/cc/cc_program.cpp b/taichi/backends/cc/cc_program.cpp index 9c704c19c9cb7..49d9d1d6e866a 100644 --- a/taichi/backends/cc/cc_program.cpp +++ b/taichi/backends/cc/cc_program.cpp @@ -20,7 +20,7 @@ FunctionType CCProgramImpl::compile(Kernel *kernel, OffloadedStmt *) { auto ker = codegen.compile(); auto ker_ptr = ker.get(); this->add_kernel(std::move(ker)); - return [ker_ptr](Context &ctx) { return ker_ptr->launch(&ctx); }; + return [ker_ptr](RuntimeContext &ctx) { return ker_ptr->launch(&ctx); }; } void CCProgramImpl::materialize_runtime(MemoryPool *memory_pool, @@ -35,14 +35,13 @@ void CCProgramImpl::materialize_runtime(MemoryPool *memory_pool, void CCProgramImpl::materialize_snode_tree( SNodeTree *tree, std::vector> &, - std::unordered_map &, uint64 *result_buffer) { auto *const root = tree->root(); CCLayoutGen gen(this, root); layout_ = gen.compile(); size_t root_size = layout_->compile(); size_t gtmp_size = taichi_global_tmp_buffer_size; - size_t args_size = taichi_max_num_args * sizeof(uint64); + size_t args_size = taichi_result_buffer_entries * sizeof(uint64); TI_INFO("[cc] C backend root buffer size: {} B", root_size); @@ -100,7 +99,7 @@ void CCRuntime::compile() { execute(cc_program_impl_->config->cc_compile_cmd, obj_path_, src_path_); } -void CCKernel::launch(Context *ctx) { +void CCKernel::launch(RuntimeContext *ctx) { if (!kernel_->is_evaluator) ActionRecorder::get_instance().record("launch_kernel", { @@ -180,7 +179,7 @@ CCFuncEntryType *CCProgramImpl::load_kernel(std::string const &name) { return reinterpret_cast(dll_->load_function("Tk_" + name)); } -CCContext *CCProgramImpl::update_context(Context *ctx) { +CCContext *CCProgramImpl::update_context(RuntimeContext *ctx) { // TODO(k-ye): Do you have other zero-copy ideas for arg buf? std::memcpy(context_->args, ctx->args, taichi_max_num_args * sizeof(uint64)); context_->earg = (int *)ctx->extra_args; @@ -190,7 +189,7 @@ CCContext *CCProgramImpl::update_context(Context *ctx) { void CCProgramImpl::context_to_result_buffer() { TI_ASSERT(result_buffer_); std::memcpy(result_buffer_, context_->args, - sizeof(uint64)); // XXX: assumed 1 return + taichi_max_num_ret_value * sizeof(uint64)); context_->earg = nullptr; } diff --git a/taichi/backends/cc/cc_program.h b/taichi/backends/cc/cc_program.h index e6400ad6e3c06..adeea855ef96f 100644 --- a/taichi/backends/cc/cc_program.h +++ b/taichi/backends/cc/cc_program.h @@ -43,7 +43,6 @@ class CCProgramImpl : public ProgramImpl { void materialize_snode_tree(SNodeTree *tree, std::vector> &, - std::unordered_map &, uint64 *result_buffer) override; void synchronize() override { @@ -67,13 +66,13 @@ class CCProgramImpl : public ProgramImpl { return runtime_.get(); } - ~CCProgramImpl() { + ~CCProgramImpl() override { } CCFuncEntryType *load_kernel(std::string const &name); void relink(); - CCContext *update_context(Context *ctx); + CCContext *update_context(RuntimeContext *ctx); void context_to_result_buffer(); private: diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index cc6231e28aa5f..3b54b129a81cd 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -7,7 +7,6 @@ #include "taichi/ir/transforms.h" #include "taichi/util/line_appender.h" #include "taichi/util/str.h" -#include "taichi/llvm/llvm_program.h" #include "cc_utils.h" #define C90_COMPAT 0 @@ -23,51 +22,50 @@ std::string get_node_ptr_name(SNode *snode) { class CCTransformer : public IRVisitor { private: - [[maybe_unused]] Kernel *kernel; - [[maybe_unused]] CCLayout *layout; + [[maybe_unused]] Kernel *kernel_; + [[maybe_unused]] CCLayout *layout_; - LineAppender line_appender; - LineAppender line_appender_header; - bool is_top_level{true}; - GetRootStmt *root_stmt; + LineAppender line_appender_; + LineAppender line_appender_header_; + bool is_top_level_{true}; + GetRootStmt *root_stmt_; public: CCTransformer(Kernel *kernel, CCLayout *layout) - : kernel(kernel), layout(layout) { + : kernel_(kernel), layout_(layout) { allow_undefined_visitor = true; invoke_default_visitor = true; } void run() { this->lower_ast(); - emit_header("void Tk_{}(struct Ti_Context *ti_ctx) {{", kernel->name); - kernel->ir->accept(this); + emit_header("void Tk_{}(struct Ti_Context *ti_ctx) {{", kernel_->name); + kernel_->ir->accept(this); emit("}}"); } void lower_ast() { - auto ir = kernel->ir.get(); - auto config = kernel->program->config; + auto ir = kernel_->ir.get(); + auto config = kernel_->program->config; config.demote_dense_struct_fors = true; - irpass::compile_to_executable(ir, config, kernel, - /*vectorize=*/false, kernel->grad, + irpass::compile_to_executable(ir, config, kernel_, kernel_->grad, /*ad_use_stack=*/true, config.print_ir, /*lower_global_access*/ true); } std::string get_source() { - return line_appender_header.lines() + line_appender.lines(); + return line_appender_header_.lines() + line_appender_.lines(); } private: void visit(Block *stmt) override { - if (!is_top_level) - line_appender.push_indent(); + if (!is_top_level_) + line_appender_.push_indent(); for (auto &s : stmt->statements) { s->accept(this); } - if (!is_top_level) - line_appender.pop_indent(); + if (!is_top_level_) + line_appender_.pop_indent(); } void visit(Stmt *stmt) override { @@ -90,10 +88,10 @@ class CCTransformer : public IRVisitor { } void visit(GetRootStmt *stmt) override { - auto *root = kernel->program->get_snode_root(SNodeTree::kFirstID); + auto *root = kernel_->program->get_snode_root(SNodeTree::kFirstID); emit("{} = ti_ctx->root;", define_var(get_node_ptr_name(root), stmt->raw_name())); - root_stmt = stmt; + root_stmt_ = stmt; } void visit(SNodeLookupStmt *stmt) override { @@ -101,8 +99,8 @@ class CCTransformer : public IRVisitor { if (stmt->input_snode) { input_ptr = stmt->input_snode; } else { - TI_ASSERT(root_stmt != nullptr); - input_ptr = root_stmt; + TI_ASSERT(root_stmt_ != nullptr); + input_ptr = root_stmt_; } emit("{} = &{}[{}];", @@ -186,8 +184,11 @@ class CCTransformer : public IRVisitor { } void visit(ReturnStmt *stmt) override { - emit("ti_ctx->args[0].val_{} = {};", data_type_name(stmt->element_type()), - stmt->value->raw_name()); + int idx{0}; + for (auto &value : stmt->values) { + emit("ti_ctx->args[{}].val_{} = {};", idx++, + data_type_name(value->element_type()), value->raw_name()); + } } void visit(ConstStmt *stmt) override { @@ -222,8 +223,8 @@ class CCTransformer : public IRVisitor { } void visit(ExternalFuncCallStmt *stmt) override { - TI_ASSERT(!stmt->func); - auto format = stmt->source; + TI_ASSERT(stmt->type == ExternalFuncCallStmt::ASSEMBLY); + auto format = stmt->asm_source; std::string source; for (int i = 0; i < format.size(); i++) { @@ -350,6 +351,9 @@ class CCTransformer : public IRVisitor { } } + void visit(DecorationStmt *stmt) override { + } + void visit(UnaryOpStmt *stmt) override { TI_ASSERT(stmt->width() == 1); const auto dt_name = cc_data_type_name(stmt->element_type()); @@ -421,7 +425,7 @@ class CCTransformer : public IRVisitor { void generate_range_for_kernel(OffloadedStmt *stmt) { if (stmt->const_begin && stmt->const_end) { - ScopedIndent _s(line_appender); + ScopedIndent _s(line_appender_); auto begin_value = stmt->begin_value; auto end_value = stmt->end_value; auto var = define_var("Ti_i32", stmt->raw_name()); @@ -455,8 +459,8 @@ class CCTransformer : public IRVisitor { } void visit(OffloadedStmt *stmt) override { - TI_ASSERT(is_top_level); - is_top_level = false; + TI_ASSERT(is_top_level_); + is_top_level_ = false; if (stmt->task_type == OffloadedStmt::TaskType::serial) { generate_serial_kernel(stmt); } else if (stmt->task_type == OffloadedStmt::TaskType::range_for) { @@ -465,7 +469,7 @@ class CCTransformer : public IRVisitor { TI_ERROR("[glsl] Unsupported offload type={} on C backend", stmt->task_name()); } - is_top_level = true; + is_top_level_ = true; } void visit(LoopIndexStmt *stmt) override { @@ -591,12 +595,12 @@ class CCTransformer : public IRVisitor { template void emit(std::string f, Args &&... args) { - line_appender.append(std::move(f), std::move(args)...); + line_appender_.append(std::move(f), std::move(args)...); } template void emit_header(std::string f, Args &&... args) { - line_appender_header.append(std::move(f), std::move(args)...); + line_appender_header_.append(std::move(f), std::move(args)...); } }; // namespace cccp diff --git a/taichi/backends/cpu/codegen_cpu.cpp b/taichi/backends/cpu/codegen_cpu.cpp index 97cef2ab2b60f..0b1809e4d75fe 100644 --- a/taichi/backends/cpu/codegen_cpu.cpp +++ b/taichi/backends/cpu/codegen_cpu.cpp @@ -1,6 +1,7 @@ #include "taichi/backends/cpu/codegen_cpu.h" #include "taichi/codegen/codegen_llvm.h" +#include "taichi/llvm/llvm_program.h" #include "taichi/common/core.h" #include "taichi/util/io.h" #include "taichi/lang_util.h" @@ -35,7 +36,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM { llvm::Function *body; { auto guard = get_function_creation_guard( - {llvm::PointerType::get(get_runtime_type("Context"), 0), + {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), llvm::Type::getInt8PtrTy(*llvm_context), tlctx->get_data_type()}); @@ -57,10 +58,96 @@ class CodeGenLLVMCPU : public CodeGenLLVM { tls_prologue, body, epilogue, tlctx->get_constant(stmt->tls_size)}); } + void create_offload_mesh_for(OffloadedStmt *stmt) override { + auto *tls_prologue = create_mesh_xlogue(stmt->tls_prologue); + + llvm::Function *body; + { + auto guard = get_function_creation_guard( + {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), + llvm::Type::getInt8PtrTy(*llvm_context), + tlctx->get_data_type()}); + + for (int i = 0; i < stmt->mesh_prologue->size(); i++) { + auto &s = stmt->mesh_prologue->statements[i]; + s->accept(this); + } + + if (stmt->bls_prologue) { + stmt->bls_prologue->accept(this); + } + + auto loop_test_bb = + llvm::BasicBlock::Create(*llvm_context, "loop_test", func); + auto loop_body_bb = + llvm::BasicBlock::Create(*llvm_context, "loop_body", func); + auto func_exit = + llvm::BasicBlock::Create(*llvm_context, "func_exit", func); + auto loop_index = + create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context)); + builder->CreateStore(tlctx->get_constant(0), loop_index); + builder->CreateBr(loop_test_bb); + + { + builder->SetInsertPoint(loop_test_bb); + auto cond = builder->CreateICmp( + llvm::CmpInst::Predicate::ICMP_SLT, builder->CreateLoad(loop_index), + llvm_val[stmt->owned_num_local.find(stmt->major_from_type) + ->second]); + builder->CreateCondBr(cond, loop_body_bb, func_exit); + } + + { + builder->SetInsertPoint(loop_body_bb); + loop_vars_llvm[stmt].push_back(loop_index); + for (int i = 0; i < stmt->body->size(); i++) { + auto &s = stmt->body->statements[i]; + s->accept(this); + } + builder->CreateStore(builder->CreateAdd(builder->CreateLoad(loop_index), + tlctx->get_constant(1)), + loop_index); + builder->CreateBr(loop_test_bb); + builder->SetInsertPoint(func_exit); + } + + if (stmt->bls_epilogue) { + stmt->bls_epilogue->accept(this); + } + + body = guard.body; + } + + llvm::Value *epilogue = create_mesh_xlogue(stmt->tls_epilogue); + + create_call("cpu_parallel_mesh_for", + {get_arg(0), tlctx->get_constant(stmt->num_cpu_threads), + tlctx->get_constant(stmt->mesh->num_patches), + tlctx->get_constant(stmt->block_dim), tls_prologue, body, + epilogue, tlctx->get_constant(stmt->tls_size)}); + } + + void create_bls_buffer(OffloadedStmt *stmt) { + auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context), + stmt->bls_size); + bls_buffer = new llvm::GlobalVariable( + *module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr, + "bls_buffer", nullptr, llvm::GlobalVariable::LocalExecTLSModel, 0); + /* module->getOrInsertGlobal("bls_buffer", type); + bls_buffer = module->getNamedGlobal("bls_buffer"); + bls_buffer->setAlignment(llvm::MaybeAlign(8));*/ // TODO(changyu): Fix JIT session error: Symbols not found: [ __emutls_get_address ] in python 3.10 + + // initialize the variable with an undef value to ensure it is added to the + // symbol table + bls_buffer->setInitializer(llvm::UndefValue::get(type)); + } + void visit(OffloadedStmt *stmt) override { stat.add("codegen_offloaded_tasks"); TI_ASSERT(current_offload == nullptr); current_offload = stmt; + if (stmt->bls_size > 0) + create_bls_buffer(stmt); using Type = OffloadedStmt::TaskType; auto offloaded_task_name = init_offloaded_task_function(stmt); if (prog->config.kernel_profiler && arch_is_cpu(prog->config.arch)) { @@ -72,6 +159,8 @@ class CodeGenLLVMCPU : public CodeGenLLVM { stmt->body->accept(this); } else if (stmt->task_type == Type::range_for) { create_offload_range_for(stmt); + } else if (stmt->task_type == Type::mesh_for) { + create_offload_mesh_for(stmt); } else if (stmt->task_type == Type::struct_for) { stmt->block_dim = std::min(stmt->snode->parent->max_num_elements(), (int64)stmt->block_dim); @@ -93,30 +182,13 @@ class CodeGenLLVMCPU : public CodeGenLLVM { } void visit(ExternalFuncCallStmt *stmt) override { - std::vector arg_types; - std::vector arg_values; - - for (auto s : stmt->arg_stmts) { - TI_ASSERT(s->width() == 1); - arg_types.push_back(tlctx->get_data_type(s->ret_type)); - arg_values.push_back(llvm_val[s]); - } - - for (auto s : stmt->output_stmts) { - TI_ASSERT(s->width() == 1); - auto t = tlctx->get_data_type(s->ret_type); - auto ptr = llvm::PointerType::get(t, 0); - arg_types.push_back(ptr); - arg_values.push_back(llvm_val[s]); + if (stmt->type == ExternalFuncCallStmt::BITCODE) { + CodeGenLLVM::visit_call_bitcode(stmt); + } else if (stmt->type == ExternalFuncCallStmt::SHARED_OBJECT) { + CodeGenLLVM::visit_call_shared_object(stmt); + } else { + TI_NOT_IMPLEMENTED } - - auto func_type = llvm::FunctionType::get( - llvm::Type::getVoidTy(*llvm_context), arg_types, false); - auto func_ptr_type = llvm::PointerType::get(func_type, 0); - - auto addr = tlctx->get_constant((std::size_t)stmt->func); - auto func = builder->CreateIntToPtr(addr, func_ptr_type); - builder->CreateCall(func, arg_values); } }; diff --git a/taichi/backends/cpu/codegen_cpu.h b/taichi/backends/cpu/codegen_cpu.h index e203b43e99f88..c3d723c75eff2 100644 --- a/taichi/backends/cpu/codegen_cpu.h +++ b/taichi/backends/cpu/codegen_cpu.h @@ -11,7 +11,7 @@ class CodeGenCPU : public KernelCodeGen { CodeGenCPU(Kernel *kernel, IRNode *ir = nullptr) : KernelCodeGen(kernel, ir) { } - virtual FunctionType codegen() override; + FunctionType codegen() override; }; TLANG_NAMESPACE_END diff --git a/taichi/backends/cpu/cpu_device.cpp b/taichi/backends/cpu/cpu_device.cpp index ab918cc1d0b10..60a8bc26f883f 100644 --- a/taichi/backends/cpu/cpu_device.cpp +++ b/taichi/backends/cpu/cpu_device.cpp @@ -5,7 +5,7 @@ namespace lang { namespace cpu { -CpuDevice::AllocInfo CpuDevice::get_alloc_info(DeviceAllocation handle) { +CpuDevice::AllocInfo CpuDevice::get_alloc_info(const DeviceAllocation handle) { validate_device_alloc(handle); return allocations_[handle.alloc_id]; } @@ -16,6 +16,7 @@ DeviceAllocation CpuDevice::allocate_memory(const AllocParams ¶ms) { auto vm = std::make_unique(params.size); info.ptr = vm->ptr; info.size = vm->size; + info.use_cached = false; DeviceAllocation alloc; alloc.alloc_id = allocations_.size(); @@ -26,15 +27,32 @@ DeviceAllocation CpuDevice::allocate_memory(const AllocParams ¶ms) { return alloc; } +DeviceAllocation CpuDevice::allocate_memory_runtime( + const LlvmRuntimeAllocParams ¶ms) { + AllocInfo info; + info.ptr = allocate_llvm_runtime_memory_jit(params); + // TODO: Add caching allocator + info.size = params.size; + info.use_cached = params.use_cached; + DeviceAllocation alloc; + alloc.alloc_id = allocations_.size(); + alloc.device = this; + + allocations_.push_back(info); + return alloc; +} + void CpuDevice::dealloc_memory(DeviceAllocation handle) { validate_device_alloc(handle); AllocInfo &info = allocations_[handle.alloc_id]; if (info.ptr == nullptr) { TI_ERROR("the DeviceAllocation is already deallocated"); } - // Use at() to ensure that the memory is allocated, and not imported - virtual_memories_.at(handle.alloc_id).reset(); - info.ptr = nullptr; + if (!info.use_cached) { + // Use at() to ensure that the memory is allocated, and not imported + virtual_memories_.at(handle.alloc_id).reset(); + info.ptr = nullptr; + } } DeviceAllocation CpuDevice::import_memory(void *ptr, size_t size) { @@ -50,6 +68,11 @@ DeviceAllocation CpuDevice::import_memory(void *ptr, size_t size) { return alloc; } +uint64 CpuDevice::fetch_result_uint64(int i, uint64 *result_buffer) { + uint64 ret = result_buffer[i]; + return ret; +} + } // namespace cpu } // namespace lang } // namespace taichi diff --git a/taichi/backends/cpu/cpu_device.h b/taichi/backends/cpu/cpu_device.h index fba75b86462b9..5d5ccfd5e5ff1 100644 --- a/taichi/backends/cpu/cpu_device.h +++ b/taichi/backends/cpu/cpu_device.h @@ -5,7 +5,7 @@ #include #include "taichi/common/core.h" -#include "taichi/backends/device.h" +#include "taichi/llvm/llvm_device.h" #include "taichi/system/virtual_memory.h" namespace taichi { @@ -75,24 +75,29 @@ class CpuStream : public Stream { void command_sync() override{TI_NOT_IMPLEMENTED}; }; -class CpuDevice : public Device { +class CpuDevice : public LlvmDevice { public: struct AllocInfo { void *ptr{nullptr}; size_t size{0}; + bool use_cached{false}; }; - AllocInfo get_alloc_info(DeviceAllocation handle); + AllocInfo get_alloc_info(const DeviceAllocation handle); ~CpuDevice() override{}; DeviceAllocation allocate_memory(const AllocParams ¶ms) override; + DeviceAllocation allocate_memory_runtime( + const LlvmRuntimeAllocParams ¶ms) override; void dealloc_memory(DeviceAllocation handle) override; std::unique_ptr create_pipeline( const PipelineSourceDesc &src, std::string name = "Pipeline") override{TI_NOT_IMPLEMENTED}; + uint64 fetch_result_uint64(int i, uint64 *result_buffer) override; + void *map_range(DevicePtr ptr, uint64_t size) override{TI_NOT_IMPLEMENTED}; void *map(DeviceAllocation alloc) override{TI_NOT_IMPLEMENTED}; @@ -111,7 +116,7 @@ class CpuDevice : public Device { std::unordered_map> virtual_memories_; - void validate_device_alloc(DeviceAllocation alloc) { + void validate_device_alloc(const DeviceAllocation alloc) { if (allocations_.size() <= alloc.alloc_id) { TI_ERROR("invalid DeviceAllocation"); } diff --git a/taichi/backends/cpu/jit_cpu.cpp b/taichi/backends/cpu/jit_cpu.cpp index c414f1131af35..46366778d2f27 100644 --- a/taichi/backends/cpu/jit_cpu.cpp +++ b/taichi/backends/cpu/jit_cpu.cpp @@ -2,6 +2,7 @@ #include +#ifdef TI_WITH_LLVM #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" @@ -18,6 +19,7 @@ #include "llvm/ExecutionEngine/RuntimeDyld.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/Module.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/LLVMContext.h" @@ -32,6 +34,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/IPO.h" +#endif #include "taichi/lang_util.h" #include "taichi/program/program.h" @@ -41,8 +44,10 @@ TLANG_NAMESPACE_BEGIN +#ifdef TI_WITH_LLVM using namespace llvm; using namespace llvm::orc; +#endif std::pair get_host_target_info() { #if defined(TI_PLATFORM_OSX) and defined(TI_ARCH_ARM) @@ -69,12 +74,12 @@ class JITSessionCPU; class JITModuleCPU : public JITModule { private: - JITSessionCPU *session; - JITDylib *dylib; + JITSessionCPU *session_; + JITDylib *dylib_; public: JITModuleCPU(JITSessionCPU *session, JITDylib *dylib) - : session(session), dylib(dylib) { + : session_(session), dylib_(dylib) { } void *lookup_function(const std::string &name) override; @@ -86,45 +91,48 @@ class JITModuleCPU : public JITModule { class JITSessionCPU : public JITSession { private: - ExecutionSession ES; - RTDyldObjectLinkingLayer object_layer; - IRCompileLayer compile_layer; - DataLayout DL; - MangleAndInterner Mangle; - std::mutex mut; - std::vector all_libs; - int module_counter; - SectionMemoryManager *memory_manager; + ExecutionSession es_; + RTDyldObjectLinkingLayer object_layer_; + IRCompileLayer compile_layer_; + DataLayout dl_; + MangleAndInterner mangle_; + std::mutex mut_; + std::vector all_libs_; + int module_counter_; + SectionMemoryManager *memory_manager_; public: - JITSessionCPU(JITTargetMachineBuilder JTMB, DataLayout DL) - : object_layer(ES, - [&]() { - auto smgr = std::make_unique(); - memory_manager = smgr.get(); - return smgr; - }), - compile_layer(ES, - object_layer, - std::make_unique(JTMB)), - DL(DL), - Mangle(ES, this->DL), - module_counter(0), - memory_manager(nullptr) { + JITSessionCPU(LlvmProgramImpl *llvm_prog, + JITTargetMachineBuilder JTMB, + DataLayout DL) + : JITSession(llvm_prog), + object_layer_(es_, + [&]() { + auto smgr = std::make_unique(); + memory_manager_ = smgr.get(); + return smgr; + }), + compile_layer_(es_, + object_layer_, + std::make_unique(JTMB)), + dl_(DL), + mangle_(es_, this->dl_), + module_counter_(0), + memory_manager_(nullptr) { if (JTMB.getTargetTriple().isOSBinFormatCOFF()) { - object_layer.setOverrideObjectFlagsWithResponsibilityFlags(true); - object_layer.setAutoClaimResponsibilityForObjectSymbols(true); + object_layer_.setOverrideObjectFlagsWithResponsibilityFlags(true); + object_layer_.setAutoClaimResponsibilityForObjectSymbols(true); } } - ~JITSessionCPU() { - std::lock_guard _(mut); - if (memory_manager) - memory_manager->deregisterEHFrames(); + ~JITSessionCPU() override { + std::lock_guard _(mut_); + if (memory_manager_) + memory_manager_->deregisterEHFrames(); } DataLayout get_data_layout() override { - return DL; + return dl_; } void global_optimize_module(llvm::Module *module) override { @@ -135,31 +143,31 @@ class JITSessionCPU : public JITSession { TI_ASSERT(max_reg == 0); // No need to specify max_reg on CPUs TI_ASSERT(M); global_optimize_module_cpu(M.get()); - std::lock_guard _(mut); - auto &dylib = ES.createJITDylib(fmt::format("{}", module_counter)); + std::lock_guard _(mut_); + auto &dylib = es_.createJITDylib(fmt::format("{}", module_counter_)); dylib.addGenerator( cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - DL.getGlobalPrefix()))); - auto *thread_safe_context = get_current_program() - .get_llvm_program_impl() + dl_.getGlobalPrefix()))); + auto *thread_safe_context = this->llvm_prog() ->get_llvm_context(host_arch()) ->get_this_thread_thread_safe_context(); - cantFail(compile_layer.add(dylib, llvm::orc::ThreadSafeModule( - std::move(M), *thread_safe_context))); - all_libs.push_back(&dylib); + cantFail(compile_layer_.add( + dylib, + llvm::orc::ThreadSafeModule(std::move(M), *thread_safe_context))); + all_libs_.push_back(&dylib); auto new_module = std::make_unique(this, &dylib); auto new_module_raw_ptr = new_module.get(); modules.push_back(std::move(new_module)); - module_counter++; + module_counter_++; return new_module_raw_ptr; } void *lookup(const std::string Name) override { - std::lock_guard _(mut); + std::lock_guard _(mut_); #ifdef __APPLE__ - auto symbol = ES.lookup(all_libs, Mangle(Name)); + auto symbol = es_.lookup(all_libs_, mangle_(Name)); #else - auto symbol = ES.lookup(all_libs, ES.intern(Name)); + auto symbol = es_.lookup(all_libs_, es_.intern(Name)); #endif if (!symbol) TI_ERROR("Function \"{}\" not found", Name); @@ -167,11 +175,11 @@ class JITSessionCPU : public JITSession { } void *lookup_in_module(JITDylib *lib, const std::string Name) { - std::lock_guard _(mut); + std::lock_guard _(mut_); #ifdef __APPLE__ - auto symbol = ES.lookup({lib}, Mangle(Name)); + auto symbol = es_.lookup({lib}, mangle_(Name)); #else - auto symbol = ES.lookup({lib}, ES.intern(Name)); + auto symbol = es_.lookup({lib}, es_.intern(Name)); #endif if (!symbol) TI_ERROR("Function \"{}\" not found", Name); @@ -179,11 +187,11 @@ class JITSessionCPU : public JITSession { } private: - static void global_optimize_module_cpu(llvm::Module *module); + void global_optimize_module_cpu(llvm::Module *module); }; void *JITModuleCPU::lookup_function(const std::string &name) { - return session->lookup_in_module(dylib, name); + return session_->lookup_in_module(dylib_, name); } void JITSessionCPU::global_optimize_module_cpu(llvm::Module *module) { @@ -202,8 +210,7 @@ void JITSessionCPU::global_optimize_module_cpu(llvm::Module *module) { TargetOptions options; options.PrintMachineCode = false; - bool fast_math = get_current_program().config.fast_math; - if (fast_math) { + if (this->llvm_prog()->config->fast_math) { options.AllowFPOpFusion = FPOpFusion::Fast; options.UnsafeFPMath = 1; options.NoInfsFPMath = 1; @@ -261,7 +268,7 @@ void JITSessionCPU::global_optimize_module_cpu(llvm::Module *module) { module_pass_manager.run(*module); } - if (get_current_program().config.print_kernel_llvm_ir_optimized) { + if (this->llvm_prog()->config->print_kernel_llvm_ir_optimized) { if (false) { TI_INFO("Functions with > 100 instructions in optimized LLVM IR:"); TaichiLLVMContext::print_huge_functions(module); @@ -273,10 +280,13 @@ void JITSessionCPU::global_optimize_module_cpu(llvm::Module *module) { } } -std::unique_ptr create_llvm_jit_session_cpu(Arch arch) { +std::unique_ptr create_llvm_jit_session_cpu( + LlvmProgramImpl *llvm_prog, + Arch arch) { TI_ASSERT(arch_is_cpu(arch)); auto target_info = get_host_target_info(); - return std::make_unique(target_info.first, target_info.second); + return std::make_unique(llvm_prog, target_info.first, + target_info.second); } TLANG_NAMESPACE_END diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 39f6970dd6377..45e1e6a9affef 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "taichi/common/core.h" #include "taichi/util/io.h" @@ -48,7 +49,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { jit->add_module(std::move(module), kernel->program->config.gpu_max_reg); return [offloaded_local, cuda_module, - kernel = this->kernel](Context &context) { + kernel = this->kernel](RuntimeContext &context) { CUDAContext::get_instance().make_current(); auto args = kernel->args; std::vector arg_buffers(args.size(), nullptr); @@ -61,31 +62,59 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { Kernel::LaunchContextBuilder ctx_builder(kernel, &context); bool transferred = false; for (int i = 0; i < (int)args.size(); i++) { - if (args[i].is_external_array && args[i].size > 0) { - // Note: both numpy and PyTorch support arrays/tensors with zeros - // in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes - // args[i].size = 0. + if (args[i].is_array) { + if (args[i].size == 0) + continue; arg_buffers[i] = context.get_arg(i); - unsigned int attr_val = 0; - uint32_t ret_code = CUDADriver::get_instance().mem_get_attribute.call( - &attr_val, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, - (void *)arg_buffers[i]); - if (ret_code != CUDA_SUCCESS || attr_val != CU_MEMORYTYPE_DEVICE) { - // Copy to device buffer if arg is on host - // - ret_code != CUDA_SUCCESS: - // arg_buffers[i] is not on device - // - attr_val != CU_MEMORYTYPE_DEVICE: - // Cuda driver is aware of arg_buffers[i] but it might be on host. - // See CUDA driver API `cuPointerGetAttribute` for more details. - transferred = true; - CUDADriver::get_instance().malloc(&device_buffers[i], args[i].size); - CUDADriver::get_instance().memcpy_host_to_device( - (void *)device_buffers[i], arg_buffers[i], args[i].size); - } else { - device_buffers[i] = arg_buffers[i]; + if (!context.is_device_allocation[i]) { + // Note: both numpy and PyTorch support arrays/tensors with zeros + // in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes + // args[i].size = 0. + unsigned int attr_val = 0; + uint32_t ret_code = + CUDADriver::get_instance().mem_get_attribute.call( + &attr_val, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, + (void *)arg_buffers[i]); + if (ret_code != CUDA_SUCCESS || attr_val != CU_MEMORYTYPE_DEVICE) { + // Copy to device buffer if arg is on host + // - ret_code != CUDA_SUCCESS: + // arg_buffers[i] is not on device + // - attr_val != CU_MEMORYTYPE_DEVICE: + // Cuda driver is aware of arg_buffers[i] but it might be on + // host. + // See CUDA driver API `cuPointerGetAttribute` for more details. + transferred = true; + CUDADriver::get_instance().malloc(&device_buffers[i], + args[i].size); + CUDADriver::get_instance().memcpy_host_to_device( + (void *)device_buffers[i], arg_buffers[i], args[i].size); + } else { + device_buffers[i] = arg_buffers[i]; + } + // device_buffers[i] saves a raw ptr on CUDA device. + ctx_builder.set_arg_external_array(i, (uint64)device_buffers[i], + args[i].size, + /*is_device_allocation=*/false); + + } else if (args[i].size > 0) { + // arg_buffers[i] is a DeviceAllocation* + // TODO: Unwraps DeviceAllocation* can be done at CodeGenLLVM since + // it's shared by cpu and cuda. + DeviceAllocation *ptr = + static_cast(arg_buffers[i]); + device_buffers[i] = kernel->program->get_llvm_program_impl() + ->get_ndarray_alloc_info_ptr(*ptr); + // We compare arg_buffers[i] and device_buffers[i] later to check + // if transfer happened. + // TODO: this logic can be improved but I'll leave it to a followup + // PR. + arg_buffers[i] = device_buffers[i]; + + // device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i] + ctx_builder.set_arg_external_array(i, (uint64)device_buffers[i], + args[i].size, + /*is_device_allocation=*/false); } - ctx_builder.set_arg_external_array(i, (uint64)device_buffers[i], - args[i].size); } } if (transferred) { @@ -162,7 +191,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { auto value_type = tlctx->get_data_type(arg_stmt->ret_type); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32)) { + if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || + arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { value_type = tlctx->get_data_type(PrimitiveType::f64); value = builder->CreateFPExt(value, value_type); } @@ -190,49 +220,49 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { // functions from libdevice auto input = llvm_val[stmt->operand]; auto input_taichi_type = stmt->operand->ret_type; + if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) { + // Promote to f32 since we don't have f16 support for extra unary ops in + // libdevice. + input = + builder->CreateFPExt(input, llvm::Type::getFloatTy(*llvm_context)); + input_taichi_type = PrimitiveType::f32; + } + auto op = stmt->op_type; -#define UNARY_STD(x) \ - else if (op == UnaryOpType::x) { \ - if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \ - llvm_val[stmt] = \ - builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \ - } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \ - llvm_val[stmt] = \ - builder->CreateCall(get_runtime_function("__nv_" #x), input); \ - } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \ - llvm_val[stmt] = builder->CreateCall(get_runtime_function(#x), input); \ - } else { \ - TI_NOT_IMPLEMENTED \ - } \ +#define UNARY_STD(x) \ + else if (op == UnaryOpType::x) { \ + if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \ + llvm_val[stmt] = create_call("__nv_" #x "f", input); \ + } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \ + llvm_val[stmt] = create_call("__nv_" #x, input); \ + } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \ + llvm_val[stmt] = create_call(#x, input); \ + } else { \ + TI_NOT_IMPLEMENTED \ + } \ } if (op == UnaryOpType::abs) { if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("__nv_fabsf"), input); + llvm_val[stmt] = create_call("__nv_fabsf", input); } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("__nv_fabs"), input); + llvm_val[stmt] = create_call("__nv_fabs", input); } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("__nv_abs"), input); + llvm_val[stmt] = create_call("__nv_abs", input); } else { TI_NOT_IMPLEMENTED } } else if (op == UnaryOpType::sqrt) { if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("__nv_sqrtf"), input); + llvm_val[stmt] = create_call("__nv_sqrtf", input); } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("__nv_sqrt"), input); + llvm_val[stmt] = create_call("__nv_sqrt", input); } else { TI_NOT_IMPLEMENTED } } else if (op == UnaryOpType::logic_not) { if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("logic_not_i32"), input); + llvm_val[stmt] = create_call("logic_not_i32", input); } else { TI_NOT_IMPLEMENTED } @@ -251,6 +281,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { TI_NOT_IMPLEMENTED } #undef UNARY_STD + if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + // Convert back to f16. + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); + } } // Not all reduction statements can be optimized. @@ -324,15 +359,174 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { llvm::AtomicOrdering::SequentiallyConsistent); } - llvm::Value *real_type_atomic(AtomicOpStmt *stmt) { + // A huge hack for supporting f16 atomic add/max/min! Borrowed from + // https://github.com/tensorflow/tensorflow/blob/470d58a83470f8ede3beaa584e6992bc71b7baa6/tensorflow/compiler/xla/service/gpu/ir_emitter.cc#L378-L490 + // The reason is that LLVM10 does not support generating atomicCAS for f16 on + // NVPTX backend. + // + // Implements atomic binary operations using atomic compare-and-swap + // (atomicCAS) as follows: + // 1. Reads the value from the memory pointed to by output_address and + // records it as old_output. + // 2. Uses old_output as one of the source operand to perform the binary + // operation and stores the result in new_output. + // 3. Calls atomicCAS which implements compare-and-swap as an atomic + // operation. In particular, atomicCAS reads the value from the memory + // pointed to by output_address, and compares the value with old_output. + // If the two values equal, new_output is written to the same memory + // location and true is returned to indicate that the atomic operation + // succeeds. Otherwise, the new value read from the memory is returned. In + // this case, the new value is copied to old_output, and steps 2. and 3. + // are repeated until atomicCAS succeeds. + // + // int32 is used for the atomicCAS operation. So atomicCAS reads and writes 32 + // bit values from the memory, which is larger than the memory size required + // by the original atomic binary operation. We mask off the last two bits of + // the output_address and use the result as an address to read the 32 bit + // values from the memory. + // + // This can avoid out of bound memory accesses, based on the assumption: + // All buffers are 4 byte aligned and have a size of 4N. + // + // The pseudo code is shown below. + // + // cas_new_output_address = alloca(32); + // cas_old_output_address = alloca(32); + // atomic_address = output_address & ((int64)(-4)); + // new_output_address = cas_new_output_address + (output_address & 3); + // + // *cas_old_output_address = *atomic_address; + // do { + // *cas_new_output_address = *cas_old_output_address; + // *new_output_address = operation(*new_output_address, *source_address); + // (*cas_old_output_address, success) = + // atomicCAS(atomic_address, *cas_old_output_address, + // *cas_new_output_address); + // } while (!success); + // + // TODO(sjwsl): Try to rewrite this after upgrading LLVM or supporting raw + // NVPTX + + llvm::Value *atomic_op_using_cas( + llvm::Value *output_address, + llvm::Value *val, + std::function op) { + llvm::PointerType *output_address_type = + llvm::dyn_cast(output_address->getType()); + TI_ASSERT(output_address_type != nullptr); + + // element_type is the data type for the binary operation. + llvm::Type *element_type = output_address_type->getPointerElementType(); + llvm::Type *element_address_type = element_type->getPointerTo(); + + int atomic_size = 32; + llvm::Type *atomic_type = builder->getIntNTy(atomic_size); + llvm::Type *atomic_address_type = atomic_type->getPointerTo( + output_address_type->getPointerAddressSpace()); + + // cas_old_output_address and cas_new_output_address point to the scratch + // memory where we store the old and new values for the repeated atomicCAS + // operations. + llvm::Value *cas_old_output_address = + builder->CreateAlloca(atomic_type, nullptr); + llvm::Value *cas_new_output_address = + builder->CreateAlloca(atomic_type, nullptr); + + llvm::Value *atomic_memory_address; + // binop_output_address points to the scratch memory that stores the + // result of the binary operation. + llvm::Value *binop_output_address; + + // Calculate bin_output_address output_address + llvm::Type *address_int_type = + module->getDataLayout().getIntPtrType(output_address_type); + atomic_memory_address = + builder->CreatePtrToInt(output_address, address_int_type); + llvm::Value *mask = llvm::ConstantInt::get(address_int_type, 3); + llvm::Value *offset = builder->CreateAnd(atomic_memory_address, mask); + mask = llvm::ConstantInt::get(address_int_type, -4); + atomic_memory_address = builder->CreateAnd(atomic_memory_address, mask); + atomic_memory_address = + builder->CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = builder->CreateAdd( + builder->CreatePtrToInt(cas_new_output_address, address_int_type), + offset); + binop_output_address = + builder->CreateIntToPtr(binop_output_address, element_address_type); + + // Use the value from the memory that atomicCAS operates on to initialize + // cas_old_output. + llvm::Value *cas_old_output = + builder->CreateLoad(atomic_memory_address, "cas_old_output"); + builder->CreateStore(cas_old_output, cas_old_output_address); + + llvm::BasicBlock *loop_body_bb = + BasicBlock::Create(*llvm_context, "atomic_op_loop_body", func); + llvm::BasicBlock *loop_exit_bb = + BasicBlock::Create(*llvm_context, "loop_exit_bb", func); + builder->CreateBr(loop_body_bb); + builder->SetInsertPoint(loop_body_bb); + + // loop body for one atomicCAS + { + // Use cas_old_output to initialize cas_new_output. + cas_old_output = + builder->CreateLoad(cas_old_output_address, "cas_old_output"); + builder->CreateStore(cas_old_output, cas_new_output_address); + + auto binop_output = op(builder->CreateLoad(binop_output_address), val); + builder->CreateStore(binop_output, binop_output_address); + + llvm::Value *cas_new_output = + builder->CreateLoad(cas_new_output_address, "cas_new_output"); + + // Emit code to perform the atomicCAS operation + // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, + // cas_new_output); + llvm::Value *ret_value = builder->CreateAtomicCmpXchg( + atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); + + // Extract the memory value returned from atomicCAS and store it as + // cas_old_output. + builder->CreateStore( + builder->CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); + // Extract the success bit returned from atomicCAS and generate a + // conditional branch on the success bit. + builder->CreateCondBr( + builder->CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, + loop_body_bb); + } + + builder->SetInsertPoint(loop_exit_bb); + + return output_address; + } + + llvm::Value *real_or_unsigned_type_atomic(AtomicOpStmt *stmt) { if (!stmt->val->ret_type->is()) { return nullptr; } AtomicOpType op = stmt->op_type; - if (is_real(stmt->val->ret_type) && op == AtomicOpType::add) { - return builder->CreateAtomicRMW(llvm::AtomicRMWInst::FAdd, - llvm_val[stmt->dest], llvm_val[stmt->val], - AtomicOrdering::SequentiallyConsistent); + if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) { + switch (op) { + case AtomicOpType::add: + return atomic_op_using_cas( + llvm_val[stmt->dest], llvm_val[stmt->val], + [&](auto v1, auto v2) { return builder->CreateFAdd(v1, v2); }); + case AtomicOpType::max: + return atomic_op_using_cas( + llvm_val[stmt->dest], llvm_val[stmt->val], + [&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); }); + case AtomicOpType::min: + return atomic_op_using_cas( + llvm_val[stmt->dest], llvm_val[stmt->val], + [&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); }); + default: + break; + } } PrimitiveTypeID prim_type = @@ -342,21 +536,28 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { std::unordered_map> atomics; + atomics[PrimitiveTypeID::f32][AtomicOpType::add] = "atomic_add_f32"; + atomics[PrimitiveTypeID::f64][AtomicOpType::add] = "atomic_add_f64"; atomics[PrimitiveTypeID::f32][AtomicOpType::min] = "atomic_min_f32"; atomics[PrimitiveTypeID::f64][AtomicOpType::min] = "atomic_min_f64"; atomics[PrimitiveTypeID::f32][AtomicOpType::max] = "atomic_max_f32"; atomics[PrimitiveTypeID::f64][AtomicOpType::max] = "atomic_max_f64"; + atomics[PrimitiveTypeID::u32][AtomicOpType::min] = "atomic_min_u32"; + atomics[PrimitiveTypeID::u64][AtomicOpType::min] = "atomic_min_u64"; + atomics[PrimitiveTypeID::u32][AtomicOpType::max] = "atomic_max_u32"; + atomics[PrimitiveTypeID::u64][AtomicOpType::max] = "atomic_max_u64"; if (atomics.find(prim_type) == atomics.end()) { return nullptr; } + if (is_integral(stmt->val->ret_type) && + atomics.at(prim_type).find(op) == atomics.at(prim_type).end()) { + return nullptr; + } TI_ASSERT(atomics.at(prim_type).find(op) != atomics.at(prim_type).end()); - return builder->CreateCall( - get_runtime_function(atomics.at(prim_type).at(op)), - {llvm_val[stmt->dest], llvm_val[stmt->val]}); - - return nullptr; + return create_call(atomics.at(prim_type).at(op), + {llvm_val[stmt->dest], llvm_val[stmt->val]}); } void visit(AtomicOpStmt *stmt) override { @@ -373,9 +574,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { old_value = result; } else if (llvm::Value *result = custom_type_atomic(stmt)) { old_value = result; - } else if (llvm::Value *result = integral_type_atomic(stmt)) { + } else if (llvm::Value *result = real_or_unsigned_type_atomic(stmt)) { old_value = result; - } else if (llvm::Value *result = real_type_atomic(stmt)) { + } else if (llvm::Value *result = integral_type_atomic(stmt)) { old_value = result; } else { TI_NOT_IMPLEMENTED @@ -394,7 +595,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { llvm::Function *body; { auto guard = get_function_creation_guard( - {llvm::PointerType::get(get_runtime_type("Context"), 0), + {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), get_tls_buffer_type(), tlctx->get_data_type()}); auto loop_var = create_entry_block_alloca(PrimitiveType::i32); @@ -413,6 +614,79 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { tlctx->get_constant(stmt->tls_size)}); } + void create_offload_mesh_for(OffloadedStmt *stmt) override { + auto tls_prologue = create_mesh_xlogue(stmt->tls_prologue); + + llvm::Function *body; + { + auto guard = get_function_creation_guard( + {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), + get_tls_buffer_type(), tlctx->get_data_type()}); + + for (int i = 0; i < stmt->mesh_prologue->size(); i++) { + auto &s = stmt->mesh_prologue->statements[i]; + s->accept(this); + } + + if (stmt->bls_prologue) { + stmt->bls_prologue->accept(this); + call("block_barrier"); // "__syncthreads()" + } + + auto loop_test_bb = + llvm::BasicBlock::Create(*llvm_context, "loop_test", func); + auto loop_body_bb = + llvm::BasicBlock::Create(*llvm_context, "loop_body", func); + auto func_exit = + llvm::BasicBlock::Create(*llvm_context, "func_exit", func); + auto loop_index = + create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context)); + llvm::Value *thread_idx = + builder->CreateIntrinsic(Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}); + llvm::Value *block_dim = builder->CreateIntrinsic( + Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}); + builder->CreateStore(thread_idx, loop_index); + builder->CreateBr(loop_test_bb); + + { + builder->SetInsertPoint(loop_test_bb); + auto cond = builder->CreateICmp( + llvm::CmpInst::Predicate::ICMP_SLT, builder->CreateLoad(loop_index), + llvm_val[stmt->owned_num_local.find(stmt->major_from_type) + ->second]); + builder->CreateCondBr(cond, loop_body_bb, func_exit); + } + + { + builder->SetInsertPoint(loop_body_bb); + loop_vars_llvm[stmt].push_back(loop_index); + for (int i = 0; i < stmt->body->size(); i++) { + auto &s = stmt->body->statements[i]; + s->accept(this); + } + builder->CreateStore( + builder->CreateAdd(builder->CreateLoad(loop_index), block_dim), + loop_index); + builder->CreateBr(loop_test_bb); + builder->SetInsertPoint(func_exit); + } + + if (stmt->bls_epilogue) { + call("block_barrier"); // "__syncthreads()" + stmt->bls_epilogue->accept(this); + } + + body = guard.body; + } + + auto tls_epilogue = create_mesh_xlogue(stmt->tls_epilogue); + + create_call( + "gpu_parallel_mesh_for", + {get_arg(0), tlctx->get_constant(stmt->mesh->num_patches), tls_prologue, + body, tls_epilogue, tlctx->get_constant(stmt->tls_size)}); + } + void emit_cuda_gc(OffloadedStmt *stmt) { auto snode_id = tlctx->get_constant(stmt->snode->id); { @@ -487,7 +761,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype)); auto data = create_intrinsic_load(dtype, data_ptr); llvm_val[stmt] = extract_custom_int(data, bit_offset, int_in_mem); - } else if (auto cft = val_type->cast()) { + } else if (val_type->cast()) { // TODO: support __ldg llvm_val[stmt] = load_custom_float(stmt->src); } else { @@ -536,6 +810,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { create_offload_range_for(stmt); } else if (stmt->task_type == Type::struct_for) { create_offload_struct_for(stmt, true); + } else if (stmt->task_type == Type::mesh_for) { + create_offload_mesh_for(stmt); } else if (stmt->task_type == Type::listgen) { emit_list_gen(stmt); } else { @@ -543,6 +819,26 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { } finalize_offloaded_task_function(); current_task->grid_dim = stmt->grid_dim; + if (stmt->task_type == Type::range_for) { + if (stmt->const_begin && stmt->const_end) { + int num_threads = stmt->end_value - stmt->begin_value; + int grid_dim = ((num_threads % stmt->block_dim) == 0) + ? (num_threads / stmt->block_dim) + : (num_threads / stmt->block_dim) + 1; + grid_dim = std::max(grid_dim, 1); + current_task->grid_dim = std::min(stmt->grid_dim, grid_dim); + } + } + if (stmt->task_type == Type::listgen) { + int query_max_block_per_sm; + CUDADriver::get_instance().device_get_attribute( + &query_max_block_per_sm, + CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR, nullptr); + int num_SMs; + CUDADriver::get_instance().device_get_attribute( + &num_SMs, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, nullptr); + current_task->grid_dim = num_SMs * query_max_block_per_sm; + } current_task->block_dim = stmt->block_dim; TI_ASSERT(current_task->grid_dim != 0); TI_ASSERT(current_task->block_dim != 0); @@ -555,13 +851,75 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { #endif } + void visit(ExternalFuncCallStmt *stmt) override { + if (stmt->type == ExternalFuncCallStmt::BITCODE) { + CodeGenLLVM::visit_call_bitcode(stmt); + } else { + TI_NOT_IMPLEMENTED + } + } + void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { const auto arg_id = stmt->arg_id; const auto axis = stmt->axis; - llvm_val[stmt] = - builder->CreateCall(get_runtime_function("Context_get_extra_args"), - {get_context(), tlctx->get_constant(arg_id), - tlctx->get_constant(axis)}); + llvm_val[stmt] = create_call("RuntimeContext_get_extra_args", + {get_context(), tlctx->get_constant(arg_id), + tlctx->get_constant(axis)}); + } + + void visit(BinaryOpStmt *stmt) override { + auto op = stmt->op_type; + if (op != BinaryOpType::atan2 && op != BinaryOpType::pow) { + return CodeGenLLVM::visit(stmt); + } + + auto ret_type = stmt->ret_type; + + llvm::Value *lhs = llvm_val[stmt->lhs]; + llvm::Value *rhs = llvm_val[stmt->rhs]; + + // This branch contains atan2 and pow which use runtime.cpp function for + // **real** type. We don't have f16 support there so promoting to f32 is + // necessary. + if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { + lhs = builder->CreateFPExt(lhs, llvm::Type::getFloatTy(*llvm_context)); + } + if (stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { + rhs = builder->CreateFPExt(rhs, llvm::Type::getFloatTy(*llvm_context)); + } + if (ret_type->is_primitive(PrimitiveTypeID::f16)) { + ret_type = PrimitiveType::f32; + } + + if (op == BinaryOpType::atan2) { + if (ret_type->is_primitive(PrimitiveTypeID::f32)) { + llvm_val[stmt] = create_call("__nv_atan2f", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { + llvm_val[stmt] = create_call("__nv_atan2", {lhs, rhs}); + } else { + TI_P(data_type_name(ret_type)); + TI_NOT_IMPLEMENTED + } + } else { + if (ret_type->is_primitive(PrimitiveTypeID::f32)) { + llvm_val[stmt] = create_call("__nv_powf", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { + llvm_val[stmt] = create_call("__nv_pow", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { + llvm_val[stmt] = create_call("pow_i32", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::i64)) { + llvm_val[stmt] = create_call("pow_i64", {lhs, rhs}); + } else { + TI_P(data_type_name(ret_type)); + TI_NOT_IMPLEMENTED + } + } + + // Convert back to f16 if applicable. + if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); + } } }; diff --git a/taichi/backends/cuda/codegen_cuda.h b/taichi/backends/cuda/codegen_cuda.h index cac00c24f6f39..0d4eec87e4b26 100644 --- a/taichi/backends/cuda/codegen_cuda.h +++ b/taichi/backends/cuda/codegen_cuda.h @@ -12,7 +12,7 @@ class CodeGenCUDA : public KernelCodeGen { : KernelCodeGen(kernel, ir) { } - virtual FunctionType codegen() override; + FunctionType codegen() override; }; TLANG_NAMESPACE_END diff --git a/taichi/backends/cuda/cuda_caching_allocator.cpp b/taichi/backends/cuda/cuda_caching_allocator.cpp new file mode 100644 index 0000000000000..cb037af88f14b --- /dev/null +++ b/taichi/backends/cuda/cuda_caching_allocator.cpp @@ -0,0 +1,40 @@ +#include "taichi/backends/cuda/cuda_caching_allocator.h" + +namespace taichi { +namespace lang { +namespace cuda { + +CudaCachingAllocator::CudaCachingAllocator(LlvmDevice *device) + : device_(device) { +} + +uint64_t *CudaCachingAllocator::allocate( + const LlvmDevice::LlvmRuntimeAllocParams ¶ms) { + uint64_t *ret{nullptr}; + auto size_aligned = taichi::iroundup(params.size, taichi_page_size); + auto it_blk = mem_blocks_.lower_bound(size_aligned); + + if (it_blk != mem_blocks_.end()) { + size_t remaining_sz = it_blk->first - size_aligned; + if (remaining_sz > 0) { + TI_ASSERT(remaining_sz % taichi_page_size == 0); + auto remaining_head = + reinterpret_cast(it_blk->second) + size_aligned; + mem_blocks_.insert( + {remaining_sz, reinterpret_cast(remaining_head)}); + } + ret = it_blk->second; + mem_blocks_.erase(it_blk); + } else { + ret = device_->allocate_llvm_runtime_memory_jit(params); + } + return ret; +} + +void CudaCachingAllocator::release(size_t sz, uint64_t *ptr) { + mem_blocks_.insert({sz, ptr}); +} + +} // namespace cuda +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/cuda/cuda_caching_allocator.h b/taichi/backends/cuda/cuda_caching_allocator.h new file mode 100644 index 0000000000000..14af5a493792f --- /dev/null +++ b/taichi/backends/cuda/cuda_caching_allocator.h @@ -0,0 +1,27 @@ +#pragma once + +#include "taichi/common/core.h" +#include "taichi/math/arithmetic.h" +#include "taichi/llvm/llvm_device.h" +#include +#include + +namespace taichi { +namespace lang { +namespace cuda { + +class CudaCachingAllocator { + public: + CudaCachingAllocator(LlvmDevice *device); + + uint64_t *allocate(const LlvmDevice::LlvmRuntimeAllocParams ¶ms); + void release(size_t sz, uint64_t *ptr); + + private: + std::multimap mem_blocks_; + LlvmDevice *device_{nullptr}; +}; + +} // namespace cuda +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/cuda/cuda_context.h b/taichi/backends/cuda/cuda_context.h index c4e33f98e42ea..69a02adf6f082 100644 --- a/taichi/backends/cuda/cuda_context.h +++ b/taichi/backends/cuda/cuda_context.h @@ -95,8 +95,8 @@ class CUDAContext { return ContextGuard(this); } - std::lock_guard &&get_lock_guard() { - return std::move(std::lock_guard(lock_)); + std::unique_lock get_lock_guard() { + return std::unique_lock(lock_); } static CUDAContext &get_instance(); diff --git a/taichi/backends/cuda/cuda_device.cpp b/taichi/backends/cuda/cuda_device.cpp index c939a7c139d6e..6c4c0fed720f0 100644 --- a/taichi/backends/cuda/cuda_device.cpp +++ b/taichi/backends/cuda/cuda_device.cpp @@ -5,7 +5,8 @@ namespace lang { namespace cuda { -CudaDevice::AllocInfo CudaDevice::get_alloc_info(DeviceAllocation handle) { +CudaDevice::AllocInfo CudaDevice::get_alloc_info( + const DeviceAllocation handle) { validate_device_alloc(handle); return allocations_[handle.alloc_id]; } @@ -22,6 +23,35 @@ DeviceAllocation CudaDevice::allocate_memory(const AllocParams ¶ms) { info.size = params.size; info.is_imported = false; + info.use_cached = false; + info.use_preallocated = false; + + DeviceAllocation alloc; + alloc.alloc_id = allocations_.size(); + alloc.device = this; + + allocations_.push_back(info); + return alloc; +} + +DeviceAllocation CudaDevice::allocate_memory_runtime( + const LlvmRuntimeAllocParams ¶ms) { + AllocInfo info; + info.size = taichi::iroundup(params.size, taichi_page_size); + if (params.host_read || params.host_write) { + TI_NOT_IMPLEMENTED + } else if (params.use_cached) { + if (caching_allocator_ == nullptr) { + caching_allocator_ = std::make_unique(this); + } + info.ptr = caching_allocator_->allocate(params); + CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size); + } else { + info.ptr = allocate_llvm_runtime_memory_jit(params); + } + info.is_imported = false; + info.use_cached = params.use_cached; + info.use_preallocated = true; DeviceAllocation alloc; alloc.alloc_id = allocations_.size(); @@ -38,8 +68,15 @@ void CudaDevice::dealloc_memory(DeviceAllocation handle) { TI_ERROR("the DeviceAllocation is already deallocated"); } TI_ASSERT(!info.is_imported); - CUDADriver::get_instance().mem_free(info.ptr); - info.ptr = nullptr; + if (info.use_cached) { + if (caching_allocator_ == nullptr) { + TI_ERROR("the CudaCachingAllocator is not initialized"); + } + caching_allocator_->release(info.size, (uint64_t *)info.ptr); + } else if (!info.use_preallocated) { + CUDADriver::get_instance().mem_free(info.ptr); + info.ptr = nullptr; + } } DeviceAllocation CudaDevice::import_memory(void *ptr, size_t size) { @@ -56,6 +93,13 @@ DeviceAllocation CudaDevice::import_memory(void *ptr, size_t size) { return alloc; } +uint64 CudaDevice::fetch_result_uint64(int i, uint64 *result_buffer) { + CUDADriver::get_instance().stream_synchronize(nullptr); + uint64 ret; + CUDADriver::get_instance().memcpy_device_to_host(&ret, result_buffer + i, + sizeof(uint64)); + return ret; +} } // namespace cuda } // namespace lang } // namespace taichi diff --git a/taichi/backends/cuda/cuda_device.h b/taichi/backends/cuda/cuda_device.h index 8215028526557..039c17b012061 100644 --- a/taichi/backends/cuda/cuda_device.h +++ b/taichi/backends/cuda/cuda_device.h @@ -4,8 +4,9 @@ #include "taichi/common/core.h" #include "taichi/backends/cuda/cuda_driver.h" +#include "taichi/backends/cuda/cuda_caching_allocator.h" #include "taichi/backends/cuda/cuda_context.h" -#include "taichi/backends/device.h" +#include "taichi/llvm/llvm_device.h" namespace taichi { namespace lang { @@ -74,25 +75,41 @@ class CudaStream : public Stream { void command_sync() override{TI_NOT_IMPLEMENTED}; }; -class CudaDevice : public Device { +class CudaDevice : public LlvmDevice { public: struct AllocInfo { void *ptr{nullptr}; size_t size{0}; bool is_imported{false}; + /* Note: Memory allocation in CUDA device. + * CudaDevice can use either its own cuda malloc mechanism via + * `allocate_memory` or the preallocated memory managed by Llvmprogramimpl + * via `allocate_memory_runtime`. The `use_preallocated` is used to track + * this option. For now, we keep both options and the preallocated method is + * used by default for CUDA backend. The `use_cached` is to enable/disable + * the caching behavior in `allocate_memory_runtime`. Later it should be + * always enabled, for now we keep both options to allow a scenario when + * using preallocated memory while disabling the caching behavior. + * */ + bool use_preallocated{true}; + bool use_cached{false}; }; - AllocInfo get_alloc_info(DeviceAllocation handle); + AllocInfo get_alloc_info(const DeviceAllocation handle); ~CudaDevice() override{}; DeviceAllocation allocate_memory(const AllocParams ¶ms) override; + DeviceAllocation allocate_memory_runtime( + const LlvmRuntimeAllocParams ¶ms) override; void dealloc_memory(DeviceAllocation handle) override; std::unique_ptr create_pipeline( const PipelineSourceDesc &src, std::string name = "Pipeline") override{TI_NOT_IMPLEMENTED}; + uint64 fetch_result_uint64(int i, uint64 *result_buffer) override; + void *map_range(DevicePtr ptr, uint64_t size) override{TI_NOT_IMPLEMENTED}; void *map(DeviceAllocation alloc) override{TI_NOT_IMPLEMENTED}; @@ -108,11 +125,12 @@ class CudaDevice : public Device { private: std::vector allocations_; - void validate_device_alloc(DeviceAllocation alloc) { + void validate_device_alloc(const DeviceAllocation alloc) { if (allocations_.size() <= alloc.alloc_id) { TI_ERROR("invalid DeviceAllocation"); } } + std::unique_ptr caching_allocator_{nullptr}; }; } // namespace cuda diff --git a/taichi/backends/cuda/cuda_driver.cpp b/taichi/backends/cuda/cuda_driver.cpp index 7bfef397a446e..e01e1c0fe5bf2 100644 --- a/taichi/backends/cuda/cuda_driver.cpp +++ b/taichi/backends/cuda/cuda_driver.cpp @@ -16,12 +16,16 @@ std::string get_cuda_error_message(uint32 err) { } bool CUDADriver::detected() { - if (get_environ_config("TI_ENABLE_CUDA", 1) == 0) - return false; - return loader_->loaded(); + return !disabled_by_env_ && cuda_version_valid_ && loader_->loaded(); } CUDADriver::CUDADriver() { + disabled_by_env_ = (get_environ_config("TI_ENABLE_CUDA", 1) == 0); + if (disabled_by_env_) { + TI_TRACE("CUDA driver disabled by enviroment variable \"TI_ENABLE_CUDA\"."); + return; + } + #if defined(TI_PLATFORM_LINUX) loader_ = std::make_unique("libcuda.so"); #elif defined(TI_PLATFORM_WINDOWS) @@ -30,25 +34,34 @@ CUDADriver::CUDADriver() { static_assert(false, "Taichi CUDA driver supports only Windows and Linux."); #endif - if (detected()) { - loader_->load_function("cuGetErrorName", get_error_name); - loader_->load_function("cuGetErrorString", get_error_string); + if (!loader_->loaded()) { + TI_WARN("CUDA driver not found."); + return; + } + + loader_->load_function("cuGetErrorName", get_error_name); + loader_->load_function("cuGetErrorString", get_error_string); + loader_->load_function("cuDriverGetVersion", driver_get_version); + + int version; + driver_get_version(&version); + TI_TRACE("CUDA driver API (v{}.{}) loaded.", version / 1000, + version % 1000 / 10); + // CUDA versions should >= 10. + if (version < 10000) { + TI_WARN("The Taichi CUDA backend requires at least CUDA 10.0, got v{}.{}.", + version / 1000, version % 1000 / 10); + return; + } + + cuda_version_valid_ = true; #define PER_CUDA_FUNCTION(name, symbol_name, ...) \ name.set(loader_->load_function(#symbol_name)); \ name.set_lock(&lock_); \ name.set_names(#name, #symbol_name); #include "taichi/backends/cuda/cuda_driver_functions.inc.h" #undef PER_CUDA_FUNCTION - - int version; - driver_get_version(&version); - - TI_TRACE("CUDA driver API (v{}.{}) loaded.", version / 1000, - version % 1000 / 10); - } else { - TI_TRACE("CUDA driver not found."); - } } // This is for initializing the CUDA driver itself diff --git a/taichi/backends/cuda/cuda_driver.h b/taichi/backends/cuda/cuda_driver.h index 769919b8100b4..40e3eff765c63 100644 --- a/taichi/backends/cuda/cuda_driver.h +++ b/taichi/backends/cuda/cuda_driver.h @@ -30,6 +30,7 @@ constexpr uint32 CU_STREAM_NON_BLOCKING = 0x1; constexpr uint32 CU_MEM_ATTACH_GLOBAL = 0x1; constexpr uint32 CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3; constexpr uint32 CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2; +constexpr uint32 CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106; constexpr uint32 CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16; constexpr uint32 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75; constexpr uint32 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76; @@ -105,6 +106,8 @@ class CUDADriver { void (*get_error_string)(uint32, const char **); + void (*driver_get_version)(int *); + bool detected(); ~CUDADriver() = default; @@ -119,6 +122,9 @@ class CUDADriver { std::unique_ptr loader_; std::mutex lock_; + + bool disabled_by_env_{false}; + bool cuda_version_valid_{false}; }; TLANG_NAMESPACE_END diff --git a/taichi/backends/cuda/cuda_driver_functions.inc.h b/taichi/backends/cuda/cuda_driver_functions.inc.h index bbf1318f36ae9..25149a9a3d400 100644 --- a/taichi/backends/cuda/cuda_driver_functions.inc.h +++ b/taichi/backends/cuda/cuda_driver_functions.inc.h @@ -2,7 +2,6 @@ // Driver PER_CUDA_FUNCTION(init, cuInit, int); -PER_CUDA_FUNCTION(driver_get_version, cuDriverGetVersion, int*); // Device management PER_CUDA_FUNCTION(device_get_count, cuDeviceGetCount, int *); @@ -28,6 +27,7 @@ PER_CUDA_FUNCTION(memcpy_device_to_host_async, cuMemcpyDtoHAsync_v2, void *, voi PER_CUDA_FUNCTION(malloc, cuMemAlloc_v2, void **, std::size_t); PER_CUDA_FUNCTION(malloc_managed, cuMemAllocManaged, void **, std::size_t, uint32); PER_CUDA_FUNCTION(memset, cuMemsetD8_v2, void *, uint8, std::size_t); +PER_CUDA_FUNCTION(memsetd32, cuMemsetD32_v2, void *, uint32, std::size_t); PER_CUDA_FUNCTION(mem_free, cuMemFree_v2, void *); PER_CUDA_FUNCTION(mem_advise, cuMemAdvise, void *, std::size_t, uint32, uint32); PER_CUDA_FUNCTION(mem_get_info, cuMemGetInfo_v2, std::size_t *, std::size_t *); diff --git a/taichi/backends/cuda/cuda_profiler.cpp b/taichi/backends/cuda/cuda_profiler.cpp index d27faa4079341..8fb65b2b5f8e6 100644 --- a/taichi/backends/cuda/cuda_profiler.cpp +++ b/taichi/backends/cuda/cuda_profiler.cpp @@ -11,22 +11,53 @@ TLANG_NAMESPACE_BEGIN // will not affect default toolkit (cuEvent) KernelProfilerCUDA::KernelProfilerCUDA(bool enable) { metric_list_.clear(); - if (enable) { + if (enable) { // default profiling toolkit: event tool_ = ProfilingToolkit::event; + event_toolkit_ = std::make_unique(); + } +} + +ProfilingToolkit get_toolkit_enum(std::string toolkit_name) { + if (toolkit_name.compare("default") == 0) + return ProfilingToolkit::event; + else if (toolkit_name.compare("cupti") == 0) + return ProfilingToolkit::cupti; + else + return ProfilingToolkit::undef; +} + +bool KernelProfilerCUDA::set_profiler_toolkit(std::string toolkit_name) { + sync(); + ProfilingToolkit set_toolkit = get_toolkit_enum(toolkit_name); + TI_TRACE("profiler toolkit enum = {} >>> {}", tool_, set_toolkit); + if (set_toolkit == tool_) + return true; + + // current toolkit is CUPTI: disable + if (tool_ == ProfilingToolkit::cupti) { + cupti_toolkit_->end_profiling(); + cupti_toolkit_->deinit_cupti(); + cupti_toolkit_->set_status(false); + tool_ = ProfilingToolkit::event; + TI_TRACE("cupti >>> event ... DONE"); + return true; + } + // current toolkit is cuEvent: check CUPTI availability + else if (tool_ == ProfilingToolkit::event) { #if defined(TI_WITH_CUDA_TOOLKIT) - // if Taichi was compiled with CUDA toolit, then use CUPTI - // TODO : add set_mode() to select toolkit by user - if (check_cupti_availability() && check_cupti_privileges()) + if (check_cupti_availability() && check_cupti_privileges()) { + if (cupti_toolkit_ == nullptr) + cupti_toolkit_ = std::make_unique(); + cupti_toolkit_->init_cupti(); + cupti_toolkit_->begin_profiling(); tool_ = ProfilingToolkit::cupti; + cupti_toolkit_->set_status(true); + TI_TRACE("event >>> cupti ... DONE"); + return true; + } #endif } - if (tool_ == ProfilingToolkit::event) { - event_toolkit_ = std::make_unique(); - } else if (tool_ == ProfilingToolkit::cupti) { - cupti_toolkit_ = std::make_unique(); - cupti_toolkit_->init_cupti(); - cupti_toolkit_->begin_profiling(); - } + return false; } std::string KernelProfilerCUDA::get_device_name() { @@ -54,6 +85,8 @@ bool KernelProfilerCUDA::reinit_with_metrics( metric_list_.size()); return true; } + + TI_NOT_IMPLEMENTED; } // deprecated, move to trace() diff --git a/taichi/backends/cuda/cuda_profiler.h b/taichi/backends/cuda/cuda_profiler.h index 8175caaebe7fb..8c97f3b6715ef 100644 --- a/taichi/backends/cuda/cuda_profiler.h +++ b/taichi/backends/cuda/cuda_profiler.h @@ -37,6 +37,8 @@ class KernelProfilerCUDA : public KernelProfilerBase { void clear() override; void stop(KernelProfilerBase::TaskHandle handle) override; + bool set_profiler_toolkit(std::string toolkit_name) override; + bool statistics_on_traced_records(); KernelProfilerBase::TaskHandle start_with_handle( @@ -49,11 +51,11 @@ class KernelProfilerCUDA : public KernelProfilerBase { private: ProfilingToolkit tool_ = ProfilingToolkit::undef; + + // Instances of these toolkits may exist at the same time, + // but only one will be enabled. std::unique_ptr event_toolkit_{nullptr}; - // if(tool_ == ProfilingToolkit::cupti) event_toolkit_ = nullptr std::unique_ptr cupti_toolkit_{nullptr}; - // if(tool_ == ProfilingToolkit::event) cupti_toolkit_ = nullptr - // TODO : switch profiling toolkit at runtime std::vector metric_list_; uint32_t records_size_after_sync_{0}; }; diff --git a/taichi/backends/cuda/cupti_toolkit.cpp b/taichi/backends/cuda/cupti_toolkit.cpp index ebd24f4686c7c..5c4244fba6222 100644 --- a/taichi/backends/cuda/cupti_toolkit.cpp +++ b/taichi/backends/cuda/cupti_toolkit.cpp @@ -22,7 +22,7 @@ enum class CuptiMetricsDefault : uint { CUPTI_METRIC_DEFAULT_TOTAL = 2 }; -constexpr char *MetricListDeafult[] = { +[[maybe_unused]] constexpr const char *MetricListDefault[] = { "smsp__cycles_elapsed.avg", // CUPTI_METRIC_KERNEL_ELAPSED_CLK_NUMS "smsp__cycles_elapsed.avg.per_second", // CUPTI_METRIC_CORE_FREQUENCY_HZS }; @@ -798,13 +798,20 @@ CuptiToolkit::CuptiToolkit() { uint metric_list_size = static_cast(CuptiMetricsDefault::CUPTI_METRIC_DEFAULT_TOTAL); for (uint idx = 0; idx < metric_list_size; idx++) { - cupti_config_.metric_list.push_back(MetricListDeafult[idx]); + cupti_config_.metric_list.push_back(MetricListDefault[idx]); } + set_status(true); } CuptiToolkit::~CuptiToolkit() { - end_profiling(); - deinit_cupti(); + if (enabled_) { + end_profiling(); + deinit_cupti(); + } +} + +void CuptiToolkit::set_status(bool enable) { + enabled_ = enable; } void CuptiToolkit::reset_metrics(const std::vector &metrics) { @@ -812,7 +819,7 @@ void CuptiToolkit::reset_metrics(const std::vector &metrics) { uint metric_list_size = static_cast(CuptiMetricsDefault::CUPTI_METRIC_DEFAULT_TOTAL); for (uint idx = 0; idx < metric_list_size; idx++) { - cupti_config_.metric_list.push_back(MetricListDeafult[idx]); + cupti_config_.metric_list.push_back(MetricListDefault[idx]); } // user selected metrics for (auto metric : metrics) @@ -1099,6 +1106,9 @@ CuptiToolkit::CuptiToolkit() { CuptiToolkit::~CuptiToolkit() { TI_NOT_IMPLEMENTED; } +void CuptiToolkit::set_status(bool enable) { + TI_NOT_IMPLEMENTED; +} void CuptiToolkit::reset_metrics(const std::vector &metrics) { TI_NOT_IMPLEMENTED; } diff --git a/taichi/backends/cuda/cupti_toolkit.h b/taichi/backends/cuda/cupti_toolkit.h index 775b005637438..98a5c3309730f 100644 --- a/taichi/backends/cuda/cupti_toolkit.h +++ b/taichi/backends/cuda/cupti_toolkit.h @@ -6,7 +6,7 @@ TLANG_NAMESPACE_BEGIN struct CuptiConfig { #if defined(TI_WITH_CUDA_TOOLKIT) - uint32_t num_ranges = 16384; // max number of kernels traced by CUPTI + uint32_t num_ranges = 1048576; // max number of kernels traced by CUPTI std::vector metric_list; #endif }; @@ -35,9 +35,11 @@ class CuptiToolkit { bool update_record(uint32_t records_size_after_sync, std::vector &traced_records); void reset_metrics(const std::vector &metrics); + void set_status(bool enable); private: - CuptiConfig cupti_config_; + [[maybe_unused]] bool enabled_{false}; + [[maybe_unused]] CuptiConfig cupti_config_; CuptiImage cupti_image_; }; diff --git a/taichi/backends/cuda/jit_cuda.cpp b/taichi/backends/cuda/jit_cuda.cpp index de27656a7917e..1e336f75dead2 100644 --- a/taichi/backends/cuda/jit_cuda.cpp +++ b/taichi/backends/cuda/jit_cuda.cpp @@ -1,4 +1,5 @@ #include "taichi/backends/cuda/jit_cuda.h" +#include "taichi/llvm/llvm_program.h" TLANG_NAMESPACE_BEGIN @@ -7,7 +8,7 @@ TLANG_NAMESPACE_BEGIN JITModule *JITSessionCUDA ::add_module(std::unique_ptr M, int max_reg) { auto ptx = compile_module_to_ptx(M); - if (get_current_program().config.print_kernel_nvptx) { + if (this->llvm_prog()->config->print_kernel_nvptx) { static FileSequenceWriter writer("taichi_kernel_nvptx_{:04d}.ptx", "module NVPTX"); writer.write(ptx); @@ -20,8 +21,7 @@ JITModule *JITSessionCUDA ::add_module(std::unique_ptr M, TI_TRACE("PTX size: {:.2f}KB", ptx.size() / 1024.0); auto t = Time::get_time(); TI_TRACE("Loading module..."); - [[maybe_unused]] auto &&_ = - std::move(CUDAContext::get_instance().get_lock_guard()); + [[maybe_unused]] auto _ = CUDAContext::get_instance().get_lock_guard(); constexpr int max_num_options = 8; int num_options = 0; @@ -82,7 +82,7 @@ std::string JITSessionCUDA::compile_module_to_ptx( using namespace llvm; - if (get_current_program().config.print_kernel_llvm_ir) { + if (this->llvm_prog()->config->print_kernel_llvm_ir) { static FileSequenceWriter writer("taichi_kernel_cuda_llvm_ir_{:04d}.ll", "unoptimized LLVM IR (CUDA)"); writer.write(module.get()); @@ -102,11 +102,9 @@ std::string JITSessionCUDA::compile_module_to_ptx( TargetRegistry::lookupTarget(triple.str(), err_str); TI_ERROR_UNLESS(target, err_str); - bool fast_math = get_current_program().config.fast_math; - TargetOptions options; options.PrintMachineCode = 0; - if (fast_math) { + if (this->llvm_prog()->config->fast_math) { options.AllowFPOpFusion = FPOpFusion::Fast; // See NVPTXISelLowering.cpp // Setting UnsafeFPMath true will result in approximations such as @@ -209,7 +207,7 @@ std::string JITSessionCUDA::compile_module_to_ptx( module_pass_manager.run(*module); } - if (get_current_program().config.print_kernel_llvm_ir_optimized) { + if (this->llvm_prog()->config->print_kernel_llvm_ir_optimized) { static FileSequenceWriter writer( "taichi_kernel_cuda_llvm_ir_optimized_{:04d}.ll", "optimized LLVM IR (CUDA)"); @@ -223,16 +221,20 @@ std::string JITSessionCUDA::compile_module_to_ptx( return buffer; } -std::unique_ptr create_llvm_jit_session_cuda(Arch arch) { +std::unique_ptr create_llvm_jit_session_cuda( + LlvmProgramImpl *llvm_prog, + Arch arch) { TI_ASSERT(arch == Arch::cuda); // https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#data-layout auto data_layout = llvm::DataLayout( "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-" "f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"); - return std::make_unique(data_layout); + return std::make_unique(llvm_prog, data_layout); } #else -std::unique_ptr create_llvm_jit_session_cuda(Arch arch) { +std::unique_ptr create_llvm_jit_session_cuda( + LlvmProgramImpl *llvm_prog, + Arch arch) { TI_NOT_IMPLEMENTED } #endif diff --git a/taichi/backends/cuda/jit_cuda.h b/taichi/backends/cuda/jit_cuda.h index 23118f2eef731..4c188add7763d 100644 --- a/taichi/backends/cuda/jit_cuda.h +++ b/taichi/backends/cuda/jit_cuda.h @@ -4,6 +4,7 @@ #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/IR/Module.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" @@ -35,10 +36,10 @@ TLANG_NAMESPACE_BEGIN #if defined(TI_WITH_CUDA) class JITModuleCUDA : public JITModule { private: - void *module; + void *module_; public: - explicit JITModuleCUDA(void *module) : module(module) { + explicit JITModuleCUDA(void *module) : module_(module) { } void *lookup_function(const std::string &name) override { @@ -48,7 +49,7 @@ class JITModuleCUDA : public JITModule { void *func = nullptr; auto t = Time::get_time(); auto err = CUDADriver::get_instance().module_get_function.call_with_warning( - &func, module, name.c_str()); + &func, module_, name.c_str()); if (err) { TI_ERROR("Cannot look up function {}", name); } @@ -63,11 +64,11 @@ class JITModuleCUDA : public JITModule { launch(name, 1, 1, 0, arg_pointers); } - virtual void launch(const std::string &name, - std::size_t grid_dim, - std::size_t block_dim, - std::size_t dynamic_shared_mem_bytes, - const std::vector &arg_pointers) override { + void launch(const std::string &name, + std::size_t grid_dim, + std::size_t block_dim, + std::size_t dynamic_shared_mem_bytes, + const std::vector &arg_pointers) override { auto func = lookup_function(name); CUDAContext::get_instance().launch(func, name, arg_pointers, grid_dim, block_dim, dynamic_shared_mem_bytes); @@ -82,23 +83,24 @@ class JITSessionCUDA : public JITSession { public: llvm::DataLayout data_layout; - explicit JITSessionCUDA(llvm::DataLayout data_layout) - : data_layout(data_layout) { + JITSessionCUDA(LlvmProgramImpl *llvm_prog, llvm::DataLayout data_layout) + : JITSession(llvm_prog), data_layout(data_layout) { } - virtual JITModule *add_module(std::unique_ptr M, - int max_reg) override; + JITModule *add_module(std::unique_ptr M, int max_reg) override; - virtual llvm::DataLayout get_data_layout() override { + llvm::DataLayout get_data_layout() override { return data_layout; } - static std::string compile_module_to_ptx( - std::unique_ptr &module); + private: + std::string compile_module_to_ptx(std::unique_ptr &module); }; #endif -std::unique_ptr create_llvm_jit_session_cuda(Arch arch); +std::unique_ptr create_llvm_jit_session_cuda( + LlvmProgramImpl *llvm_prog, + Arch arch); TLANG_NAMESPACE_END diff --git a/taichi/backends/cuda/runtime.cpp b/taichi/backends/cuda/runtime.cpp index 217077a109c21..7befd9d1163fa 100644 --- a/taichi/backends/cuda/runtime.cpp +++ b/taichi/backends/cuda/runtime.cpp @@ -26,7 +26,7 @@ class RuntimeCUDA : public Runtime { return CUDAContext::get_instance().detected(); } - ~RuntimeCUDA() { + ~RuntimeCUDA() override { } }; diff --git a/taichi/backends/device.cpp b/taichi/backends/device.cpp index dcbff2e1e3095..550f762b5ab4b 100644 --- a/taichi/backends/device.cpp +++ b/taichi/backends/device.cpp @@ -1,9 +1,11 @@ #include -#include #if TI_WITH_VULKAN #include #include +#if TI_WITH_LLVM +#include +#endif #if TI_WITH_CUDA #include #include @@ -29,11 +31,13 @@ Device::MemcpyCapability Device::check_memcpy_capability(DevicePtr dst, } #if TI_WITH_VULKAN +#if TI_WITH_LLVM if (dynamic_cast(dst.device) && dynamic_cast(src.device)) { // TODO: support direct copy if dst itself supports host write. return Device::MemcpyCapability::RequiresStagingBuffer; } +#endif #if TI_WITH_CUDA if (dynamic_cast(dst.device) && dynamic_cast(src.device)) { @@ -71,7 +75,7 @@ void Device::memcpy_via_staging(DevicePtr dst, DevicePtr src, uint64_t size) { // Inter-device copy -#if TI_WITH_VULKAN +#if defined(TI_WITH_VULKAN) && defined(TI_WITH_LLVM) if (dynamic_cast(dst.device) && dynamic_cast(src.device)) { memcpy_cpu_to_vulkan_via_staging(dst, staging, src, size); @@ -89,6 +93,53 @@ void Device::memcpy_via_host(DevicePtr dst, TI_NOT_IMPLEMENTED; } +void Device::print_all_cap() const { + const std::unordered_map names{ + {DeviceCapability::vk_api_version, "vk_api_version"}, + {DeviceCapability::vk_has_physical_features2, + "vk_has_physical_features2"}, + {DeviceCapability::vk_has_external_memory, "vk_has_external_memory"}, + {DeviceCapability::vk_has_surface, "vk_has_surface"}, + {DeviceCapability::vk_has_presentation, "vk_has_presentation"}, + {DeviceCapability::spirv_version, "spirv_version"}, + {DeviceCapability::spirv_has_int8, "spirv_has_int8"}, + {DeviceCapability::spirv_has_int16, "spirv_has_int16"}, + {DeviceCapability::spirv_has_int64, "spirv_has_int64"}, + {DeviceCapability::spirv_has_float16, "spirv_has_float16"}, + {DeviceCapability::spirv_has_float64, "spirv_has_float64"}, + {DeviceCapability::spirv_has_atomic_i64, "spirv_has_atomic_i64"}, + {DeviceCapability::spirv_has_atomic_float16, "spirv_has_atomic_float16"}, + {DeviceCapability::spirv_has_atomic_float16_add, + "spirv_has_atomic_float16_add"}, + {DeviceCapability::spirv_has_atomic_float16_minmax, + "spirv_has_atomic_float16_minmax"}, + {DeviceCapability::spirv_has_atomic_float, "spirv_has_atomic_float"}, + {DeviceCapability::spirv_has_atomic_float_add, + "spirv_has_atomic_float_add"}, + {DeviceCapability::spirv_has_atomic_float_minmax, + "spirv_has_atomic_float_minmax"}, + {DeviceCapability::spirv_has_atomic_float64, "spirv_has_atomic_float64"}, + {DeviceCapability::spirv_has_atomic_float64_add, + "spirv_has_atomic_float64_add"}, + {DeviceCapability::spirv_has_atomic_float64_minmax, + "spirv_has_atomic_float64_minmax"}, + {DeviceCapability::spirv_has_variable_ptr, "spirv_has_variable_ptr"}, + {DeviceCapability::spirv_has_physical_storage_buffer, + "spirv_has_physical_storage_buffer"}, + {DeviceCapability::spirv_has_subgroup_basic, "spirv_has_subgroup_basic"}, + {DeviceCapability::spirv_has_subgroup_vote, "spirv_has_subgroup_vote"}, + {DeviceCapability::spirv_has_subgroup_arithmetic, + "spirv_has_subgroup_arithmetic"}, + {DeviceCapability::spirv_has_subgroup_ballot, + "spirv_has_subgroup_ballot"}, + {DeviceCapability::wide_lines, "wide_lines"}, + }; + for (auto &pair : caps_) { + TI_TRACE("DeviceCapability::{} ({}) = {}", names.at(pair.first), + int(pair.first), pair.second); + } +} + void GraphicsDevice::image_transition(DeviceAllocation img, ImageLayout old_layout, ImageLayout new_layout) { diff --git a/taichi/backends/device.h b/taichi/backends/device.h index f9e5ff59f69b3..82b2788f43c64 100644 --- a/taichi/backends/device.h +++ b/taichi/backends/device.h @@ -1,6 +1,7 @@ #pragma once #include "taichi/lang_util.h" +#include "taichi/jit/jit_module.h" #include "taichi/program/compile_config.h" #include #include @@ -26,6 +27,9 @@ enum class DeviceCapability : uint32_t { spirv_has_float16, spirv_has_float64, spirv_has_atomic_i64, + spirv_has_atomic_float16, // load, store, exchange + spirv_has_atomic_float16_add, + spirv_has_atomic_float16_minmax, spirv_has_atomic_float, // load, store, exchange spirv_has_atomic_float_add, spirv_has_atomic_float_minmax, @@ -33,17 +37,27 @@ enum class DeviceCapability : uint32_t { spirv_has_atomic_float64_add, spirv_has_atomic_float64_minmax, spirv_has_variable_ptr, + spirv_has_physical_storage_buffer, + spirv_has_subgroup_basic, + spirv_has_subgroup_vote, + spirv_has_subgroup_arithmetic, + spirv_has_subgroup_ballot, + // Graphics Caps, + wide_lines }; class Device; struct DeviceAllocation; struct DevicePtr; +struct LLVMRuntime; // TODO: Figure out how to support images. Temporary solutions is to have all // opque types such as images work as an allocation -struct DeviceAllocation { +using DeviceAllocationId = uint32_t; + +struct TI_DLL_EXPORT DeviceAllocation { Device *device{nullptr}; - uint32_t alloc_id{0}; + DeviceAllocationId alloc_id{0}; DevicePtr get_ptr(uint64_t offset = 0) const; @@ -56,14 +70,14 @@ struct DeviceAllocation { } }; -struct DeviceAllocationGuard : public DeviceAllocation { +struct TI_DLL_EXPORT DeviceAllocationGuard : public DeviceAllocation { DeviceAllocationGuard(DeviceAllocation alloc) : DeviceAllocation(alloc) { } DeviceAllocationGuard(const DeviceAllocationGuard &) = delete; ~DeviceAllocationGuard(); }; -struct DevicePtr : public DeviceAllocation { +struct TI_DLL_EXPORT DevicePtr : public DeviceAllocation { uint64_t offset{0}; bool operator==(const DevicePtr &other) const { @@ -219,7 +233,8 @@ enum class ImageLayout { depth_attachment, depth_attachment_read, transfer_dst, - transfer_src + transfer_src, + present_src }; struct BufferImageCopyParams { @@ -240,6 +255,12 @@ struct BufferImageCopyParams { uint32_t image_layer_count{1}; }; +struct ImageCopyParams { + uint32_t width{1}; + uint32_t height{1}; + uint32_t depth{1}; +}; + class CommandList { public: virtual ~CommandList() { @@ -256,6 +277,17 @@ class CommandList { virtual void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) = 0; virtual void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) = 0; + struct ComputeSize { + uint32_t x{0}; + uint32_t y{0}; + uint32_t z{0}; + }; + // Some GPU APIs can set the block (workgroup, threadsgroup) size at + // dispatch time. + virtual void dispatch(ComputeSize grid_size, ComputeSize block_size) { + dispatch(grid_size.x, grid_size.y, grid_size.z); + } + // These are not implemented in compute only device virtual void begin_renderpass(int x0, int y0, @@ -303,6 +335,20 @@ class CommandList { const BufferImageCopyParams ¶ms) { TI_NOT_IMPLEMENTED } + virtual void copy_image(DeviceAllocation dst_img, + DeviceAllocation src_img, + ImageLayout dst_img_layout, + ImageLayout src_img_layout, + const ImageCopyParams ¶ms) { + TI_NOT_IMPLEMENTED + } + virtual void blit_image(DeviceAllocation dst_img, + DeviceAllocation src_img, + ImageLayout dst_img_layout, + ImageLayout src_img_layout, + const ImageCopyParams ¶ms) { + TI_NOT_IMPLEMENTED + } }; struct PipelineSourceDesc { @@ -341,16 +387,18 @@ class Device { public: virtual ~Device(){}; - virtual uint32_t get_cap(DeviceCapability capability_id) const { + uint32_t get_cap(DeviceCapability capability_id) const { if (caps_.find(capability_id) == caps_.end()) return 0; return caps_.at(capability_id); } - virtual void set_cap(DeviceCapability capability_id, uint32_t val) { + void set_cap(DeviceCapability capability_id, uint32_t val) { caps_[capability_id] = val; } + void print_all_cap() const; + struct AllocParams { uint64_t size{0}; bool host_write{false}; @@ -360,8 +408,13 @@ class Device { }; virtual DeviceAllocation allocate_memory(const AllocParams ¶ms) = 0; + virtual void dealloc_memory(DeviceAllocation handle) = 0; + virtual uint64_t get_memory_physical_pointer(DeviceAllocation handle) { + TI_NOT_IMPLEMENTED + } + virtual std::unique_ptr create_pipeline( const PipelineSourceDesc &src, std::string name = "Pipeline") = 0; @@ -372,6 +425,13 @@ class Device { this->allocate_memory(params)); } + virtual uint64 fetch_result_uint64(int i, uint64 *result_buffer) { + TI_NOT_IMPLEMENTED + } + + // Each thraed will acquire its own stream + virtual Stream *get_compute_stream() = 0; + // Mapping can fail and will return nullptr virtual void *map_range(DevicePtr ptr, uint64_t size) = 0; virtual void *map(DeviceAllocation alloc) = 0; @@ -404,9 +464,6 @@ class Device { DevicePtr src, uint64_t size); - // Each thraed will acquire its own stream - virtual Stream *get_compute_stream() = 0; - private: std::unordered_map caps_; }; @@ -419,8 +476,12 @@ class Surface { virtual DeviceAllocation get_target_image() = 0; virtual void present_image() = 0; virtual std::pair get_size() = 0; + virtual int get_image_count() = 0; virtual BufferFormat image_format() = 0; virtual void resize(uint32_t width, uint32_t height) = 0; + virtual DeviceAllocation get_image_data() { + TI_NOT_IMPLEMENTED + } }; struct VertexInputBinding { @@ -444,6 +505,8 @@ struct SurfaceConfig { bool vsync{false}; bool adaptive{true}; void *window_handle{nullptr}; + uint32_t width{1}; + uint32_t height{1}; }; struct ImageParams { diff --git a/taichi/backends/dx/dx_api.cpp b/taichi/backends/dx/dx_api.cpp new file mode 100644 index 0000000000000..04e497fb386cd --- /dev/null +++ b/taichi/backends/dx/dx_api.cpp @@ -0,0 +1,17 @@ +#include "taichi/backends/dx/dx_api.h" + +namespace taichi { +namespace lang { +namespace directx11 { + +bool is_dx_api_available() { +#ifdef TI_WITH_DX11 + return true; +#else + return false; +#endif +} + +} // namespace directx11 +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_api.h b/taichi/backends/dx/dx_api.h new file mode 100644 index 0000000000000..753299a3df15e --- /dev/null +++ b/taichi/backends/dx/dx_api.h @@ -0,0 +1,20 @@ +#pragma once +#pragma comment(lib, "d3d11.lib") +#pragma comment(lib, "d3dcompiler.lib") +#pragma comment(lib, "dxguid.lib") + +#include "taichi/common/core.h" + +#ifdef TI_WITH_DX11 +#include +#endif + +namespace taichi { +namespace lang { +namespace directx11 { + +bool is_dx_api_available(); + +} +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_device.cpp b/taichi/backends/dx/dx_device.cpp new file mode 100644 index 0000000000000..23deeb61b7c5b --- /dev/null +++ b/taichi/backends/dx/dx_device.cpp @@ -0,0 +1,603 @@ +#include "taichi/backends/dx/dx_device.h" + +#include "spirv_hlsl.hpp" +#include + +namespace taichi { +namespace lang { +namespace directx11 { + +void check_dx_error(HRESULT hr, const char *msg) { + if (!SUCCEEDED(hr)) { + TI_ERROR("Error in {}: {}", msg, hr); + } +} + +std::unique_ptr Dx11ResourceBinder::materialize() { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::rw_buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::rw_buffer(uint32_t set, + uint32_t binding, + DeviceAllocation alloc) { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::buffer(uint32_t set, + uint32_t binding, + DeviceAllocation alloc) { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::image(uint32_t set, + uint32_t binding, + DeviceAllocation alloc, + ImageSamplerConfig sampler_config) { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::vertex_buffer(DevicePtr ptr, uint32_t binding) { + TI_NOT_IMPLEMENTED; +} + +void Dx11ResourceBinder::index_buffer(DevicePtr ptr, size_t index_width) { + TI_NOT_IMPLEMENTED; +} + +Dx11ResourceBinder::~Dx11ResourceBinder() { +} + +Dx11CommandList::Dx11CommandList(Dx11Device *ti_device) : device_(ti_device) { +} + +Dx11CommandList::~Dx11CommandList() { +} + +void Dx11CommandList::bind_pipeline(Pipeline *p) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::bind_resources(ResourceBinder *binder) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::bind_resources(ResourceBinder *binder, + ResourceBinder::Bindings *bindings) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::buffer_barrier(DevicePtr ptr, size_t size) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::buffer_barrier(DeviceAllocation alloc) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::memory_barrier() { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::buffer_copy(DevicePtr dst, DevicePtr src, size_t size) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::buffer_fill(DevicePtr ptr, size_t size, uint32_t data) { + std::unique_ptr cmd = + std::make_unique(this); + ID3D11Buffer *buf = device_->alloc_id_to_buffer(ptr.alloc_id); + ID3D11UnorderedAccessView *uav = device_->alloc_id_to_uav(ptr.alloc_id); + cmd->uav = uav; + D3D11_BUFFER_DESC desc; + buf->GetDesc(&desc); + cmd->size = desc.ByteWidth; + recorded_commands_.push_back(std::move(cmd)); +} + +void Dx11CommandList::CmdBufferFill::execute() { + ID3D11DeviceContext *context = cmdlist_->device_->d3d11_context(); + const UINT values[4] = {data, data, data, data}; + context->ClearUnorderedAccessViewUint(uav, values); +} + +void Dx11CommandList::dispatch(uint32_t x, uint32_t y, uint32_t z) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::begin_renderpass(int x0, + int y0, + int x1, + int y1, + uint32_t num_color_attachments, + DeviceAllocation *color_attachments, + bool *color_clear, + std::vector *clear_colors, + DeviceAllocation *depth_attachment, + bool depth_clear) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::end_renderpass() { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::draw(uint32_t num_verticies, uint32_t start_vertex) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::clear_color(float r, float g, float b, float a) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::set_line_width(float width) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::draw_indexed(uint32_t num_indicies, + uint32_t start_vertex, + uint32_t start_index) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::image_transition(DeviceAllocation img, + ImageLayout old_layout, + ImageLayout new_layout) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::buffer_to_image(DeviceAllocation dst_img, + DevicePtr src_buf, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::image_to_buffer(DevicePtr dst_buf, + DeviceAllocation src_img, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) { + TI_NOT_IMPLEMENTED; +} + +void Dx11CommandList::run_commands() { + for (const auto &cmd : recorded_commands_) { + cmd->execute(); + } +} + +namespace { +HRESULT create_compute_device(ID3D11Device **out_device, + ID3D11DeviceContext **out_context, + bool force_ref, + bool debug_enabled) { + const D3D_FEATURE_LEVEL levels[] = { + D3D_FEATURE_LEVEL_11_1, + D3D_FEATURE_LEVEL_11_0, + D3D_FEATURE_LEVEL_10_1, + D3D_FEATURE_LEVEL_10_0, + }; + + UINT flags = 0; + if (debug_enabled) + flags |= D3D11_CREATE_DEVICE_DEBUG; + + ID3D11Device *device = nullptr; + ID3D11DeviceContext *context = nullptr; + HRESULT hr; + + D3D_DRIVER_TYPE driver_types[] = { + D3D_DRIVER_TYPE_HARDWARE, D3D_DRIVER_TYPE_SOFTWARE, + D3D_DRIVER_TYPE_REFERENCE, D3D_DRIVER_TYPE_WARP}; + const char *driver_type_names[] = { + "D3D_DRIVER_TYPE_HARDWARE", "D3D_DRIVER_TYPE_SOFTWARE", + "D3D_DRIVER_TYPE_REFERENCE", "D3D_DRIVER_TYPE_WARP"}; + + const int num_types = sizeof(driver_types) / sizeof(driver_types[0]); + + int attempt_idx = 0; + if (force_ref) { + attempt_idx = 2; + } + + for (; attempt_idx < num_types; attempt_idx++) { + D3D_DRIVER_TYPE driver_type = driver_types[attempt_idx]; + hr = D3D11CreateDevice(nullptr, driver_type, nullptr, flags, levels, + _countof(levels), D3D11_SDK_VERSION, &device, + nullptr, &context); + + if (FAILED(hr) || device == nullptr) { + TI_WARN("Failed to create D3D11 device with type {}: {}\n", driver_type, + driver_type_names[attempt_idx]); + continue; + } + + if (device->GetFeatureLevel() < D3D_FEATURE_LEVEL_11_0) { + D3D11_FEATURE_DATA_D3D10_X_HARDWARE_OPTIONS hwopts = {0}; + device->CheckFeatureSupport(D3D11_FEATURE_D3D10_X_HARDWARE_OPTIONS, + &hwopts, sizeof(hwopts)); + if (!hwopts.ComputeShaders_Plus_RawAndStructuredBuffers_Via_Shader_4_x) { + device->Release(); + TI_WARN( + "DirectCompute not supported via " + "ComputeShaders_Plus_RawAndStructuredBuffers_Via_Shader_4"); + } + continue; + } + + TI_INFO("Successfully created DX11 device with type {}", + driver_type_names[attempt_idx]); + *out_device = device; + *out_context = context; + break; + } + + if (*out_device == nullptr || *out_context == nullptr) { + TI_ERROR("Failed to create DX11 device using all {} driver types", + num_types); + } + + return hr; +} + +HRESULT create_raw_buffer(ID3D11Device *device, + UINT size, + void *init_data, + ID3D11Buffer **out_buf) { + *out_buf = nullptr; + D3D11_BUFFER_DESC desc = {}; + desc.BindFlags = D3D11_BIND_UNORDERED_ACCESS | D3D11_BIND_SHADER_RESOURCE; + desc.ByteWidth = size; + desc.MiscFlags = D3D11_RESOURCE_MISC_BUFFER_ALLOW_RAW_VIEWS; + if (init_data) { + D3D11_SUBRESOURCE_DATA data; + data.pSysMem = init_data; + return device->CreateBuffer(&desc, &data, out_buf); + } else { + return device->CreateBuffer(&desc, nullptr, out_buf); + } +} + +HRESULT create_buffer_uav(ID3D11Device *device, + ID3D11Buffer *buffer, + ID3D11UnorderedAccessView **out_uav) { + D3D11_BUFFER_DESC buf_desc = {}; + buffer->GetDesc(&buf_desc); + D3D11_UNORDERED_ACCESS_VIEW_DESC uav_desc = {}; + uav_desc.ViewDimension = D3D11_UAV_DIMENSION_BUFFER; + uav_desc.Buffer.FirstElement = 0; + if (buf_desc.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_ALLOW_RAW_VIEWS) { + uav_desc.Format = DXGI_FORMAT_R32_TYPELESS; + uav_desc.Buffer.Flags = D3D11_BUFFER_UAV_FLAG_RAW; + uav_desc.Buffer.NumElements = buf_desc.ByteWidth / 4; + } else if (buf_desc.MiscFlags & D3D11_RESOURCE_MISC_BUFFER_STRUCTURED) { + uav_desc.Format = DXGI_FORMAT_UNKNOWN; + uav_desc.Buffer.NumElements = + buf_desc.ByteWidth / buf_desc.StructureByteStride; + } else + return E_INVALIDARG; + return device->CreateUnorderedAccessView(buffer, &uav_desc, out_uav); +} + +HRESULT compile_compute_shader_from_string(const std::string &source, + LPCSTR entry_point, + ID3D11Device *device, + ID3DBlob **blob) { + UINT flags = D3DCOMPILE_OPTIMIZATION_LEVEL2; + LPCSTR profile = (device->GetFeatureLevel() >= D3D_FEATURE_LEVEL_11_0) + ? "cs_5_0" + : "cs_4_0"; + const D3D_SHADER_MACRO defines[] = {"EXAMPLE_DEFINE", "1", NULL, NULL}; + ID3DBlob *shader_blob = nullptr, *error_blob = nullptr; + HRESULT hr = + D3DCompile(source.data(), source.size(), nullptr, defines, nullptr, + entry_point, profile, flags, 0, &shader_blob, &error_blob); + if (FAILED(hr)) { + TI_WARN("Error in compile_compute_shader_from_string\n"); + if (error_blob) { + TI_WARN("{}", (char *)error_blob->GetBufferPointer()); + error_blob->Release(); + } else + TI_WARN("error_blob is null\n"); + if (shader_blob) { + shader_blob->Release(); + } + return hr; + } + *blob = shader_blob; + return hr; +} + +HRESULT create_cpu_accessible_buffer_copy(ID3D11Device *device, + ID3D11Buffer *src_buf, + ID3D11Buffer **out_buf) { + D3D11_BUFFER_DESC desc; + src_buf->GetDesc(&desc); + D3D11_BUFFER_DESC desc1 = {}; + desc1.BindFlags = 0; + desc1.ByteWidth = desc.ByteWidth; + desc1.Usage = D3D11_USAGE_STAGING; + desc1.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE | D3D11_CPU_ACCESS_READ; + desc1.MiscFlags = 0; + HRESULT hr = device->CreateBuffer(&desc1, nullptr, out_buf); + return hr; +} + +} // namespace + +Dx11Device::Dx11Device() { + create_dx11_device(); + if (kD3d11DebugEnabled) { + info_queue_ = std::make_unique(device_); + } + set_cap(DeviceCapability::spirv_version, 0x10300); + + stream_ = new Dx11Stream(this); +} + +Dx11Device::~Dx11Device() { + destroy_dx11_device(); +} + +void Dx11Device::create_dx11_device() { + if (device_ != nullptr && context_ != nullptr) { + TI_TRACE("D3D11 device has already been created."); + return; + } + TI_TRACE("Creating D3D11 device"); + create_compute_device(&device_, &context_, kD3d11ForceRef, + kD3d11DebugEnabled); +} + +void Dx11Device::destroy_dx11_device() { + if (device_ != nullptr) { + device_->Release(); + device_ = nullptr; + } + if (context_ != nullptr) { + context_->Release(); + context_ = nullptr; + } +} + +int Dx11Device::live_dx11_object_count() { + TI_ASSERT(info_queue_ != nullptr); + return info_queue_->live_object_count(); +} + +DeviceAllocation Dx11Device::allocate_memory(const AllocParams ¶ms) { + ID3D11Buffer *buf; + HRESULT hr; + hr = create_raw_buffer(device_, params.size, nullptr, &buf); + check_dx_error(hr, "create raw buffer"); + alloc_id_to_buffer_[alloc_serial_] = buf; + + ID3D11UnorderedAccessView *uav; + hr = create_buffer_uav(device_, buf, &uav); + check_dx_error(hr, "create UAV for buffer"); + alloc_id_to_uav_[alloc_serial_] = uav; + + // Set debug names + std::string buf_name = "buffer alloc#" + std::to_string(alloc_serial_) + + " size=" + std::to_string(params.size) + '\0'; + hr = buf->SetPrivateData(WKPDID_D3DDebugObjectName, buf_name.size(), + buf_name.c_str()); + check_dx_error(hr, "set name for buffer"); + + std::string uav_name = "UAV of " + buf_name; + hr = uav->SetPrivateData(WKPDID_D3DDebugObjectName, uav_name.size(), + uav_name.c_str()); + check_dx_error(hr, "set name for UAV"); + + DeviceAllocation alloc; + alloc.device = this; + alloc.alloc_id = alloc_serial_; + ++alloc_serial_; + + return alloc; +} + +void Dx11Device::dealloc_memory(DeviceAllocation handle) { + uint32_t alloc_id = handle.alloc_id; + ID3D11Buffer *buf = alloc_id_to_buffer_[alloc_id]; + buf->Release(); + alloc_id_to_buffer_.erase(alloc_id); + ID3D11UnorderedAccessView *uav = alloc_id_to_uav_[alloc_id]; + uav->Release(); + ID3D11Buffer *cpucopy = alloc_id_to_cpucopy_[alloc_id]; + if (cpucopy) + cpucopy->Release(); + alloc_id_to_uav_.erase(alloc_id); +} + +std::unique_ptr Dx11Device::create_pipeline( + const PipelineSourceDesc &src, + std::string name) { + return std::make_unique(src, name, this); +} + +void *Dx11Device::map_range(DevicePtr ptr, uint64_t size) { + TI_NOT_IMPLEMENTED; +} + +void *Dx11Device::map(DeviceAllocation alloc) { + uint32_t alloc_id = alloc.alloc_id; + ID3D11Buffer *buf = alloc_id_to_buffer(alloc_id); + ID3D11Buffer *cpucopy = alloc_id_to_buffer_cpu_copy(alloc_id); + + if (cpucopy == nullptr) { + create_cpu_accessible_buffer_copy(device_, buf, &cpucopy); + alloc_id_to_cpucopy_[alloc_id] = cpucopy; + } + + context_->CopyResource(cpucopy, buf); + D3D11_MAPPED_SUBRESOURCE mapped; + context_->Map(cpucopy, 0, D3D11_MAP_READ_WRITE, 0, &mapped); + return mapped.pData; +} + +void Dx11Device::unmap(DevicePtr ptr) { + ID3D11Buffer *cpucopy = alloc_id_to_buffer_cpu_copy(ptr.alloc_id); + context_->Unmap(cpucopy, 0); +} + +void Dx11Device::unmap(DeviceAllocation alloc) { + ID3D11Buffer *cpucopy = alloc_id_to_buffer_cpu_copy(alloc.alloc_id); + ID3D11Buffer *buf = alloc_id_to_buffer(alloc.alloc_id); + context_->Unmap(cpucopy, 0); + context_->CopyResource(buf, cpucopy); +} + +void Dx11Device::memcpy_internal(DevicePtr dst, DevicePtr src, uint64_t size) { + TI_NOT_IMPLEMENTED; +} + +Stream *Dx11Device::get_compute_stream() { + return stream_; +} + +std::unique_ptr Dx11Device::create_raster_pipeline( + const std::vector &src, + const RasterParams &raster_params, + const std::vector &vertex_inputs, + const std::vector &vertex_attrs, + std::string name) { + TI_NOT_IMPLEMENTED; +} + +Stream *Dx11Device::get_graphics_stream() { + TI_NOT_IMPLEMENTED; +} + +std::unique_ptr Dx11Device::create_surface( + const SurfaceConfig &config) { + TI_NOT_IMPLEMENTED; +} + +DeviceAllocation Dx11Device::create_image(const ImageParams ¶ms) { + TI_NOT_IMPLEMENTED; +} + +void Dx11Device::destroy_image(DeviceAllocation handle) { + TI_NOT_IMPLEMENTED; +} + +void Dx11Device::image_transition(DeviceAllocation img, + ImageLayout old_layout, + ImageLayout new_layout) { + TI_NOT_IMPLEMENTED; +} + +void Dx11Device::buffer_to_image(DeviceAllocation dst_img, + DevicePtr src_buf, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) { + TI_NOT_IMPLEMENTED; +} +void Dx11Device::image_to_buffer(DevicePtr dst_buf, + DeviceAllocation src_img, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) { + TI_NOT_IMPLEMENTED; +} + +ID3D11Buffer *Dx11Device::alloc_id_to_buffer(uint32_t alloc_id) { + return alloc_id_to_buffer_.at(alloc_id); +} + +ID3D11Buffer *Dx11Device::alloc_id_to_buffer_cpu_copy(uint32_t alloc_id) { + if (alloc_id_to_cpucopy_.find(alloc_id) == alloc_id_to_cpucopy_.end()) + return nullptr; + return alloc_id_to_cpucopy_.at(alloc_id); +} + +ID3D11UnorderedAccessView *Dx11Device::alloc_id_to_uav(uint32_t alloc_id) { + return alloc_id_to_uav_.at(alloc_id); +} + +Dx11Stream::Dx11Stream(Dx11Device *device_) : device_(device_) { +} + +Dx11Stream::~Dx11Stream() { +} + +std::unique_ptr Dx11Stream::new_command_list() { + return std::make_unique(device_); +} + +void Dx11Stream::submit(CommandList *cmdlist) { + TI_NOT_IMPLEMENTED; +} + +// No difference for DX11 +void Dx11Stream::submit_synced(CommandList *cmdlist) { + Dx11CommandList *dx_cmd_list = static_cast(cmdlist); + dx_cmd_list->run_commands(); +} + +void Dx11Stream::command_sync() { + // Not needed for DX11 +} + +Dx11Pipeline::Dx11Pipeline(const PipelineSourceDesc &desc, + const std::string &name, + Dx11Device *device) + : device_(device) { + // TODO: Currently, PipelineSourceType::hlsl_src still returns SPIRV binary. + // Will need to update this section when that changes + TI_ASSERT(desc.type == PipelineSourceType::hlsl_src || + desc.type == PipelineSourceType::spirv_binary); + + ID3DBlob *shader_blob; + HRESULT hr; + + std::vector spirv_binary( + (uint32_t *)desc.data, (uint32_t *)((uint8_t *)desc.data + desc.size)); + spirv_cross::CompilerHLSL hlsl(std::move(spirv_binary)); + spirv_cross::CompilerHLSL::Options options; + options.shader_model = 40; + hlsl.set_hlsl_options(options); + + std::string source = hlsl.compile(); + TI_TRACE("hlsl source: \n{}", source); + + hr = compile_compute_shader_from_string( + source, "main", device_->d3d11_device(), &shader_blob); + if (SUCCEEDED(hr)) { + hr = device_->d3d11_device()->CreateComputeShader( + shader_blob->GetBufferPointer(), shader_blob->GetBufferSize(), nullptr, + &compute_shader_); + shader_blob->Release(); + compute_shader_->SetPrivateData(WKPDID_D3DDebugObjectName, name.size(), + name.c_str()); + if (!SUCCEEDED(hr)) { + TI_ERROR("HLSL compute shader creation error"); + } + } else { + TI_ERROR("HLSL compute shader compilation error"); + } +} + +Dx11Pipeline::~Dx11Pipeline() { +} + +ResourceBinder *Dx11Pipeline::resource_binder() { + return nullptr; +} + +} // namespace directx11 +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_device.h b/taichi/backends/dx/dx_device.h new file mode 100644 index 0000000000000..2cf8500028473 --- /dev/null +++ b/taichi/backends/dx/dx_device.h @@ -0,0 +1,226 @@ +#pragma once + +#include "taichi/backends/device.h" +#include "taichi/backends/dx/dx_info_queue.h" +#include + +namespace taichi { +namespace lang { +namespace directx11 { + +// Only enable debug layer when the corresponding testing facility is enabled +constexpr bool kD3d11DebugEnabled = true; +constexpr bool kD3d11ForceRef = false; // Force REF device. May be used to + // force software rendering. + +void debug_enabled(bool); +void force_ref(bool); +void check_dx_error(HRESULT hr, const char *msg); + +class Dx11ResourceBinder : public ResourceBinder { + public: + ~Dx11ResourceBinder() override; + std::unique_ptr materialize() override; + void rw_buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) override; + void rw_buffer(uint32_t set, + uint32_t binding, + DeviceAllocation alloc) override; + void buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) override; + void buffer(uint32_t set, uint32_t binding, DeviceAllocation alloc) override; + void image(uint32_t set, + uint32_t binding, + DeviceAllocation alloc, + ImageSamplerConfig sampler_config) override; + + // Set vertex buffer (not implemented in compute only device) + void vertex_buffer(DevicePtr ptr, uint32_t binding = 0) override; + + // Set index buffer (not implemented in compute only device) + // index_width = 4 -> uint32 index + // index_width = 2 -> uint16 index + void index_buffer(DevicePtr ptr, size_t index_width) override; + + const std::unordered_map &binding_to_alloc_id() { + return binding_to_alloc_id_; + } + + private: + std::unordered_map binding_to_alloc_id_; +}; + +class Dx11Device; + +class Dx11Pipeline : public Pipeline { + public: + Dx11Pipeline(const PipelineSourceDesc &desc, + const std::string &name, + Dx11Device *device); + ~Dx11Pipeline() override; + ResourceBinder *resource_binder() override; + + private: + std::shared_ptr device_{}; + ID3D11ComputeShader *compute_shader_{}; + Dx11ResourceBinder binder_{}; +}; + +class Dx11Stream : public Stream { + public: + Dx11Stream(Dx11Device *); + ~Dx11Stream() override; + + std::unique_ptr new_command_list() override; + void submit(CommandList *cmdlist) override; + void submit_synced(CommandList *cmdlist) override; + void command_sync() override; + + private: + Dx11Device *device_; +}; + +class Dx11CommandList : public CommandList { + public: + Dx11CommandList(Dx11Device *ti_device); + ~Dx11CommandList() override; + + void bind_pipeline(Pipeline *p) override; + void bind_resources(ResourceBinder *binder) override; + void bind_resources(ResourceBinder *binder, + ResourceBinder::Bindings *bindings) override; + void buffer_barrier(DevicePtr ptr, size_t size) override; + void buffer_barrier(DeviceAllocation alloc) override; + void memory_barrier() override; + void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) override; + void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) override; + void dispatch(uint32_t x, uint32_t y = 1, uint32_t z = 1) override; + + // These are not implemented in compute only device + void begin_renderpass(int x0, + int y0, + int x1, + int y1, + uint32_t num_color_attachments, + DeviceAllocation *color_attachments, + bool *color_clear, + std::vector *clear_colors, + DeviceAllocation *depth_attachment, + bool depth_clear) override; + void end_renderpass() override; + void draw(uint32_t num_verticies, uint32_t start_vertex = 0) override; + void clear_color(float r, float g, float b, float a) override; + void set_line_width(float width) override; + void draw_indexed(uint32_t num_indicies, + uint32_t start_vertex = 0, + uint32_t start_index = 0) override; + void image_transition(DeviceAllocation img, + ImageLayout old_layout, + ImageLayout new_layout) override; + void buffer_to_image(DeviceAllocation dst_img, + DevicePtr src_buf, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) override; + void image_to_buffer(DevicePtr dst_buf, + DeviceAllocation src_img, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) override; + + void run_commands(); + + private: + struct Cmd { + explicit Cmd(Dx11CommandList *cmdlist) : cmdlist_(cmdlist) { + } + virtual void execute() { + } + Dx11CommandList *cmdlist_; + }; + + struct CmdBufferFill : public Cmd { + explicit CmdBufferFill(Dx11CommandList *cmdlist) : Cmd(cmdlist) { + } + ID3D11UnorderedAccessView *uav{nullptr}; + size_t offset{0}, size{0}; + uint32_t data{0}; + void execute() override; + }; + + std::vector> recorded_commands_; + Dx11Device *device_; +}; + +class Dx11Device : public GraphicsDevice { + public: + Dx11Device(); + ~Dx11Device() override; + + DeviceAllocation allocate_memory(const AllocParams ¶ms) override; + void dealloc_memory(DeviceAllocation handle) override; + std::unique_ptr create_pipeline( + const PipelineSourceDesc &src, + std::string name = "Pipeline") override; + void *map_range(DevicePtr ptr, uint64_t size) override; + void *map(DeviceAllocation alloc) override; + void unmap(DevicePtr ptr) override; + void unmap(DeviceAllocation alloc) override; + void memcpy_internal(DevicePtr dst, DevicePtr src, uint64_t size) override; + Stream *get_compute_stream() override; + std::unique_ptr create_raster_pipeline( + const std::vector &src, + const RasterParams &raster_params, + const std::vector &vertex_inputs, + const std::vector &vertex_attrs, + std::string name = "Pipeline") override; + Stream *get_graphics_stream() override; + std::unique_ptr create_surface(const SurfaceConfig &config) override; + DeviceAllocation create_image(const ImageParams ¶ms) override; + void destroy_image(DeviceAllocation handle) override; + + void image_transition(DeviceAllocation img, + ImageLayout old_layout, + ImageLayout new_layout) override; + void buffer_to_image(DeviceAllocation dst_img, + DevicePtr src_buf, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) override; + void image_to_buffer(DevicePtr dst_buf, + DeviceAllocation src_img, + ImageLayout img_layout, + const BufferImageCopyParams ¶ms) override; + + int live_dx11_object_count(); + ID3D11DeviceContext *d3d11_context() { + return context_; + } + + ID3D11Buffer *alloc_id_to_buffer(uint32_t alloc_id); + ID3D11Buffer *alloc_id_to_buffer_cpu_copy(uint32_t alloc_id); + ID3D11UnorderedAccessView *alloc_id_to_uav(uint32_t alloc_id); + ID3D11Device *d3d11_device() { + return device_; + } + + private: + void create_dx11_device(); + void destroy_dx11_device(); + ID3D11Device *device_{}; + ID3D11DeviceContext *context_{}; + std::unique_ptr info_queue_{}; + std::unordered_map + alloc_id_to_buffer_; // binding ID to buffer + std::unordered_map + alloc_id_to_cpucopy_; // binding ID to CPU copy of buffer + std::unordered_map + alloc_id_to_uav_; // binding ID to UAV + int alloc_serial_; + Dx11Stream *stream_; +}; + +} // namespace directx11 +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_info_queue.cpp b/taichi/backends/dx/dx_info_queue.cpp new file mode 100644 index 0000000000000..6955a9fa79007 --- /dev/null +++ b/taichi/backends/dx/dx_info_queue.cpp @@ -0,0 +1,140 @@ +#include "taichi/backends/dx/dx_info_queue.h" + +namespace taichi { +namespace lang { +namespace directx11 { + +void check_dx_error(HRESULT hr, const char *msg); + +namespace { +inline std::string trim_string(const std::string &s) { + int begin = 0, end = (int)s.size(); + while (begin < end && std::isspace(s[begin])) { + begin++; + } + while (begin < end && std::isspace(s[end - 1])) { + end--; + } + return std::string(s.begin() + begin, s.begin() + end); +} +} // namespace + +std::string munch_token(std::string &s) { + if (s.empty()) { + return ""; + } + size_t idx = s.find(' '); + std::string ret; + ret = trim_string(s.substr(0, idx)); + s = s.substr(idx + 1); + if (ret.empty() == false && ret.back() == ',') + ret.pop_back(); + return ret; +} + +std::vector Dx11InfoQueue::parse_reference_count( + const std::vector &messages) { + std::vector ret; + for (std::string line : messages) { + Dx11InfoQueue::Entry entry; + line = trim_string(line); + std::string x; + x = munch_token(line); + + // Example 1: "Live ID3D11Query at 0x0000018F64E81DA0, Refcount: 0, IntRef: + // 1" Example 2: "Live ID3D11Buffer at 0x000001F8AF284370, Name: buffer + // alloc#0 size=1048576, Refcount: 1, IntRef: 1" + + if (x != "Live") + continue; + + x = munch_token(line); + entry.type = x; + + x = munch_token(line); + if (x != "at") + continue; + + x = munch_token(line); + if (x.empty()) + continue; + + entry.addr = reinterpret_cast(std::atoll(x.c_str())); + + while (true) { + x = munch_token(line); + if (x == "Refcount:") { + x = munch_token(line); + entry.refcount = std::atoi(x.c_str()); + } else if (x == "IntRef:") { + x = munch_token(line); + entry.intref = std::atoi(x.c_str()); + } else + break; + } + ret.push_back(entry); + } + return ret; +} + +Dx11InfoQueue::Dx11InfoQueue(ID3D11Device *device) + : device_(device), last_message_count_(0) { + init(); +} + +void Dx11InfoQueue::init() { + typedef HRESULT(WINAPI * DXGIGetDebugInterface)(REFIID, void **); + + HRESULT hr; + hr = device_->QueryInterface(__uuidof(ID3D11InfoQueue), + reinterpret_cast(&info_queue_)); + check_dx_error(hr, "Query ID3D11InfoQueue interface from the DX11 device"); + hr = device_->QueryInterface(__uuidof(ID3D11Debug), + reinterpret_cast(&debug_)); + check_dx_error(hr, "Query ID3D11Debug interface from the DX11 device"); +} + +std::vector Dx11InfoQueue::get_updated_messages() { + std::vector ret; + if (!info_queue_) { + return ret; + } + const int num_messages = info_queue_->GetNumStoredMessages(); + const int n = num_messages - last_message_count_; + ret.resize(n); + for (int i = 0; i < n; i++) { + D3D11_MESSAGE *msg; + size_t len = 0; + HRESULT hr = + info_queue_->GetMessageW(i + last_message_count_, nullptr, &len); + check_dx_error(hr, "Check D3D11 info queue message length"); + msg = (D3D11_MESSAGE *)malloc(len); + hr = info_queue_->GetMessageW(i + last_message_count_, msg, &len); + check_dx_error(hr, "Obtain D3D11 info queue message content"); + ret[i] = std::string(msg->pDescription); + free(msg); + } + last_message_count_ = num_messages; + return ret; +} + +bool Dx11InfoQueue::has_updated_messages() { + if (!info_queue_) { + return false; + } + const int n = info_queue_->GetNumStoredMessages(); + return n > last_message_count_; +} + +int Dx11InfoQueue::live_object_count() { + get_updated_messages(); // Drain message queue + debug_->ReportLiveDeviceObjects(D3D11_RLDO_DETAIL); + if (has_updated_messages()) { + live_objects_ = parse_reference_count(get_updated_messages()); + } + return static_cast(live_objects_.size()); +} + +} // namespace directx11 +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_info_queue.h b/taichi/backends/dx/dx_info_queue.h new file mode 100644 index 0000000000000..18cda1043a999 --- /dev/null +++ b/taichi/backends/dx/dx_info_queue.h @@ -0,0 +1,36 @@ +#pragma once + +#include "taichi/backends/device.h" +#include + +namespace taichi { +namespace lang { +namespace directx11 { + +class Dx11InfoQueue { + public: + struct Entry { + std::string type; + void *addr; + int refcount; + int intref; + }; + static std::vector parse_reference_count( + const std::vector &); + explicit Dx11InfoQueue(ID3D11Device *device); + int live_object_count(); + + private: + bool has_updated_messages(); + std::vector get_updated_messages(); + std::vector live_objects_; + void init(); + ID3D11Device *device_{}; + ID3D11Debug *debug_{}; + ID3D11InfoQueue *info_queue_{}; + int last_message_count_; +}; + +} // namespace directx11 +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_program.cpp b/taichi/backends/dx/dx_program.cpp new file mode 100644 index 0000000000000..c6335fa9127df --- /dev/null +++ b/taichi/backends/dx/dx_program.cpp @@ -0,0 +1,68 @@ +#include "taichi/backends/dx/dx_program.h" + +#include "taichi/backends/dx/dx_device.h" +#include "taichi/backends/vulkan/snode_tree_manager.h" + +namespace taichi { +namespace lang { +namespace directx11 { + +FunctionType compile_to_executable(Kernel *kernel, + vulkan::VkRuntime *runtime, + vulkan::SNodeTreeManager snode_tree_mgr) { + auto handle = runtime->register_taichi_kernel( + std::move(vulkan::run_codegen(kernel, runtime->get_ti_device(), + snode_tree_mgr->get_compiled_structs()))); + return [runtime, handle](RuntimeContext &ctx) { + runtime->launch_kernel(handle, &ctx); + }; +} + +} // namespace directx11 + +Dx11ProgramImpl::Dx11ProgramImpl(CompileConfig &config) : ProgramImpl(config) { +} + +FunctionType Dx11ProgramImpl::compile(Kernel *kernel, + OffloadedStmt *offloaded) { + spirv::lower(kernel); + return directx11::compile_to_executable(kernel, runtime_.get(), + snode_tree_mgr_.get()); +} + +void Dx11ProgramImpl::materialize_runtime(MemoryPool *memory_pool, + KernelProfilerBase *profiler, + uint64 **result_buffer_ptr) { + *result_buffer_ptr = (uint64 *)memory_pool->allocate( + sizeof(uint64) * taichi_result_buffer_entries, 8); + + device_ = std::make_unique(); + + vulkan::VkRuntime::Params params; + params.host_result_buffer = *result_buffer_ptr; + params.device = device_.get(); + runtime_ = std::make_unique(std::move(params)); + snode_tree_mgr_ = std::make_unique(runtime_.get()); +} + +void Dx11ProgramImpl::synchronize() { + TI_NOT_IMPLEMENTED; +} + +void Dx11ProgramImpl::materialize_snode_tree( + SNodeTree *tree, + std::vector> &snode_trees_, + uint64 *result_buffer_ptr) { + snode_tree_mgr_->materialize_snode_tree(tree); +} + +std::unique_ptr Dx11ProgramImpl::make_aot_module_builder() { + return nullptr; +} + +void Dx11ProgramImpl::destroy_snode_tree(SNodeTree *snode_tree) { + TI_NOT_IMPLEMENTED; +} + +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/dx/dx_program.h b/taichi/backends/dx/dx_program.h new file mode 100644 index 0000000000000..1d8353315d03e --- /dev/null +++ b/taichi/backends/dx/dx_program.h @@ -0,0 +1,39 @@ +#pragma once + +#include "taichi/backends/dx/dx_device.h" +#include "taichi/backends/vulkan/runtime.h" +#include "taichi/backends/vulkan/snode_tree_manager.h" +#include "taichi/program/program_impl.h" + +namespace taichi { +namespace lang { + +class Dx11ProgramImpl : public ProgramImpl { + public: + Dx11ProgramImpl(CompileConfig &config); + + FunctionType compile(Kernel *kernel, OffloadedStmt *offloaded) override; + std::size_t get_snode_num_dynamically_allocated( + SNode *snode, + uint64 *result_buffer) override { + return 0; + } + std::unique_ptr make_aot_module_builder(); + void materialize_runtime(MemoryPool *memory_pool, + KernelProfilerBase *profiler, + uint64 **result_buffer_ptr) override; + virtual void materialize_snode_tree( + SNodeTree *tree, + std::vector> &snode_trees_, + uint64 *result_buffer_ptr) override; + virtual void destroy_snode_tree(SNodeTree *snode_tree) override; + void synchronize() override; + + private: + std::shared_ptr device_{nullptr}; + std::unique_ptr runtime_{nullptr}; + std::unique_ptr snode_tree_mgr_{nullptr}; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/interop/vulkan_cpu_interop.cpp b/taichi/backends/interop/vulkan_cpu_interop.cpp index 0ae14a270e86a..11359dcc2b8bb 100644 --- a/taichi/backends/interop/vulkan_cpu_interop.cpp +++ b/taichi/backends/interop/vulkan_cpu_interop.cpp @@ -10,7 +10,7 @@ namespace taichi { namespace lang { -#if TI_WITH_VULKAN +#if TI_WITH_VULKAN && defined(TI_WITH_LLVM) using namespace taichi::lang::vulkan; using namespace taichi::lang::cpu; diff --git a/taichi/backends/interop/vulkan_cuda_interop.cpp b/taichi/backends/interop/vulkan_cuda_interop.cpp index ae6e5520a199e..9376087f94d0e 100644 --- a/taichi/backends/interop/vulkan_cuda_interop.cpp +++ b/taichi/backends/interop/vulkan_cuda_interop.cpp @@ -149,14 +149,22 @@ void memcpy_cuda_to_vulkan(DevicePtr dst, DevicePtr src, uint64_t size) { DeviceAllocation dst_alloc(dst); DeviceAllocation src_alloc(src); - static std::unordered_map alloc_base_ptrs; + static std::unordered_map< + VulkanDevice *, + std::unordered_map>> + alloc_base_ptrs_all; + std::unordered_map &alloc_base_ptrs = + alloc_base_ptrs_all[vk_dev][cuda_dev]; if (alloc_base_ptrs.find(dst_alloc.alloc_id) == alloc_base_ptrs.end()) { auto [base_mem, alloc_offset, alloc_size] = vk_dev->get_vkmemory_offset_size(dst_alloc); - auto block_size = VulkanDevice::kMemoryBlockSize; + // this might be smaller than the actual size of the VkDeviceMemory, but it + // is big enough to cover the region of this buffer, so it's fine. + size_t mem_size = alloc_offset + alloc_size; void *alloc_base_ptr = get_cuda_memory_pointer( - base_mem, /*mem_size=*/block_size, /*offset=*/alloc_offset, + base_mem, /*mem_size=*/mem_size, /*offset=*/alloc_offset, /*buffer_size=*/alloc_size, vk_dev->vk_device()); alloc_base_ptrs[dst_alloc.alloc_id] = (unsigned char *)alloc_base_ptr; } diff --git a/taichi/backends/metal/aot_module_builder_impl.cpp b/taichi/backends/metal/aot_module_builder_impl.cpp index 263029b2776ed..a4d985670cf69 100644 --- a/taichi/backends/metal/aot_module_builder_impl.cpp +++ b/taichi/backends/metal/aot_module_builder_impl.cpp @@ -11,16 +11,18 @@ namespace metal { AotModuleBuilderImpl::AotModuleBuilderImpl( const CompiledRuntimeModule *compiled_runtime_module, const std::vector &compiled_snode_trees, - const BufferMetaData &buffer_meta_data) + const std::unordered_set &fields, + BufferMetaData buffer_meta_data) : compiled_runtime_module_(compiled_runtime_module), compiled_snode_trees_(compiled_snode_trees), - buffer_meta_data_(buffer_meta_data) { + fields_(fields) { + buffer_meta_data.root_buffer_size = compiled_snode_trees_[0].root_size; ti_aot_data_.metadata = buffer_meta_data; } -void AotModuleBuilderImpl::metalgen(const std::string &dir, - const std::string &filename, - const CompiledKernelData &k) const { +void AotModuleBuilderImpl::write_metal_file(const std::string &dir, + const std::string &filename, + const CompiledKernelData &k) const { const std::string mtl_path = fmt::format("{}/{}_{}.metal", dir, filename, k.kernel_name); std::ofstream fs{mtl_path}; @@ -41,12 +43,12 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir, ts.write_to_file(txt_path); for (const auto &k : ti_aot_data_.kernels) { - metalgen(output_dir, filename, k); + write_metal_file(output_dir, filename, k); } for (const auto &k : ti_aot_data_.tmpl_kernels) { for (auto &ki : k.kernel_tmpl_map) { - metalgen(output_dir, filename, ki.second); + write_metal_file(output_dir, filename, ki.second); } } } @@ -59,18 +61,26 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, ti_aot_data_.kernels.push_back(std::move(compiled)); } -void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier, +void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, int row_num, int column_num) { + const auto *dense_snode = rep_snode->parent; + TI_ASSERT_INFO(fields_.find(dense_snode) != fields_.end(), + "dense_snode: id={} type={}", dense_snode->id, + dense_snode->get_node_type_name_hinted()); + const auto &dense_desc = + compiled_snode_trees_[0].snode_descriptors.at(dense_snode->id); CompiledFieldData field_data; field_data.field_name = identifier; field_data.is_scalar = is_scalar; field_data.dtype = to_metal_type(dt); field_data.dtype_name = metal_data_type_name(dt); field_data.shape = shape; + field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent; field_data.row_num = row_num; field_data.column_num = column_num; ti_aot_data_.fields.push_back(field_data); diff --git a/taichi/backends/metal/aot_module_builder_impl.h b/taichi/backends/metal/aot_module_builder_impl.h index ebc8981a965f2..5f0532789ee4c 100644 --- a/taichi/backends/metal/aot_module_builder_impl.h +++ b/taichi/backends/metal/aot_module_builder_impl.h @@ -2,10 +2,11 @@ #include #include +#include +#include "taichi/aot/module_builder.h" #include "taichi/backends/metal/aot_utils.h" #include "taichi/backends/metal/struct_metal.h" -#include "taichi/program/aot_module_builder.h" namespace taichi { namespace lang { @@ -16,14 +17,17 @@ class AotModuleBuilderImpl : public AotModuleBuilder { explicit AotModuleBuilderImpl( const CompiledRuntimeModule *compiled_runtime_module, const std::vector &compiled_snode_trees, - const BufferMetaData &buffer_meta_data); + const std::unordered_set &fields, + BufferMetaData buffer_meta_data); void dump(const std::string &output_dir, const std::string &filename) const override; protected: void add_per_backend(const std::string &identifier, Kernel *kernel) override; - void add_per_backend_field(const std::string &identifier, + + void add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, @@ -34,15 +38,15 @@ class AotModuleBuilderImpl : public AotModuleBuilder { Kernel *kernel) override; private: + void write_metal_file(const std::string &dir, + const std::string &filename, + const CompiledKernelData &k) const; + const CompiledRuntimeModule *compiled_runtime_module_; const std::vector &compiled_snode_trees_; - BufferMetaData buffer_meta_data_; + const std::unordered_set fields_; PrintStringTable strtab_; TaichiAotData ti_aot_data_; - - void metalgen(const std::string &dir, - const std::string &filename, - const CompiledKernelData &k) const; }; } // namespace metal diff --git a/taichi/backends/metal/aot_module_loader_impl.cpp b/taichi/backends/metal/aot_module_loader_impl.cpp new file mode 100644 index 0000000000000..a02192c6c7115 --- /dev/null +++ b/taichi/backends/metal/aot_module_loader_impl.cpp @@ -0,0 +1,88 @@ +#include "taichi/backends/metal/aot_module_loader_impl.h" + +#include "taichi/backends/metal/aot_utils.h" +#include "taichi/backends/metal/kernel_manager.h" + +namespace taichi { +namespace lang { +namespace metal { +namespace { + +class KernelImpl : public aot::Kernel { + public: + explicit KernelImpl(KernelManager *runtime, const std::string &kernel_name) + : runtime_(runtime), kernel_name_(kernel_name) { + } + + void launch(RuntimeContext *ctx) override { + runtime_->launch_taichi_kernel(kernel_name_, ctx); + } + + private: + KernelManager *const runtime_; + const std::string kernel_name_; +}; + +class AotModuleImpl : public aot::Module { + public: + explicit AotModuleImpl(const AotModuleParams ¶ms) + : runtime_(params.runtime) { + const std::string bin_path = + fmt::format("{}/metadata.tcb", params.module_path); + read_from_binary_file(aot_data_, bin_path); + // Do we still need to load each individual kernel? + for (const auto &k : aot_data_.kernels) { + kernels_[k.kernel_name] = &k; + } + } + + std::unique_ptr get_kernel(const std::string &name) override { + return make_new_kernel(name); + } + + std::unique_ptr get_field(const std::string &name) override { + TI_NOT_IMPLEMENTED; + } + + size_t get_root_size() const override { + return aot_data_.metadata.root_buffer_size; + } + + // Module metadata + Arch arch() const override { + return Arch::metal; + } + uint64_t version() const override { + TI_NOT_IMPLEMENTED; + } + + private: + std::unique_ptr make_new_kernel( + const std::string &name) override { + auto itr = kernels_.find(name); + if (itr == kernels_.end()) { + TI_DEBUG("Failed to load kernel {}", name); + return nullptr; + } + auto *kernel_data = itr->second; + runtime_->register_taichi_kernel(name, kernel_data->source_code, + kernel_data->kernel_attribs, + kernel_data->ctx_attribs); + return std::make_unique(runtime_, name); + } + + KernelManager *const runtime_; + TaichiAotData aot_data_; + std::unordered_map kernels_; +}; + +} // namespace + +std::unique_ptr make_aot_module(std::any mod_params) { + AotModuleParams params = std::any_cast(mod_params); + return std::make_unique(params); +} + +} // namespace metal +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/metal/aot_module_loader_impl.h b/taichi/backends/metal/aot_module_loader_impl.h new file mode 100644 index 0000000000000..04cc71f494f78 --- /dev/null +++ b/taichi/backends/metal/aot_module_loader_impl.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include + +#include "taichi/aot/module_loader.h" + +namespace taichi { +namespace lang { +namespace metal { + +class KernelManager; + +struct AotModuleParams { + std::string module_path; + KernelManager *runtime{nullptr}; +}; + +std::unique_ptr make_aot_module(std::any mod_params); +} // namespace metal +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 3d0780570b363..9a65afbcc7621 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -23,6 +23,7 @@ namespace shaders { #define TI_INSIDE_METAL_CODEGEN #include "taichi/backends/metal/shaders/ad_stack.metal.h" #include "taichi/backends/metal/shaders/helpers.metal.h" +#include "taichi/backends/metal/shaders/init_randseeds.metal.h" #include "taichi/backends/metal/shaders/print.metal.h" #include "taichi/backends/metal/shaders/runtime_kernels.metal.h" #undef TI_INSIDE_METAL_CODEGEN @@ -81,6 +82,19 @@ bool is_ret_type_bit_pointer(Stmt *s) { return false; } +bool is_full_bits(int bits) { + return bits == (sizeof(uint32_t) * 8); +} + +void validate_cft_for_metal(CustomFloatType *cft) { + if (cft->get_exponent_type() != nullptr) { + TI_NOT_IMPLEMENTED; + } + if (cft->get_compute_type()->as() != PrimitiveType::f32) { + TI_ERROR("Metal only supports 32-bit float"); + } +} + class RootIdsExtractor : public BasicStmtVisitor { public: static std::unordered_set run(Stmt *s) { @@ -106,9 +120,32 @@ class RootIdsExtractor : public BasicStmtVisitor { } private: + using BasicStmtVisitor::visit; std::unordered_set roots_; }; +class TaskPreprocessor final : public BasicStmtVisitor { + public: + struct Result { + bool should_init_randseeds{false}; + }; + + static Result run(Stmt *s) { + TaskPreprocessor tp; + s->accept(&tp); + return tp.res_; + } + + protected: + void visit(RandStmt *) override { + res_.should_init_randseeds = true; + } + using BasicStmtVisitor::visit; + + TaskPreprocessor() = default; + Result res_; +}; + class KernelCodegenImpl : public IRVisitor { private: enum class Section { @@ -232,7 +269,7 @@ class KernelCodegenImpl : public IRVisitor { const auto root_id = stmt->root()->id; root_id_to_stmts_[root_id] = stmt; const auto &cst = get_compiled_snode_tree(stmt->root()); - const auto root_desc = BufferDescriptor::Root(root_id); + const auto root_desc = BufferDescriptor::root(root_id); emit(R"({} {}({});)", cst.root_snode_type_name, stmt->raw_name(), buffer_to_name(root_desc)); } @@ -392,7 +429,11 @@ class KernelCodegenImpl : public IRVisitor { void visit(ReturnStmt *stmt) override { // TODO: use stmt->ret_id instead of 0 as index - emit("*{}.ret0() = {};", kContextVarName, stmt->value->raw_name()); + int idx{0}; + for (auto &value : stmt->values) { + emit("{}.ret0()[{}] = {};", kContextVarName, idx, value->raw_name()); + idx++; + } } void visit(ExternalPtrStmt *stmt) override { @@ -464,6 +505,9 @@ class KernelCodegenImpl : public IRVisitor { } } + void visit(DecorationStmt *stmt) override { + } + void visit(UnaryOpStmt *stmt) override { if (stmt->op_type == UnaryOpType::cast_value) { emit("const {} {} = static_cast<{}>({});", @@ -643,17 +687,21 @@ class KernelCodegenImpl : public IRVisitor { const auto root_ids = RootIdsExtractor::run(stmt); BufferDescSet used_root_descs; for (const auto rid : root_ids) { - used_root_descs.insert(BufferDescriptor::Root(rid)); + used_root_descs.insert(BufferDescriptor::root(rid)); } root_id_to_stmts_.clear(); + auto preproc_res = TaskPreprocessor::run(stmt); using Type = OffloadedStmt::TaskType; if (stmt->task_type == Type::serial) { - generate_serial_kernel(stmt, used_root_descs); + // For serial tasks, there is only one thread, so different calls to + // random() is guaranteed to produce different results. + preproc_res.should_init_randseeds = false; + generate_serial_kernel(stmt, used_root_descs, preproc_res); } else if (stmt->task_type == Type::range_for) { - generate_range_for_kernel(stmt, used_root_descs); + generate_range_for_kernel(stmt, used_root_descs, preproc_res); } else if (stmt->task_type == Type::struct_for) { - generate_struct_for_kernel(stmt, used_root_descs); + generate_struct_for_kernel(stmt, used_root_descs, preproc_res); } else if (stmt->task_type == Type::listgen) { add_runtime_list_op_kernel(stmt); } else if (stmt->task_type == Type::gc) { @@ -679,7 +727,16 @@ class KernelCodegenImpl : public IRVisitor { } void visit(ContinueStmt *stmt) override { - if (stmt->as_return()) { + auto stmt_in_off_for = [stmt]() { + TI_ASSERT(stmt->scope != nullptr); + if (auto *offl = stmt->scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return true; + } + return false; + }; + if (stmt_in_off_for()) { emit("return;"); } else { emit("continue;"); @@ -870,6 +927,9 @@ class KernelCodegenImpl : public IRVisitor { current_appender().append_raw(shaders::kMetalPrintSourceCode); emit(""); emit_kernel_args_struct(); + emit(""); + current_appender().append_raw(shaders::kMetalInitRandseedsSourceCode); + emit(""); } void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) { @@ -986,19 +1046,6 @@ class KernelCodegenImpl : public IRVisitor { bit_ptr_stmt->raw_name(), num_bits); } - void validate_cft_for_metal(CustomFloatType *cft) const { - if (cft->get_exponent_type() != nullptr) { - TI_NOT_IMPLEMENTED; - } - if (cft->get_compute_type()->as() != PrimitiveType::f32) { - TI_ERROR("Metal only supports 32-bit float"); - } - } - - static bool is_full_bits(int bits) { - return bits == (sizeof(uint32_t) * 8); - } - void emit_kernel_args_struct() { if (ctx_attribs_.empty()) { return; @@ -1067,18 +1114,19 @@ class KernelCodegenImpl : public IRVisitor { TI_ASSERT(rhs.type() == BufferType::Root); return lhs.root_id() < rhs.root_id(); }); - result.push_back(BufferDescriptor::GlobalTmps()); + result.push_back(BufferDescriptor::global_tmps()); if (!ctx_attribs_.empty()) { - result.push_back(BufferDescriptor::Context()); + result.push_back(BufferDescriptor::context()); } - result.push_back(BufferDescriptor::Runtime()); + result.push_back(BufferDescriptor::runtime()); // TODO(k-ye): Bind this buffer only when print() is used. - result.push_back(BufferDescriptor::Print()); + result.push_back(BufferDescriptor::print()); return result; } void generate_serial_kernel(OffloadedStmt *stmt, - const BufferDescSet &root_buffer_descs) { + const BufferDescSet &root_buffer_descs, + const TaskPreprocessor::Result &preproc_res) { TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::serial); const std::string mtl_kernel_name = make_kernel_name(); KernelAttributes ka; @@ -1096,7 +1144,8 @@ class KernelCodegenImpl : public IRVisitor { current_kernel_attribs_ = &ka; const auto mtl_func_name = mtl_kernel_func_name(mtl_kernel_name); - emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, stmt->body.get()); + emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, preproc_res, + stmt->body.get()); emit_call_mtl_kernel_func(mtl_func_name, ka.buffers, /*loop_index_expr=*/"0"); } @@ -1108,7 +1157,8 @@ class KernelCodegenImpl : public IRVisitor { } void generate_range_for_kernel(OffloadedStmt *stmt, - const BufferDescSet &root_buffer_descs) { + const BufferDescSet &root_buffer_descs, + const TaskPreprocessor::Result &preproc_res) { TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::range_for); const std::string mtl_kernel_name = make_kernel_name(); KernelAttributes ka; @@ -1185,7 +1235,7 @@ class KernelCodegenImpl : public IRVisitor { extra_args.push_back(kTlsBufferName); } emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, extra_func_params, - stmt->body.get()); + preproc_res, stmt->body.get()); emit_call_mtl_kernel_func(mtl_func_name, ka.buffers, extra_args, /*loop_index_expr=*/"ii"); } @@ -1204,7 +1254,8 @@ class KernelCodegenImpl : public IRVisitor { } void generate_struct_for_kernel(OffloadedStmt *stmt, - const BufferDescSet &root_buffer_descs) { + const BufferDescSet &root_buffer_descs, + const TaskPreprocessor::Result &preproc_res) { TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::struct_for); const std::string mtl_kernel_name = make_kernel_name(); @@ -1251,7 +1302,7 @@ class KernelCodegenImpl : public IRVisitor { kKernelGridSizeName); { const auto belonged_root_id = snode_to_roots_.at(sn_id).snode_id; - const auto root_desc = BufferDescriptor::Root(belonged_root_id); + const auto root_desc = BufferDescriptor::root(belonged_root_id); ScopedIndent s2(current_appender()); emit("const int parent_idx_ = (ii / child_num_slots);"); emit("if (parent_idx_ >= parent_list.num_active()) break;"); @@ -1283,7 +1334,7 @@ class KernelCodegenImpl : public IRVisitor { extra_args.push_back(kTlsBufferName); } emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, extra_func_params, - stmt->body.get()); + preproc_res, stmt->body.get()); emit_call_mtl_kernel_func(mtl_func_name, ka.buffers, extra_args, /*loop_index_expr=*/"ii"); } @@ -1331,9 +1382,9 @@ class KernelCodegenImpl : public IRVisitor { std::min(total_num_self_from_root(sn_descs, sn->id), kMaxNumThreadsGridStrideLoop); ka.advisory_num_threads_per_group = stmt->block_dim; - ka.buffers = {BufferDescriptor::Runtime(), - BufferDescriptor::Root(snode_to_roots_.at(sn->id).snode_id), - BufferDescriptor::Context()}; + ka.buffers = {BufferDescriptor::runtime(), + BufferDescriptor::root(snode_to_roots_.at(sn->id).snode_id), + BufferDescriptor::context()}; ka.runtime_list_op_attribs = KernelAttributes::RuntimeListOpAttributes(); ka.runtime_list_op_attribs->snode = sn; @@ -1353,7 +1404,7 @@ class KernelCodegenImpl : public IRVisitor { ka.task_type = OffloadedTaskType::gc; ka.gc_op_attribs = KernelAttributes::GcOpAttributes(); ka.gc_op_attribs->snode = sn; - ka.buffers = {BufferDescriptor::Runtime(), BufferDescriptor::Context()}; + ka.buffers = {BufferDescriptor::runtime(), BufferDescriptor::context()}; current_kernel_attribs_ = nullptr; // stage 1 specific ka.name = "gc_compact_free_list"; @@ -1403,6 +1454,7 @@ class KernelCodegenImpl : public IRVisitor { const std::string &kernel_func_name, const std::vector &buffers, const std::vector &extra_params, + const TaskPreprocessor::Result &preproc_res, Block *func_ir) { SectionGuard sg(this, Section::KernelFuncs); @@ -1430,6 +1482,10 @@ class KernelCodegenImpl : public IRVisitor { fmt::arg("rtm", kRuntimeVarName), fmt::arg("lidx", kLinearLoopIndexName), fmt::arg("nums", kNumRandSeeds)); + if (preproc_res.should_init_randseeds) { + emit("mtl_init_random_seeds(({}->rand_seeds), {}, {});", + kRuntimeVarName, kLinearLoopIndexName, kNumRandSeeds); + } // Init AssertRecorder. emit("AssertRecorder {}({});", kAssertRecorderVarName, kPrintAssertBufferName); @@ -1452,9 +1508,10 @@ class KernelCodegenImpl : public IRVisitor { inline void emit_mtl_kernel_func_def( const std::string &kernel_func_name, const std::vector &buffers, + const TaskPreprocessor::Result &preproc_res, Block *func_ir) { emit_mtl_kernel_func_def(kernel_func_name, buffers, /*extra_params=*/{}, - func_ir); + preproc_res, func_ir); } void emit_call_mtl_kernel_func(const std::string &kernel_func_name, @@ -1625,7 +1682,8 @@ FunctionType compile_to_metal_executable( kernel_mgr->register_taichi_kernel( compiled_res.kernel_name, compiled_res.source_code, compiled_res.kernel_attribs, compiled_res.ctx_attribs); - return [kernel_mgr, kernel_name = compiled_res.kernel_name](Context &ctx) { + return [kernel_mgr, + kernel_name = compiled_res.kernel_name](RuntimeContext &ctx) { kernel_mgr->launch_taichi_kernel(kernel_name, &ctx); }; } diff --git a/taichi/backends/metal/data_types.cpp b/taichi/backends/metal/data_types.cpp index 46d9325032e95..7f3c2266967f5 100644 --- a/taichi/backends/metal/data_types.cpp +++ b/taichi/backends/metal/data_types.cpp @@ -93,6 +93,8 @@ std::string metal_unary_op_type_symbol(UnaryOpType type) { return "-"; case UnaryOpType::sqrt: return "sqrt"; + case UnaryOpType::round: + return "round"; case UnaryOpType::floor: return "floor"; case UnaryOpType::ceil: diff --git a/taichi/backends/metal/device.cpp b/taichi/backends/metal/device.cpp new file mode 100644 index 0000000000000..af00838ff1b4f --- /dev/null +++ b/taichi/backends/metal/device.cpp @@ -0,0 +1,338 @@ +#include "taichi/backends/metal/device.h" + +#include "taichi/platform/mac/objc_api.h" +#include "taichi/backends/metal/api.h" +#include "taichi/backends/metal/constants.h" +#include "taichi/backends/metal/runtime_utils.h" + +namespace taichi { +namespace lang { +namespace metal { + +#ifdef TI_PLATFORM_OSX +namespace { + +class ResourceBinderImpl : public ResourceBinder { + public: + struct Binding { + DeviceAllocationId alloc_id{0}; + // Not sure if this info is necessary yet. + // TODO: Make it an enum? + [[maybe_unused]] bool is_constant{false}; + }; + using BindingMap = std::unordered_map; + + explicit ResourceBinderImpl(const Device *dev) : dev_(dev) { + } + + std::unique_ptr materialize() override { + TI_NOT_IMPLEMENTED; + return nullptr; + } + // RW buffers + void rw_buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) override { + TI_NOT_IMPLEMENTED; + } + void rw_buffer(uint32_t set, + uint32_t binding, + DeviceAllocation alloc) override { + bind_buffer(set, binding, alloc, /*is_constant=*/false); + } + + // Constant buffers + void buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) override { + TI_NOT_IMPLEMENTED; + } + void buffer(uint32_t set, uint32_t binding, DeviceAllocation alloc) override { + bind_buffer(set, binding, alloc, /*is_constant=*/true); + } + + const BindingMap &binding_map() const { + return binding_map_; + } + + private: + void bind_buffer(uint32_t set, + uint32_t binding, + DeviceAllocation alloc, + bool is_constant) { + TI_ASSERT(set == 0); + TI_ASSERT(alloc.device == dev_); + binding_map_[binding] = {alloc.alloc_id, is_constant}; + } + + const Device *const dev_; + BindingMap binding_map_; +}; + +class PipelineImpl : public Pipeline { + public: + explicit PipelineImpl(nsobj_unique_ptr pipeline) + : pipeline_state_(std::move(pipeline)) { + } + + ResourceBinder *resource_binder() override { + // TODO: Hmm, why do we need this interface? + return nullptr; + } + + MTLComputePipelineState *mtl_pipeline_state() { + return pipeline_state_.get(); + } + + private: + nsobj_unique_ptr pipeline_state_{nullptr}; +}; + +class CommandListImpl : public CommandList { + private: + struct ComputeEncoderBuilder { + MTLComputePipelineState *pipeline{nullptr}; + ResourceBinderImpl::BindingMap binding_map; + }; + + public: + explicit CommandListImpl(nsobj_unique_ptr cb) + : command_buffer_(std::move(cb)) { + } + + MTLCommandBuffer *command_buffer() { + return command_buffer_.get(); + } + + void set_label(const std::string &label) { + inflight_label_ = label; + } + + void bind_pipeline(Pipeline *p) override { + get_or_make_compute_builder()->pipeline = + static_cast(p)->mtl_pipeline_state(); + } + + void bind_resources(ResourceBinder *binder) override { + get_or_make_compute_builder()->binding_map = + static_cast(binder)->binding_map(); + } + + void bind_resources(ResourceBinder *binder, + ResourceBinder::Bindings *bindings) override { + TI_NOT_IMPLEMENTED; + } + void buffer_barrier(DevicePtr ptr, size_t size) override { + TI_NOT_IMPLEMENTED; + } + void buffer_barrier(DeviceAllocation alloc) override { + TI_NOT_IMPLEMENTED; + } + void memory_barrier() override { + TI_NOT_IMPLEMENTED; + } + + void buffer_copy(DevicePtr dst, DevicePtr src, size_t size) override { + } + void buffer_fill(DevicePtr ptr, size_t size, uint32_t data) override { + } + void dispatch(uint32_t x, uint32_t y, uint32_t z) override { + TI_ERROR("Please call dispatch(grid_size, block_size) instead"); + } + + void dispatch(CommandList::ComputeSize grid_size, + CommandList::ComputeSize block_size) override { + auto encoder = new_compute_command_encoder(command_buffer_.get()); + TI_ASSERT(encoder != nullptr); + metal::set_label(encoder.get(), inflight_label_); + const auto &builder = inflight_compute_builder_.value(); + set_compute_pipeline_state(encoder.get(), builder.pipeline); + auto ceil_div = [](uint32_t a, uint32_t b) -> uint32_t { + return (a + b - 1) / b; + }; + const auto num_blocks_x = ceil_div(grid_size.x, block_size.x); + const auto num_blocks_y = ceil_div(grid_size.y, block_size.y); + const auto num_blocks_z = ceil_div(grid_size.z, block_size.z); + dispatch_threadgroups(encoder.get(), num_blocks_x, num_blocks_y, + num_blocks_z, block_size.x, block_size.y, + block_size.z); + finish_encoder(encoder.get()); + } + + // Graphics commands are not implemented on Metal + private: + ComputeEncoderBuilder *get_or_make_compute_builder() { + if (!inflight_compute_builder_.has_value()) { + inflight_compute_builder_ = ComputeEncoderBuilder{}; + } + return &(inflight_compute_builder_.value()); + } + + template + void finish_encoder(T *encoder) { + end_encoding(encoder); + inflight_label_.clear(); + inflight_compute_builder_.reset(); + } + + nsobj_unique_ptr command_buffer_{nullptr}; + std::string inflight_label_; + std::optional inflight_compute_builder_; +}; + +class StreamImpl : public Stream { + public: + explicit StreamImpl(MTLCommandQueue *command_queue) + : command_queue_(command_queue) { + } + + std::unique_ptr new_command_list() override { + auto cb = new_command_buffer(command_queue_); + TI_ASSERT(cb != nullptr); + set_label(cb.get(), fmt::format("command_buffer_{}", list_counter_++)); + return std::make_unique(std::move(cb)); + } + + void submit(CommandList *cmdlist) override { + auto *cb = static_cast(cmdlist)->command_buffer(); + commit_command_buffer(cb); + } + void submit_synced(CommandList *cmdlist) override { + auto *cb = static_cast(cmdlist)->command_buffer(); + commit_command_buffer(cb); + wait_until_completed(cb); + } + + void command_sync() override { + // No-op on Metal + } + + private: + MTLCommandQueue *const command_queue_; + uint32_t list_counter_{0}; +}; + +class DeviceImpl : public Device { + public: + explicit DeviceImpl(const ComputeDeviceParams ¶ms) + : device_(params.device), mem_pool_(params.mem_pool) { + command_queue_ = new_command_queue(device_); + TI_ASSERT(command_queue_ != nullptr); + // TODO: thread local streams? + stream_ = std::make_unique(command_queue_.get()); + TI_ASSERT(stream_ != nullptr); + } + + DeviceAllocation allocate_memory(const AllocParams ¶ms) override { + DeviceAllocation res; + res.device = this; + res.alloc_id = allocations_.size(); + + AllocationInternal &ialloc = + allocations_[res.alloc_id]; // "i" for internal + auto mem = std::make_unique(params.size, mem_pool_); + ialloc.buffer = new_mtl_buffer_no_copy(device_, mem->ptr(), mem->size()); + ialloc.buffer_mem = std::move(mem); + return res; + } + + void dealloc_memory(DeviceAllocation handle) override { + allocations_.erase(handle.alloc_id); + TI_NOT_IMPLEMENTED; + } + + std::unique_ptr create_pipeline(const PipelineSourceDesc &src, + std::string name) override { + TI_ASSERT(src.type == PipelineSourceType::metal_src); + TI_ASSERT(src.stage == PipelineStageType::compute); + // FIXME: infer version/fast_math + std::string src_code{static_cast(src.data), src.size}; + auto kernel_lib = new_library_with_source( + device_, src_code, /*fast_math=*/false, kMslVersionNone); + TI_ASSERT(kernel_lib != nullptr); + auto mtl_func = new_function_with_name(kernel_lib.get(), name); + TI_ASSERT(mtl_func != nullptr); + auto pipeline = + new_compute_pipeline_state_with_function(device_, mtl_func.get()); + TI_ASSERT(pipeline != nullptr); + return std::make_unique(std::move(pipeline)); + } + + void *map_range(DevicePtr ptr, uint64_t size) override { + auto *mem = find_buffer_mem(ptr.alloc_id); + if (!mem) { + return nullptr; + } + if ((ptr.offset + size) > mem->size()) { + TI_ERROR("Range exceeded"); + return nullptr; + } + return (mem->ptr() + ptr.offset); + } + + void *map(DeviceAllocation alloc) override { + auto *mem = find_buffer_mem(alloc.alloc_id); + if (!mem) { + return nullptr; + } + return mem->ptr(); + } + + void unmap(DevicePtr ptr) override { + // No-op on Metal + } + void unmap(DeviceAllocation alloc) override { + // No-op on Metal + } + + void memcpy_internal(DevicePtr dst, DevicePtr src, uint64_t size) override { + TI_NOT_IMPLEMENTED; + } + + Stream *get_compute_stream() override { + return stream_.get(); + } + + private: + const BufferMemoryView *find_buffer_mem(DeviceAllocationId id) const { + auto itr = allocations_.find(id); + if (itr == allocations_.end()) { + return nullptr; + } + return itr->second.buffer_mem.get(); + } + + struct AllocationInternal { + std::unique_ptr buffer_mem{nullptr}; + nsobj_unique_ptr buffer{nullptr}; + }; + + MTLDevice *const device_; + MemoryPool *const mem_pool_; + nsobj_unique_ptr command_queue_{nullptr}; + std::unique_ptr stream_{nullptr}; + std::unordered_map allocations_; +}; + +} // namespace + +std::unique_ptr make_compute_device( + const ComputeDeviceParams ¶ms) { + return std::make_unique(params); +} + +#else + +std::unique_ptr make_compute_device( + const ComputeDeviceParams ¶ms) { + TI_ERROR("Platform does not support Metal"); + return nullptr; +} + +#endif // TI_PLATFORM_OSX + +} // namespace metal +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/metal/device.h b/taichi/backends/metal/device.h new file mode 100644 index 0000000000000..2f4c2544ea20b --- /dev/null +++ b/taichi/backends/metal/device.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "taichi/backends/device.h" + +namespace taichi { +namespace lang { + +class MemoryPool; + +namespace metal { + +struct MTLDevice; + +struct ComputeDeviceParams { + MTLDevice *device{nullptr}; + MemoryPool *mem_pool{nullptr}; +}; + +std::unique_ptr make_compute_device( + const ComputeDeviceParams ¶ms); + +} // namespace metal +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/metal/kernel_manager.cpp b/taichi/backends/metal/kernel_manager.cpp index 35cab785a1333..7334ca937cc27 100644 --- a/taichi/backends/metal/kernel_manager.cpp +++ b/taichi/backends/metal/kernel_manager.cpp @@ -10,6 +10,7 @@ #include "taichi/backends/metal/constants.h" #include "taichi/backends/metal/features.h" +#include "taichi/backends/metal/runtime_utils.h" #include "taichi/inc/constants.h" #include "taichi/math/arithmetic.h" #include "taichi/program/py_print_buffer.h" @@ -49,36 +50,6 @@ inline int infer_msl_version(const TaichiKernelAttributes::UsedFeatures &f) { return kMslVersionNone; } -// This class requests the Metal buffer memory of |size| bytes from |mem_pool|. -// Once allocated, it does not own the memory (hence the name "view"). Instead, -// GC is deferred to the memory pool. -class BufferMemoryView { - public: - BufferMemoryView(size_t size, MemoryPool *mem_pool) { - // Both |ptr_| and |size_| must be aligned to page size. - size_ = iroundup(size, taichi_page_size); - ptr_ = (char *)mem_pool->allocate(size_, /*alignment=*/taichi_page_size); - TI_ASSERT(ptr_ != nullptr); - std::memset(ptr_, 0, size_); - } - // Move only - BufferMemoryView(BufferMemoryView &&) = default; - BufferMemoryView &operator=(BufferMemoryView &&) = default; - BufferMemoryView(const BufferMemoryView &) = delete; - BufferMemoryView &operator=(const BufferMemoryView &) = delete; - - inline size_t size() const { - return size_; - } - inline char *ptr() const { - return ptr_; - } - - private: - size_t size_; - char *ptr_; -}; - // MetalRuntime maintains a series of MTLBuffers that are shared across all the // Metal kernels mapped by a single Taichi kernel. This map stores those buffers // from their enum. Each CompiledMtlKernelBase can then decide which specific @@ -406,7 +377,7 @@ class CompiledTaichiKernel { class HostMetalCtxBlitter { public: HostMetalCtxBlitter(const CompiledTaichiKernel &kernel, - Context *host_ctx, + RuntimeContext *host_ctx, uint64_t *host_result_buffer, const std::string &kernel_name) : ti_kernel_attribs_(&kernel.ti_kernel_attribs), @@ -472,9 +443,10 @@ class HostMetalCtxBlitter { } void metal_to_host() { -#define TO_HOST(type) \ - const type d = *reinterpret_cast(device_ptr); \ - host_result_buffer_[i] = taichi_union_cast_with_different_sizes(d); +#define TO_HOST(type, offset) \ + const type d = *(reinterpret_cast(device_ptr) + offset); \ + host_result_buffer_[offset] = \ + taichi_union_cast_with_different_sizes(d); if (ctx_attribs_->empty()) { return; @@ -508,25 +480,24 @@ class HostMetalCtxBlitter { // *arg* on the host context. const auto &ret = ctx_attribs_->rets()[i]; char *device_ptr = base + ret.offset_in_mem; - if (ret.is_array) { - void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(host_ptr, device_ptr, ret.stride); - } else { + const int dt_bytes = metal_data_type_bytes(ret.dt); + const int num = ret.stride / dt_bytes; + for (int j = 0; j < num; ++j) { const auto dt = ret.dt; if (dt == MetalDataType::i32) { - TO_HOST(int32); + TO_HOST(int32, j); } else if (dt == MetalDataType::u32) { - TO_HOST(uint32); + TO_HOST(uint32, j); } else if (dt == MetalDataType::f32) { - TO_HOST(float32); + TO_HOST(float32, j); } else if (dt == MetalDataType::i8) { - TO_HOST(int8); + TO_HOST(int8, j); } else if (dt == MetalDataType::i16) { - TO_HOST(int16); + TO_HOST(int16, j); } else if (dt == MetalDataType::u8) { - TO_HOST(uint8); + TO_HOST(uint8, j); } else if (dt == MetalDataType::u16) { - TO_HOST(uint16); + TO_HOST(uint16, j); } else { TI_ERROR("Metal does not support return value type={}", metal_data_type_name(ret.dt)); @@ -538,7 +509,7 @@ class HostMetalCtxBlitter { static std::unique_ptr maybe_make( const CompiledTaichiKernel &kernel, - Context *ctx, + RuntimeContext *ctx, uint64_t *host_result_buffer, std::string name) { if (kernel.ctx_attribs.empty()) { @@ -551,7 +522,7 @@ class HostMetalCtxBlitter { private: const TaichiKernelAttributes *const ti_kernel_attribs_; const KernelContextAttributes *const ctx_attribs_; - Context *const host_ctx_; + RuntimeContext *const host_ctx_; uint64_t *const host_result_buffer_; BufferMemoryView *const kernel_ctx_mem_; MTLBuffer *const kernel_ctx_buffer_; @@ -620,13 +591,13 @@ class KernelManager::Impl { print_mem_->size()); TI_ASSERT(print_buffer_ != nullptr); - init_runtime_buffer(compiled_runtime_module_); + init_runtime_buffer(compiled_runtime_module_, params.config->random_seed); clear_print_assert_buffer(); } void add_compiled_snode_tree(const CompiledStructs &compiled_tree) { SNodesRootBuffer rtbuf{}; - rtbuf.desc = BufferDescriptor::Root(compiled_tree.root_id); + rtbuf.desc = BufferDescriptor::root(compiled_tree.root_id); if (compiled_tree.root_size > 0) { rtbuf.mem = std::make_unique(compiled_tree.root_size, mem_pool_); @@ -676,7 +647,7 @@ class KernelManager::Impl { } void launch_taichi_kernel(const std::string &taichi_kernel_name, - Context *ctx) { + RuntimeContext *ctx) { mac::ScopedAutoreleasePool pool; auto &ctk = *compiled_taichi_kernels_.find(taichi_kernel_name)->second; auto ctx_blitter = HostMetalCtxBlitter::maybe_make( @@ -689,13 +660,13 @@ class KernelManager::Impl { for (auto &rb : root_buffers_) { input_buffers[rb.desc] = rb.buffer.get(); } - input_buffers[BufferDescriptor::GlobalTmps()] = global_tmps_buffer_.get(); - input_buffers[BufferDescriptor::Runtime()] = runtime_buffer_.get(); - input_buffers[BufferDescriptor::Print()] = print_buffer_.get(); + input_buffers[BufferDescriptor::global_tmps()] = global_tmps_buffer_.get(); + input_buffers[BufferDescriptor::runtime()] = runtime_buffer_.get(); + input_buffers[BufferDescriptor::print()] = print_buffer_.get(); if (ctx_blitter) { ctx_blitter->host_to_metal(); - input_buffers[BufferDescriptor::Context()] = ctk.ctx_buffer.get(); + input_buffers[BufferDescriptor::context()] = ctk.ctx_buffer.get(); } for (const auto &mk : ctk.compiled_mtl_kernels) { @@ -756,14 +727,11 @@ class KernelManager::Impl { } private: - void init_runtime_buffer(const CompiledRuntimeModule &rtm_module) { + void init_runtime_buffer(const CompiledRuntimeModule &rtm_module, + int random_seed) { char *addr = runtime_mem_->ptr(); // init rand_seeds - // TODO(k-ye): Provide a way to use a fixed seed in dev mode. - std::mt19937 generator( - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); + std::default_random_engine generator((unsigned int)random_seed); std::uniform_int_distribution distr( 0, std::numeric_limits::max()); for (int i = 0; i < kNumRandSeeds; ++i) { @@ -997,11 +965,13 @@ class KernelManager::Impl { end_encoding(encoder.get()); } // Sync - profiler_->start("metal_synchronize"); + if (profiler_) + profiler_->start("metal_synchronize"); commit_command_buffer(cur_command_buffer_.get()); wait_until_completed(cur_command_buffer_.get()); create_new_command_buffer(); - profiler_->stop(); + if (profiler_) + profiler_->stop(); // print_runtime_debug(); } @@ -1179,7 +1149,7 @@ class KernelManager::Impl { } void launch_taichi_kernel(const std::string &taichi_kernel_name, - Context *ctx) { + RuntimeContext *ctx) { TI_ERROR("Metal not supported on the current OS"); } @@ -1221,7 +1191,7 @@ void KernelManager::register_taichi_kernel( } void KernelManager::launch_taichi_kernel(const std::string &taichi_kernel_name, - Context *ctx) { + RuntimeContext *ctx) { impl_->launch_taichi_kernel(taichi_kernel_name, ctx); } diff --git a/taichi/backends/metal/kernel_manager.h b/taichi/backends/metal/kernel_manager.h index 75f82bb8e7288..92c441306af5b 100644 --- a/taichi/backends/metal/kernel_manager.h +++ b/taichi/backends/metal/kernel_manager.h @@ -15,7 +15,7 @@ namespace taichi { namespace lang { -struct Context; +struct RuntimeContext; namespace metal { @@ -55,7 +55,7 @@ class KernelManager { // Kernel launching is asynchronous, therefore the Metal memory is not valid // to access until after a synchronize() call. void launch_taichi_kernel(const std::string &taichi_kernel_name, - Context *ctx); + RuntimeContext *ctx); // Synchronize the memory content from Metal to host (x86_64). void synchronize(); diff --git a/taichi/backends/metal/kernel_utils.cpp b/taichi/backends/metal/kernel_utils.cpp index c1cba5bbb5f5e..ac59811f999c5 100644 --- a/taichi/backends/metal/kernel_utils.cpp +++ b/taichi/backends/metal/kernel_utils.cpp @@ -76,7 +76,7 @@ std::string KernelAttributes::debug_string() const { } KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) - : ctx_bytes_(0), extra_args_bytes_(Context::extra_args_size) { + : ctx_bytes_(0), extra_args_bytes_(RuntimeContext::extra_args_size) { arg_attribs_vec_.reserve(kernel.args.size()); for (const auto &ka : kernel.args) { ArgAttributes ma; @@ -87,22 +87,37 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) TI_ERROR("Metal kernel only supports <= 32-bit data, got {}", metal_data_type_name(ma.dt)); } - ma.is_array = ka.is_external_array; + ma.is_array = ka.is_array; ma.stride = ma.is_array ? ka.size : dt_bytes; ma.index = arg_attribs_vec_.size(); arg_attribs_vec_.push_back(ma); } for (const auto &kr : kernel.rets) { RetAttributes mr; - mr.dt = to_metal_type(kr.dt); - const size_t dt_bytes = metal_data_type_bytes(mr.dt); - if (dt_bytes > 4) { - // Metal doesn't support 64bit data buffers. - TI_ERROR("Metal kernel only supports <= 32-bit data, got {}", - metal_data_type_name(mr.dt)); + if (auto tensor_type = kr.dt->cast()) { + mr.dt = to_metal_type(tensor_type->get_element_type()); + const size_t dt_bytes = metal_data_type_bytes(mr.dt); + mr.is_array = true; + if (dt_bytes > 4) { + // Metal doesn't support 64bit data buffers. + TI_ERROR( + "Metal kernel only supports <= 32-bit data, got {} which is " + "Tensor's element type", + metal_data_type_name(mr.dt)); + } + mr.stride = + tensor_type->get_num_elements() * metal_data_type_bytes(mr.dt); + } else { + mr.dt = to_metal_type(kr.dt); + const size_t dt_bytes = metal_data_type_bytes(mr.dt); + mr.is_array = false; + if (dt_bytes > 4) { + // Metal doesn't support 64bit data buffers. + TI_ERROR("Metal kernel only supports <= 32-bit data, got {}", + metal_data_type_name(mr.dt)); + } + mr.stride = metal_data_type_bytes(mr.dt); } - mr.is_array = false; // TODO(#909): this is a temporary limitation - mr.stride = dt_bytes; mr.index = ret_attribs_vec_.size(); ret_attribs_vec_.push_back(mr); } @@ -120,12 +135,17 @@ KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) // Put scalar args in the memory first for (int i : scalar_indices) { auto &attribs = (*vec)[i]; + const size_t dt_bytes = metal_data_type_bytes(attribs.dt); + // Align bytes to the nearest multiple of dt_bytes + bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes; attribs.offset_in_mem = bytes; bytes += attribs.stride; } // Then the array args for (int i : array_indices) { auto &attribs = (*vec)[i]; + const size_t dt_bytes = metal_data_type_bytes(attribs.dt); + bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes; attribs.offset_in_mem = bytes; bytes += attribs.stride; } diff --git a/taichi/backends/metal/kernel_utils.h b/taichi/backends/metal/kernel_utils.h index e48bbf956a251..5138a374e36e6 100644 --- a/taichi/backends/metal/kernel_utils.h +++ b/taichi/backends/metal/kernel_utils.h @@ -45,23 +45,23 @@ struct BufferDescriptor { BufferDescriptor() = default; - static BufferDescriptor Root(int root_id) { + static BufferDescriptor root(int root_id) { return BufferDescriptor{Type::Root, root_id}; } - static BufferDescriptor GlobalTmps() { + static BufferDescriptor global_tmps() { return BufferDescriptor{Type::GlobalTmps}; } - static BufferDescriptor Context() { + static BufferDescriptor context() { return BufferDescriptor{Type::Context}; } - static BufferDescriptor Runtime() { + static BufferDescriptor runtime() { return BufferDescriptor{Type::Runtime}; } - static BufferDescriptor Print() { + static BufferDescriptor print() { return BufferDescriptor{Type::Print}; } @@ -105,7 +105,7 @@ struct BufferDescriptor { struct KernelAttributes { std::string name; // Total number of threads to launch (i.e. threads per grid). Note that this - // is only advisory, because eventually this numb er is also determined by the + // is only advisory, because eventually this number is also determined by the // runtime config. This works because grid strided loop is supported. int advisory_total_num_threads; // Block size in CUDA's terminology. On Metal, it is called a threadgroup. @@ -280,6 +280,7 @@ struct CompiledFieldData { MetalDataType dtype; std::string dtype_name; std::vector shape; + int mem_offset_in_parent{0}; bool is_scalar{false}; int row_num{0}; int column_num{0}; @@ -288,6 +289,7 @@ struct CompiledFieldData { dtype, dtype_name, shape, + mem_offset_in_parent, is_scalar, row_num, column_num); diff --git a/taichi/backends/metal/metal_program.cpp b/taichi/backends/metal/metal_program.cpp index 21623f0c5cfb8..4daef0d0eac7e 100644 --- a/taichi/backends/metal/metal_program.cpp +++ b/taichi/backends/metal/metal_program.cpp @@ -1,9 +1,46 @@ +#include + #include "metal_program.h" #include "taichi/backends/metal/codegen_metal.h" #include "taichi/backends/metal/struct_metal.h" namespace taichi { namespace lang { +namespace { + +std::unordered_set find_all_dense_snodes( + const metal::SNodeDescriptorsMap &snodes_map) { + std::unordered_set res; + for (const auto [_, desc] : snodes_map) { + const auto *sn = desc.snode; + if (sn->type == SNodeType::dense) { + res.insert(sn); + } + } + return res; +} + +bool all_fields_are_dense( + const std::unordered_set &placed_snodes) { + for (const auto *sn : placed_snodes) { + for (const auto &ch : sn->ch) { + if (ch->type != SNodeType::place) { + return false; + } + } + const auto *parent = sn->parent; + if (!parent) { + return false; + } + if (parent->type != SNodeType::root) { + return false; + } + } + return true; +} + +} // namespace + MetalProgramImpl::MetalProgramImpl(CompileConfig &config_) : ProgramImpl(config_) { } @@ -43,25 +80,40 @@ void MetalProgramImpl::materialize_runtime(MemoryPool *memory_pool, metal_kernel_mgr_ = std::make_unique(std::move(params)); } +void MetalProgramImpl::compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) { + (void)compile_snode_tree_types_impl(tree); +} + void MetalProgramImpl::materialize_snode_tree( SNodeTree *tree, std::vector> &, - std::unordered_map &, uint64 *result_buffer) { - // TODO: support materializing multiple snode trees - TI_ASSERT_INFO(config->use_llvm, - "Metal arch requires that LLVM being enabled"); - auto *const root = tree->root(); - auto csnode_tree = metal::compile_structs(*root); + const auto &csnode_tree = compile_snode_tree_types_impl(tree); metal_kernel_mgr_->add_compiled_snode_tree(csnode_tree); - compiled_snode_trees_.push_back(std::move(csnode_tree)); } std::unique_ptr MetalProgramImpl::make_aot_module_builder() { + TI_ERROR_IF(compiled_snode_trees_.size() > 1, + "AOT: only supports one SNodeTree"); + const auto fields = + find_all_dense_snodes(compiled_snode_trees_[0].snode_descriptors); + TI_ERROR_IF(!all_fields_are_dense(fields), "AOT: only supports dense field"); return std::make_unique( - &(compiled_runtime_module_.value()), compiled_snode_trees_, + &(compiled_runtime_module_.value()), compiled_snode_trees_, fields, metal_kernel_mgr_->get_buffer_meta_data()); } +const metal::CompiledStructs &MetalProgramImpl::compile_snode_tree_types_impl( + SNodeTree *tree) { + TI_ASSERT_INFO(config->use_llvm, + "Metal arch requires that LLVM being enabled"); + auto *const root = tree->root(); + auto csnode_tree = metal::compile_structs(*root); + compiled_snode_trees_.push_back(std::move(csnode_tree)); + return compiled_snode_trees_.back(); +} + } // namespace lang } // namespace taichi diff --git a/taichi/backends/metal/metal_program.h b/taichi/backends/metal/metal_program.h index 7d9ca9da5beef..ce58b89ed1216 100644 --- a/taichi/backends/metal/metal_program.h +++ b/taichi/backends/metal/metal_program.h @@ -30,10 +30,13 @@ class MetalProgramImpl : public ProgramImpl { KernelProfilerBase *profiler, uint64 **result_buffer_ptr) override; + void compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) override; + void materialize_snode_tree( SNodeTree *tree, std::vector> &snode_trees_, - std::unordered_map &snodes, uint64 *result_buffer) override; void synchronize() override { @@ -47,6 +50,8 @@ class MetalProgramImpl : public ProgramImpl { std::unique_ptr make_aot_module_builder() override; private: + const metal::CompiledStructs &compile_snode_tree_types_impl(SNodeTree *tree); + std::optional compiled_runtime_module_{ std::nullopt}; std::vector compiled_snode_trees_; diff --git a/taichi/backends/metal/runtime_utils.cpp b/taichi/backends/metal/runtime_utils.cpp new file mode 100644 index 0000000000000..d146c0357bbf1 --- /dev/null +++ b/taichi/backends/metal/runtime_utils.cpp @@ -0,0 +1,23 @@ +#include "taichi/backends/metal/runtime_utils.h" + +#include + +#include "taichi/inc/constants.h" +#include "taichi/math/arithmetic.h" +#include "taichi/system/memory_pool.h" + +namespace taichi { +namespace lang { +namespace metal { + +BufferMemoryView::BufferMemoryView(std::size_t size, MemoryPool *mem_pool) { + // Both |ptr_| and |size_| must be aligned to page size. + size_ = iroundup(size, taichi_page_size); + ptr_ = (char *)mem_pool->allocate(size_, /*alignment=*/taichi_page_size); + TI_ASSERT(ptr_ != nullptr); + std::memset(ptr_, 0, size_); +} + +} // namespace metal +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/metal/runtime_utils.h b/taichi/backends/metal/runtime_utils.h new file mode 100644 index 0000000000000..ad5121b16b044 --- /dev/null +++ b/taichi/backends/metal/runtime_utils.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +namespace taichi { +namespace lang { + +class MemoryPool; + +namespace metal { + +// This class requests the Metal buffer memory of |size| bytes from |mem_pool|. +// Once allocated, it does not own the memory (hence the name "view"). Instead, +// GC is deferred to the memory pool. +class BufferMemoryView { + public: + BufferMemoryView(std::size_t size, MemoryPool *mem_pool); + // Move only + BufferMemoryView(BufferMemoryView &&) = default; + BufferMemoryView &operator=(BufferMemoryView &&) = default; + BufferMemoryView(const BufferMemoryView &) = delete; + BufferMemoryView &operator=(const BufferMemoryView &) = delete; + + inline size_t size() const { + return size_; + } + inline char *ptr() const { + return ptr_; + } + + private: + std::size_t size_{0}; + char *ptr_{nullptr}; +}; + +} // namespace metal +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/metal/shaders/init_randseeds.metal.h b/taichi/backends/metal/shaders/init_randseeds.metal.h new file mode 100644 index 0000000000000..b0b4f6e2b0d58 --- /dev/null +++ b/taichi/backends/metal/shaders/init_randseeds.metal.h @@ -0,0 +1,30 @@ +#include "taichi/backends/metal/shaders/prolog.h" + +#ifdef TI_INSIDE_METAL_CODEGEN + +#ifndef TI_METAL_NESTED_INCLUDE +#define METAL_BEGIN_INIT_RANDSEEDS_DEF \ + constexpr auto kMetalInitRandseedsSourceCode = +#define METAL_END_INIT_RANDSEEDS_DEF ; +#else +#define METAL_BEGIN_INIT_RANDSEEDS_DEF +#define METAL_END_INIT_RANDSEEDS_DEF +#endif // TI_METAL_NESTED_INCLUDE + +#else + +#define METAL_BEGIN_INIT_RANDSEEDS_DEF +#define METAL_END_INIT_RANDSEEDS_DEF + +#endif // TI_INSIDE_METAL_CODEGEN + +METAL_BEGIN_INIT_RANDSEEDS_DEF +STR([[maybe_unused]] void mtl_init_random_seeds( + device uint32_t *rand_seeds, + const uint thread_position_in_grid, + const uint threads_per_grid) { + for (int ii = thread_position_in_grid; ii < threads_per_grid; ++ii) { + rand_seeds[ii] += ii; + } +}) +METAL_END_INIT_RANDSEEDS_DEF diff --git a/taichi/backends/metal/struct_metal.cpp b/taichi/backends/metal/struct_metal.cpp index 144153fd536cc..15ef3ab10bb48 100644 --- a/taichi/backends/metal/struct_metal.cpp +++ b/taichi/backends/metal/struct_metal.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -257,7 +258,14 @@ class StructCompiler { emit("struct {} {{", node_name); const auto snty_name = snode_type_name(snty); emit(" // {}", snty_name); - const int n = snode.num_cells_per_container; + const int64 n = snode.num_cells_per_container; + // There's no assert in metal shading language yet so we have to warn + // outside. + if (n > std::numeric_limits::max()) { + TI_WARN( + "Snode index might be out of int32 boundary but int64 is not " + "supported on metal backend."); + } emit(" constant static constexpr int n = {};", n); emit_snode_stride(snty, ch_name, n); emit_snode_constructor(snode); diff --git a/taichi/backends/opengl/aot_data.h b/taichi/backends/opengl/aot_data.h index 6017efb698f5a..076fdb8fa349f 100644 --- a/taichi/backends/opengl/aot_data.h +++ b/taichi/backends/opengl/aot_data.h @@ -9,42 +9,10 @@ namespace taichi { namespace lang { namespace opengl { -struct AotCompiledKernel { - CompiledProgram program; - std::string identifier; - - TI_IO_DEF(program, identifier); -}; - -struct AotCompiledKernelTmpl { - std::unordered_map program; - std::string identifier; - - TI_IO_DEF(program, identifier); -}; - -struct CompiledFieldData { - std::string field_name; - uint32_t dtype; - std::string dtype_name; - std::vector shape; - bool is_scalar{false}; - int row_num{0}; - int column_num{0}; - - TI_IO_DEF(field_name, - dtype, - dtype_name, - shape, - is_scalar, - row_num, - column_num); -}; - struct AotData { - std::vector kernels; - std::vector kernel_tmpls; - std::vector fields; + std::unordered_map kernels; + std::unordered_map kernel_tmpls; + std::vector fields; size_t root_buffer_size; diff --git a/taichi/backends/opengl/aot_module_builder_impl.cpp b/taichi/backends/opengl/aot_module_builder_impl.cpp index d94ca8665c30f..2e13bed36cc43 100644 --- a/taichi/backends/opengl/aot_module_builder_impl.cpp +++ b/taichi/backends/opengl/aot_module_builder_impl.cpp @@ -1,88 +1,175 @@ #include "taichi/backends/opengl/aot_module_builder_impl.h" -#include "glad/glad.h" + +#include "taichi/aot/module_data.h" +#include "taichi/backends/opengl/opengl_utils.h" + +#if !defined(TI_PLATFORM_WINDOWS) +#include +#endif namespace taichi { namespace lang { namespace opengl { +namespace { + +class AotDataConverter { + public: + static aot::ModuleData convert(const opengl::AotData &in) { + AotDataConverter c{}; + return c.visit(in); + } + + private: + explicit AotDataConverter() = default; + + aot::ModuleData visit(const opengl::AotData &in) const { + aot::ModuleData res{}; + for (const auto &[key, val] : in.kernels) { + res.kernels[key] = visit(val); + } + for (const auto &[key, val] : in.kernel_tmpls) { + res.kernel_tmpls[key] = visit(val); + } + res.fields = in.fields; + res.root_buffer_size = in.root_buffer_size; + return res; + } + + aot::CompiledTaichiKernel visit( + const opengl::CompiledTaichiKernel &in) const { + aot::CompiledTaichiKernel res{}; + res.tasks.reserve(in.tasks.size()); + for (const auto &t : in.tasks) { + res.tasks.push_back(visit(t)); + } + res.args_count = in.arg_count; + res.rets_count = in.ret_count; + res.args_buffer_size = in.args_buf_size; + res.rets_buffer_size = in.ret_buf_size; + for (const auto &[arg_id, val] : in.scalar_args) { + res.scalar_args[arg_id] = visit(val); + } + for (const auto &[arg_id, val] : in.arr_args) { + aot::ArrayArg out_arr = visit(val); + out_arr.bind_index = in.used.arr_arg_to_bind_idx.at(arg_id); + res.arr_args[arg_id] = out_arr; + } + return res; + } + + aot::CompiledOffloadedTask visit( + const opengl::CompiledOffloadedTask &in) const { + aot::CompiledOffloadedTask res{}; + res.type = offloaded_task_type_name(in.type); + res.name = in.name; + res.source_path = in.src; + res.range_hint = in.range_hint; + res.gpu_block_size = in.workgroup_size; + return res; + } + + aot::ScalarArg visit(const opengl::ScalarArg &in) const { + aot::ScalarArg res{}; + res.dtype_name = in.dtype_name; + res.offset_in_args_buf = in.offset_in_bytes_in_args_buf; + return res; + } + + aot::ArrayArg visit(const opengl::CompiledArrayArg &in) const { + aot::ArrayArg res{}; + res.dtype_name = in.dtype_name; + res.field_dim = in.field_dim; + res.element_shape = in.element_shape; + res.shape_offset_in_args_buf = in.shape_offset_in_bytes_in_args_buf; + return res; + } +}; + +void write_glsl_file(const std::string &output_dir, CompiledOffloadedTask &t) { + const std::string glsl_path = fmt::format("{}/{}.glsl", output_dir, t.name); + std::ofstream fs{glsl_path}; + fs << t.src; + t.src = glsl_path; + fs.close(); +} + +} // namespace AotModuleBuilderImpl::AotModuleBuilderImpl( - StructCompiledResult &compiled_structs) - : compiled_structs_(compiled_structs) { + StructCompiledResult &compiled_structs, + bool allow_nv_shader_extension) + : compiled_structs_(compiled_structs), + allow_nv_shader_extension_(allow_nv_shader_extension) { aot_data_.root_buffer_size = compiled_structs_.root_size; } void AotModuleBuilderImpl::dump(const std::string &output_dir, const std::string &filename) const { - const std::string bin_path = - fmt::format("{}/{}_metadata.tcb", output_dir, filename); + TI_WARN_IF(!filename.empty(), + "Filename prefix is ignored on opengl backend."); + const std::string bin_path = fmt::format("{}/metadata.tcb", output_dir); write_to_binary_file(aot_data_, bin_path); - // The txt file is mostly for debugging purpose. - const std::string txt_path = - fmt::format("{}/{}_metadata.txt", output_dir, filename); - TextSerializer ts; - ts("taichi aot data", aot_data_); - ts.write_to_file(txt_path); + // Json format doesn't support multiple line strings. + AotData aot_data_copy = aot_data_; + for (auto &k : aot_data_copy.kernels) { + for (auto &t : k.second.tasks) { + write_glsl_file(output_dir, t); + } + } + for (auto &k : aot_data_copy.kernel_tmpls) { + for (auto &t : k.second.tasks) { + write_glsl_file(output_dir, t); + } + } + auto aot_module_data = AotDataConverter::convert(aot_data_copy); + const std::string json_path = fmt::format("{}/metadata.json", output_dir); + aot_module_data.dump_json(json_path); } void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, Kernel *kernel) { - opengl::OpenglCodeGen codegen(kernel->name, &compiled_structs_); + opengl::OpenglCodeGen codegen(kernel->name, &compiled_structs_, + allow_nv_shader_extension_); auto compiled = codegen.compile(*kernel); - aot_data_.kernels.push_back({compiled, identifier}); + aot_data_.kernels.insert(std::make_pair(identifier, std::move(compiled))); } -void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier, +void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, int row_num, int column_num) { - uint32_t gl_dtype_enum; - - if (dt == PrimitiveType::u64) { - gl_dtype_enum = GL_UNSIGNED_INT64_ARB; - } else if (dt == PrimitiveType::i64) { - gl_dtype_enum = GL_INT64_ARB; - } else if (dt == PrimitiveType::u32) { - gl_dtype_enum = GL_UNSIGNED_INT; - } else if (dt == PrimitiveType::i32) { - gl_dtype_enum = GL_INT; - } else if (dt == PrimitiveType::u16) { - gl_dtype_enum = GL_UNSIGNED_SHORT; - } else if (dt == PrimitiveType::i16) { - gl_dtype_enum = GL_SHORT; - } else if (dt == PrimitiveType::u8) { - gl_dtype_enum = GL_UNSIGNED_BYTE; - } else if (dt == PrimitiveType::i8) { - gl_dtype_enum = GL_BYTE; - } else if (dt == PrimitiveType::f64) { - gl_dtype_enum = GL_DOUBLE; - } else if (dt == PrimitiveType::f32) { - gl_dtype_enum = GL_FLOAT; - } + uint32_t gl_dtype_enum = to_gl_dtype_enum(dt); - aot_data_.fields.push_back({identifier, gl_dtype_enum, dt.to_string(), shape, - is_scalar, row_num, column_num}); + // Note that currently we only support adding dense fields in AOT for all + // backends. In opengl backend we only error out when a non dense field is + // added to the aot module, but in metal backend we error out earlier when + // constructing aot module. Ideally we will unify this behavior but it doesn't + // matter too much for now. + TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent), + "AOT: only supports dense field"); + std::vector element_shape; + if (!is_scalar) { + element_shape = {row_num, column_num}; + } + aot_data_.fields.push_back( + {identifier, gl_dtype_enum, dt.to_string(), + compiled_structs_.snode_map.at(rep_snode->node_type_name) + .mem_offset_in_root, + shape, is_scalar, element_shape}); } void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) { - opengl::OpenglCodeGen codegen(kernel->name, &compiled_structs_); + opengl::OpenglCodeGen codegen(kernel->name, &compiled_structs_, + allow_nv_shader_extension_); auto compiled = codegen.compile(*kernel); - for (auto &k : aot_data_.kernel_tmpls) { - if (k.identifier == identifier) { - k.program.insert(std::make_pair(key, compiled)); - return; - } - } - - AotCompiledKernelTmpl tmpldata; - tmpldata.identifier = identifier; - tmpldata.program.insert(std::make_pair(key, compiled)); - - aot_data_.kernel_tmpls.push_back(std::move(tmpldata)); + aot_data_.kernel_tmpls.insert( + std::make_pair(identifier + "|" + key, std::move(compiled))); } } // namespace opengl diff --git a/taichi/backends/opengl/aot_module_builder_impl.h b/taichi/backends/opengl/aot_module_builder_impl.h index 55f631c3b56ee..b3ddd4c340067 100644 --- a/taichi/backends/opengl/aot_module_builder_impl.h +++ b/taichi/backends/opengl/aot_module_builder_impl.h @@ -3,7 +3,7 @@ #include #include -#include "taichi/program/aot_module_builder.h" +#include "taichi/aot/module_builder.h" #include "taichi/backends/opengl/aot_data.h" namespace taichi { @@ -12,14 +12,17 @@ namespace opengl { class AotModuleBuilderImpl : public AotModuleBuilder { public: - explicit AotModuleBuilderImpl(StructCompiledResult &compiled_structs); + explicit AotModuleBuilderImpl(StructCompiledResult &compiled_structs, + bool allow_nv_shader_extension); void dump(const std::string &output_dir, const std::string &filename) const override; protected: void add_per_backend(const std::string &identifier, Kernel *kernel) override; - void add_per_backend_field(const std::string &identifier, + + void add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, @@ -31,8 +34,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder { private: StructCompiledResult &compiled_structs_; - AotData aot_data_; + bool allow_nv_shader_extension_ = false; }; } // namespace opengl diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 2949884ac7fbe..c6bb01d871faa 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -19,28 +19,45 @@ namespace opengl { namespace { namespace shaders { +#define FOREACH_ARR_NAME(_) \ + _(arr0) \ + _(arr1) \ + _(arr2) \ + _(arr3) \ + _(arr4) \ + _(arr5) \ + _(arr6) \ + _(arr7) + #define TI_INSIDE_OPENGL_CODEGEN #include "taichi/backends/opengl/shaders/atomics_macro_f32.glsl.h" #include "taichi/backends/opengl/shaders/runtime.h" -#include "taichi/backends/opengl/shaders/listman.h" #include "taichi/backends/opengl/shaders/random.glsl.h" #include "taichi/backends/opengl/shaders/fast_pow.glsl.h" #include "taichi/backends/opengl/shaders/print.glsl.h" #include "taichi/backends/opengl/shaders/reduction.glsl.h" + +GENERATE_OPENGL_ATOMIC_F32(data); +GENERATE_OPENGL_ATOMIC_F32(gtmp); + +FOREACH_ARR_NAME(GENERATE_OPENGL_ATOMIC_F32); + +GENERATE_OPENGL_REDUCTION_FUNCTIONS(add, float); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(max, float); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(min, float); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(add, int); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(max, int); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(min, int); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(add, uint); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(max, uint); +GENERATE_OPENGL_REDUCTION_FUNCTIONS(min, uint); + #undef TI_INSIDE_OPENGL_CODEGEN +#undef FOREACH_ARR_NAME } // namespace shaders using irpass::ExternalPtrAccess; -int find_children_id(const SNode *snode) { - auto parent = snode->parent; - for (int i = 0; i < parent->ch.size(); i++) { - if (parent->ch[i].get() == snode) - return i; - } - TI_ERROR("Child not found in parent!"); -} - std::string opengl_atomic_op_type_cap_name(AtomicOpType type) { static std::map type_names; if (type_names.empty()) { @@ -59,14 +76,20 @@ std::string opengl_atomic_op_type_cap_name(AtomicOpType type) { return type_names[type]; } +#if !defined(TI_PLATFORM_WINDOWS) +#include +#endif + class KernelGen : public IRVisitor { public: KernelGen(Kernel *kernel, const StructCompiledResult *struct_compiled, - const std::string &kernel_name) + const std::string &kernel_name, + bool allows_nv_shader_ext) : kernel_(kernel), - kernel_name_(kernel_name), struct_compiled_(struct_compiled), + kernel_name_(kernel_name), + allows_nv_shader_ext_(allows_nv_shader_ext), root_snode_type_name_(struct_compiled->root_snode_type_name), glsl_kernel_prefix_(kernel_name) { compiled_program_.init_args(kernel); @@ -77,14 +100,15 @@ class KernelGen : public IRVisitor { private: const Kernel *kernel_; const StructCompiledResult *struct_compiled_; - std::string kernel_name_; - std::string root_snode_type_name_; - std::string glsl_kernel_prefix_; + const std::string kernel_name_; + const bool allows_nv_shader_ext_; + const std::string root_snode_type_name_; + const std::string glsl_kernel_prefix_; GetRootStmt *root_stmt_; int glsl_kernel_count_{0}; bool is_top_level_{true}; - CompiledProgram compiled_program_; + CompiledTaichiKernel compiled_program_; UsedFeature used; // TODO: is this actually per-offload? // per-offload variables: @@ -93,8 +117,9 @@ class KernelGen : public IRVisitor { std::string glsl_kernel_name_; int num_workgroups_{1}; int workgroup_size_{1}; - bool used_tls; // TODO: move into UsedFeature? - std::unordered_map extptr_access; + bool used_tls_; // TODO: move into UsedFeature? + std::unordered_map extptr_access_; + std::unordered_set loaded_args_; template void emit(std::string f, Args &&... args) { @@ -135,8 +160,36 @@ class KernelGen : public IRVisitor { return opengl::opengl_data_type_name(dt); } - void generate_bottom() { - // TODO(archibate): () really necessary? How about just main()? + std::string gen_layout_line(std::string dt, + std::string dtype, + std::string buf, + std::string bind_id) { + return fmt::format( + "layout(std430, binding = {}) buffer {}_{} {{ {} _{}_{}_[];}}; \n", + bind_id, buf, dt, dtype, buf, dt); + } + + std::string gen_buffer_registration(const UsedFeature &used, + std::string buf, + std::string bind_id) { + std::string res = ""; + if (used.int32) + res += gen_layout_line("i32", "int", buf, bind_id); + if (used.int64) + res += gen_layout_line("i64", "int64_t", buf, bind_id); + if (used.uint32) + res += gen_layout_line("u32", "uint", buf, bind_id); + if (used.uint64) + res += gen_layout_line("u64", "uint64_t", buf, bind_id); + if (used.float32) + res += gen_layout_line("f32", "float", buf, bind_id); + if (used.float64) + res += gen_layout_line("f64", "double", buf, bind_id); + return res; + } + + void generate_task_bottom(OffloadedTaskType task_type, + std::string range_hint) { emit("void main()"); emit("{{"); if (used.random) @@ -145,78 +198,68 @@ class KernelGen : public IRVisitor { emit(" {}();", glsl_kernel_name_); emit("}}"); - // clang-format off if (used.print) // the runtime buffer is only used for print now.. line_appender_header_.append_raw(shaders::kOpenGlRuntimeSourceCode); - if (used.listman) - line_appender_header_.append_raw(shaders::kOpenGLListmanSourceCode); std::string kernel_header; -#define DEFINE_LAYOUT(layout, restype, name, id, dt, dtype) \ - kernel_header += "layout("#layout", binding = " + fmt::format("{}", id) \ - + ") " #restype " " #name "_" #dt " { " #dtype " _" \ - #name "_" #dt "_[]; };\n" -#define REGISTER_BUFFER(layout, restype, name, id) do { \ - if (used.int32) DEFINE_LAYOUT(layout, restype, name, id, i32, int); \ - if (used.int64) DEFINE_LAYOUT(layout, restype, name, id, i64, int64_t); \ - if (used.uint32) DEFINE_LAYOUT(layout, restype, name, id, u32, uint); \ - if (used.uint64) DEFINE_LAYOUT(layout, restype, name, id, u64, uint64_t); \ - if (used.float32) DEFINE_LAYOUT(layout, restype, name, id, f32, float); \ - if (used.float64) DEFINE_LAYOUT(layout, restype, name, id, f64, double); \ - } while (0) - - REGISTER_BUFFER(std430, buffer, data, GLBufId::Root); + if (used.buf_data) + kernel_header += gen_buffer_registration( + used, "data", std::to_string(static_cast(GLBufId::Root))); if (used.buf_gtmp) - REGISTER_BUFFER(std430, buffer, gtmp, GLBufId::Gtmp); + kernel_header += gen_buffer_registration( + used, "gtmp", std::to_string(static_cast(GLBufId::Gtmp))); if (used.buf_args) - REGISTER_BUFFER(std430, readonly buffer, args, GLBufId::Args); - if (used.buf_retr) - REGISTER_BUFFER(std430, writeonly buffer, retr, GLBufId::Retr); - if (used.buf_extr) { - bool write = false; - bool read = false; - - for (auto pair : this->extptr_access) { - write |= (pair.second & irpass::ExternalPtrAccess::WRITE) != irpass::ExternalPtrAccess::NONE; - read |= (pair.second & irpass::ExternalPtrAccess::WRITE) != irpass::ExternalPtrAccess::NONE; - } - - if (write && !read) { - REGISTER_BUFFER(std430, writeonly buffer, extr, GLBufId::Extr); - } else if (!write && read) { - REGISTER_BUFFER(std430, readonly buffer, extr, GLBufId::Extr); - } else { - REGISTER_BUFFER(std430, buffer, extr, GLBufId::Extr); - } + kernel_header += gen_buffer_registration( + used, "args", std::to_string(static_cast(GLBufId::Args))); + for (auto [arr_id, bind_idx] : used.arr_arg_to_bind_idx) { + kernel_header += gen_buffer_registration( + used, "arr" + std::to_string(arr_id), std::to_string(bind_idx)); } -#undef REGISTER_BUFFER -#undef DEFINE_LAYOUT - // clang-format on - if (used.simulated_atomic_float) { - line_appender_header_.append_raw(shaders::kOpenGLAtomicF32SourceCode); - kernel_header += ("DEFINE_ATOMIC_F32_FUNCTIONS(data);\n"); + if (used.buf_data) { + kernel_header += shaders::kOpenGlAtomicF32Source_data; + } if (used.buf_gtmp) { - kernel_header += ("DEFINE_ATOMIC_F32_FUNCTIONS(gtmp);\n"); + kernel_header += shaders::kOpenGlAtomicF32Source_gtmp; } - if (used.buf_extr) { - kernel_header += ("DEFINE_ATOMIC_F32_FUNCTIONS(extr);\n"); + std::unordered_set arr_ids; + for ([[maybe_unused]] const auto [arr_id, bind_idx] : + used.arr_arg_to_bind_idx) { + arr_ids.insert(arr_id); } + +#define FOREACH_ARR_ID(_) \ + _(0) \ + _(1) \ + _(2) \ + _(3) \ + _(4) \ + _(5) \ + _(6) \ + _(7) + +#define ADD_ARR_ATOMIC_F32_SOURCE(id) \ + if (arr_ids.count(id)) { \ + kernel_header += shaders::kOpenGlAtomicF32Source_arr##id; \ + } + + FOREACH_ARR_ID(ADD_ARR_ATOMIC_F32_SOURCE); +#undef ADD_ARR_ATOMIC_F32_SOURCE +#undef FOREACH_ARR_ID } if (used.reduction) { line_appender_header_.append_raw(shaders::kOpenGLReductionCommon); - line_appender_header_.append_raw(shaders::kOpenGLReductionSourceCode); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(add, float);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(max, float);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(min, float);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(add, int);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(max, int);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(min, int);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(add, uint);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(max, uint);\n"); - kernel_header += ("DEFINE_REDUCTION_FUNCTIONS(min, uint);\n"); + kernel_header += shaders::kOpenGlReductionSource_add_float; + kernel_header += shaders::kOpenGlReductionSource_max_float; + kernel_header += shaders::kOpenGlReductionSource_min_float; + kernel_header += shaders::kOpenGlReductionSource_add_int; + kernel_header += shaders::kOpenGlReductionSource_max_int; + kernel_header += shaders::kOpenGlReductionSource_min_int; + kernel_header += shaders::kOpenGlReductionSource_add_uint; + kernel_header += shaders::kOpenGlReductionSource_max_uint; + kernel_header += shaders::kOpenGlReductionSource_min_uint; } line_appender_header_.append_raw(kernel_header); @@ -239,12 +282,17 @@ class KernelGen : public IRVisitor { #include "taichi/inc/opengl_extension.inc.h" #undef PER_OPENGL_EXTENSION auto kernel_src_code = - "#version 430 core\n" + extensions + "precision highp float;\n" + - line_appender_header_.lines() + line_appender_.lines(); - compiled_program_.add(std::move(glsl_kernel_name_), kernel_src_code, - num_workgroups_, workgroup_size_, - &this->extptr_access); + (is_gles() ? "#version 310 es\n" : "#version 430 core\n") + extensions + + "precision highp float;\n" + line_appender_header_.lines() + + line_appender_.lines(); auto &config = kernel_->program->config; + const int prescribed_block_dim = config.max_block_dim; + workgroup_size_ = prescribed_block_dim > 0 + ? std::min(workgroup_size_, prescribed_block_dim) + : workgroup_size_; + compiled_program_.add(std::move(glsl_kernel_name_), kernel_src_code, + task_type, range_hint, num_workgroups_, + workgroup_size_, &this->extptr_access_); if (config.print_kernel_llvm_ir) { static FileSequenceWriter writer("shader{:04d}.comp", "OpenGL compute shader"); @@ -253,7 +301,7 @@ class KernelGen : public IRVisitor { line_appender_header_.clear_all(); line_appender_.clear_all(); num_workgroups_ = 1; - num_workgroups_ = 1; + workgroup_size_ = 1; } void visit(Block *stmt) override { @@ -266,7 +314,7 @@ class KernelGen : public IRVisitor { line_appender_.pop_indent(); } - virtual void visit(Stmt *stmt) override { + void visit(Stmt *stmt) override { TI_ERROR("[glsl] unsupported statement type {}", typeid(*stmt).name()); } @@ -350,28 +398,20 @@ class KernelGen : public IRVisitor { stmt->input_index->short_name(), stmt->snode->node_type_name); if (stmt->activate) { - if (stmt->snode->type == SNodeType::dense) { - // do nothing - } else if (stmt->snode->type == SNodeType::dynamic) { - used.int32 = true; - emit("atomicMax(_data_i32_[{} >> 2], {} + 1); // dynamic activate", - get_snode_meta_address(stmt->snode), - stmt->input_index->short_name()); - } else { - TI_NOT_IMPLEMENTED - } + TI_ASSERT(stmt->snode->type == SNodeType::dense); } } + void visit(AssertStmt *stmt) override { + // TODO: do the actual assert + TI_WARN("Assert is not supported for OpenGL arch"); + } + void visit(SNodeOpStmt *stmt) override { // IAPR? if (stmt->op_type == SNodeOpType::activate) { if (stmt->snode->type == SNodeType::dense || stmt->snode->type == SNodeType::root) { // do nothing - } else if (stmt->snode->type == SNodeType::dynamic) { - used.int32 = true; - emit("atomicMax(_data_i32_[{} >> 2], {} + 1); // dynamic activate", - get_snode_meta_address(stmt->snode), stmt->val->short_name()); } else { TI_NOT_IMPLEMENTED } @@ -380,10 +420,6 @@ class KernelGen : public IRVisitor { if (stmt->snode->type == SNodeType::dense || stmt->snode->type == SNodeType::root) { // do nothing - } else if (stmt->snode->type == SNodeType::dynamic) { - used.int32 = true; - emit("_data_i32_[{} >> 2] = 0; // dynamic deactivate", - get_snode_meta_address(stmt->snode), stmt->val->short_name()); } else { TI_NOT_IMPLEMENTED } @@ -393,103 +429,139 @@ class KernelGen : public IRVisitor { if (stmt->snode->type == SNodeType::dense || stmt->snode->type == SNodeType::root) { emit("int {} = 1;", stmt->short_name()); - } else if (stmt->snode->type == SNodeType::dynamic) { - used.int32 = true; - emit("int {} = int({} < _data_i32_[{} >> 2]);", stmt->short_name(), - stmt->val->short_name(), get_snode_meta_address(stmt->snode)); } else { TI_NOT_IMPLEMENTED } - - } else if (stmt->op_type == SNodeOpType::append) { - TI_ASSERT(stmt->snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type->is_primitive(PrimitiveTypeID::i32)); - used.int32 = true; - emit("int {} = atomicAdd(_data_i32_[{} >> 2], 1);", stmt->short_name(), - get_snode_meta_address(stmt->snode)); - auto dt = stmt->val->element_type(); - emit("int _ad_{} = {} + {} * {};", stmt->short_name(), - get_snode_base_address(stmt->snode), stmt->short_name(), - struct_compiled_->snode_map.at(stmt->snode->node_type_name) - .elem_stride); - emit("_data_{}_[_ad_{} >> {}] = {};", opengl_data_type_short_name(dt), - stmt->short_name(), opengl_data_address_shifter(dt), - stmt->val->short_name()); - - } else if (stmt->op_type == SNodeOpType::length) { - TI_ASSERT(stmt->snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type->is_primitive(PrimitiveTypeID::i32)); - used.int32 = true; - emit("int {} = _data_i32_[{} >> 2];", stmt->short_name(), - get_snode_meta_address(stmt->snode)); - } else { TI_NOT_IMPLEMENTED } } - std::map ptr_signats; + std::map ptr_signats_; void visit(GetChStmt *stmt) override { + used.buf_data = true; emit("int {} = {} + {}; // {}", stmt->short_name(), stmt->input_ptr->short_name(), struct_compiled_->snode_map.at(stmt->input_snode->node_type_name) .children_offsets[stmt->chid], stmt->output_snode->node_type_name); if (stmt->output_snode->is_place()) - ptr_signats[stmt->id] = "data"; + ptr_signats_[stmt->id] = "data"; } void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); auto dt = stmt->val->element_type(); - emit("_{}_{}_[{} >> {}] = {};", - ptr_signats.at(stmt->dest->id), // throw out_of_range if not a pointer - opengl_data_type_short_name(dt), stmt->dest->short_name(), - opengl_data_address_shifter(dt), stmt->val->short_name()); + std::string index = stmt->dest->is() + ? stmt->dest->short_name() + : fmt::format("{} >> {}", stmt->dest->short_name(), + opengl_data_address_shifter(dt)); + + emit( + "_{}_{}_[{}] = {};", + ptr_signats_.at(stmt->dest->id), // throw out_of_range if not a pointer + opengl_data_type_short_name(dt), index, stmt->val->short_name()); } void visit(GlobalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); auto dt = stmt->element_type(); - emit("{} {} = _{}_{}_[{} >> {}];", - opengl_data_type_name(stmt->element_type()), stmt->short_name(), - ptr_signats.at(stmt->src->id), opengl_data_type_short_name(dt), - stmt->src->short_name(), opengl_data_address_shifter(dt)); + std::string index = stmt->src->is() + ? stmt->src->short_name() + : fmt::format("{} >> {}", stmt->src->short_name(), + opengl_data_address_shifter(dt)); + + emit("{} {} = _{}_{}_[{}];", opengl_data_type_name(stmt->element_type()), + stmt->short_name(), ptr_signats_.at(stmt->src->id), + opengl_data_type_short_name(dt), index); } void visit(ExternalPtrStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - const auto linear_index_name = fmt::format("_li_{}", stmt->short_name()); - emit("int {} = 0;", linear_index_name); - emit("{{ // linear seek"); - { - ScopedIndent _s(line_appender_); - const auto *argload = stmt->base_ptrs[0]->as(); - const int arg_id = argload->arg_id; - const int num_indices = stmt->indices.size(); - std::vector size_var_names; - for (int i = 0; i < num_indices; i++) { - used.buf_args = true; + const auto linear_index_name = stmt->short_name(); + const auto *argload = stmt->base_ptrs[0]->as(); + const int arg_id = argload->arg_id; + const int num_indices = stmt->indices.size(); + auto element_shape = stmt->element_shape; + std::vector size_var_names; + std::vector element_shape_size_var_names; + enum ExternalArrayLayout { layout_AOS = 0, layout_SOA = 1 }; + auto layout = stmt->element_dim <= 0 ? layout_AOS : layout_SOA; + + if (element_shape.size() > 0) { + int elem_beg = 0; + int elem_end = 0; + if (layout == layout_SOA) { + elem_beg = 0; + elem_end = element_shape.size(); + } else { + elem_beg = num_indices - element_shape.size(); + elem_end = num_indices; + } + for (int i = elem_beg; i < elem_end; i++) { used.int32 = true; - std::string var_name = fmt::format("_s{}_{}", i, stmt->short_name()); + std::string var_name = fmt::format("_s{}_{}{}", i, "arr", arg_id); + if (!loaded_args_.count(var_name)) { + emit("int {} = {};", var_name, element_shape[i - elem_beg]); + loaded_args_.insert(var_name); + } + element_shape_size_var_names.push_back(std::move(var_name)); + } + } + // Args buffer arrange dimensions from outer to inner + // AoS args buffer: array_shape|element_shape + // SoA args buffer: element_shape|array_shape + // + // ti.Matrix.ndarray(3, 2, ti.f32, (5, 4), layout=ti.Layout.AOS) + // args buffer: 5, 4, 3, 2 + // ti.Matrix.ndarray(3, 2, ti.f32, (5, 4), layout=ti.Layout.SOA) + // args buffer: 3, 2, 5, 4 + int ind_beg = 0; + int ind_end = 0; + if (layout == layout_SOA) { + ind_beg = element_shape.size(); + ind_end = num_indices; + } else { + ind_beg = 0; + ind_end = num_indices - element_shape.size(); + } + for (int i = ind_beg; i < ind_end; i++) { + used.buf_args = true; + used.int32 = true; + std::string var_name = fmt::format("_s{}_{}{}", i, "arr", arg_id); + + if (!loaded_args_.count(var_name)) { emit("int {} = _args_i32_[{} + {} * {} + {}];", var_name, - taichi_opengl_earg_base / sizeof(int), arg_id, + taichi_opengl_extra_args_base / sizeof(int), arg_id, taichi_max_num_indices, i); - size_var_names.push_back(std::move(var_name)); - } - for (int i = 0; i < num_indices; i++) { - emit("{} *= {};", linear_index_name, size_var_names[i]); - emit("{} += {};", linear_index_name, stmt->indices[i]->short_name()); + loaded_args_.insert(var_name); } + size_var_names.push_back(std::move(var_name)); } - emit("}}"); + // Arrange index stride and offsets in correct order + if (layout == layout_SOA) { + size_var_names.insert(size_var_names.begin(), + element_shape_size_var_names.begin(), + element_shape_size_var_names.end()); + } else { + size_var_names.insert(size_var_names.end(), + element_shape_size_var_names.begin(), + element_shape_size_var_names.end()); + } + + emit("int {} = {};", linear_index_name, + num_indices == 0 ? "0" : stmt->indices[0]->short_name()); - emit("int {} = {} + ({} << {});", stmt->short_name(), - stmt->base_ptrs[0]->short_name(), linear_index_name, - opengl_data_address_shifter(stmt->base_ptrs[0]->element_type())); - used.buf_extr = true; - ptr_signats[stmt->id] = "extr"; + for (int i = 1; i < num_indices; i++) { + emit("{} *= {};", linear_index_name, size_var_names[i]); + emit("{} += {};", linear_index_name, stmt->indices[i]->short_name()); + } + + ptr_signats_[stmt->id] = "arr" + std::to_string(arg_id); + } + + void visit(DecorationStmt *stmt) override { } void visit(UnaryOpStmt *stmt) override { @@ -592,6 +664,9 @@ class KernelGen : public IRVisitor { // floor(x / y) emit("{} {} = {} - {} * int({} / {});", dt_name, bin_name, lhs_name, rhs_name, lhs_name, rhs_name); + // FIXME: hack! doesn't make too much difference on mobile. + // emit("{} {} = {} & int({} - 1); // mod", dt_name, bin_name, lhs_name, + // rhs_name); return; } else if (bin->op_type == BinaryOpType::atan2) { if (bin->element_type() == @@ -679,19 +754,8 @@ class KernelGen : public IRVisitor { emit("{{ // Begin Atomic Op"); - if (dt->is_primitive(PrimitiveTypeID::i32) || - (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) && - dt->is_primitive(PrimitiveTypeID::i64)) || - ((stmt->op_type == AtomicOpType::add || - stmt->op_type == AtomicOpType::sub) && - ((TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float) && - dt->is_primitive(PrimitiveTypeID::f32)) || - (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float64) && - dt->is_primitive(PrimitiveTypeID::f64))))) { - emit("{} = {}(_{}_{}_[{} >> {}], {});", stmt->short_name(), - opengl_atomic_op_type_cap_name(stmt->op_type), - ptr_signats.at(stmt->dest->id), opengl_data_type_short_name(dt), - stmt->dest->short_name(), opengl_data_address_shifter(dt), val_name); + if (maybe_generate_fatomics_using_nv_ext(stmt, dt, val_name)) { + // Do nothing } else { if (dt != PrimitiveType::f32) { TI_ERROR( @@ -702,15 +766,53 @@ class KernelGen : public IRVisitor { } used.simulated_atomic_float = true; used.int32 = true; // since simulated atomics are based on _data_i32_ - emit("{} = {}_{}_{}({} >> {}, {});", stmt->short_name(), + std::string index = + stmt->dest->is() + ? stmt->dest->short_name() + : fmt::format("{} >> {}", stmt->dest->short_name(), + opengl_data_address_shifter(dt)); + emit("{} = {}_{}_{}({}, {});", stmt->short_name(), opengl_atomic_op_type_cap_name(stmt->op_type), - ptr_signats.at(stmt->dest->id), opengl_data_type_short_name(dt), - stmt->dest->short_name(), opengl_data_address_shifter(dt), val_name); + ptr_signats_.at(stmt->dest->id), opengl_data_type_short_name(dt), + index, val_name); } emit("}} // End Atomic Op"); } + bool maybe_generate_fatomics_using_nv_ext(AtomicOpStmt *stmt, + DataType dt, + const std::string &val_name) { + if (!allows_nv_shader_ext_ && !dt->is_primitive(PrimitiveTypeID::i32)) { + return false; + } + const bool check_int = + (dt->is_primitive(PrimitiveTypeID::i32) || + (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) && + dt->is_primitive(PrimitiveTypeID::i64))); + const bool check_add = (stmt->op_type == AtomicOpType::add || + stmt->op_type == AtomicOpType::sub); + const bool check_float = + ((TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float) && + dt->is_primitive(PrimitiveTypeID::f32)) || + (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float64) && + dt->is_primitive(PrimitiveTypeID::f64))); + if (check_int || (check_add && check_float)) { + std::string index = + stmt->dest->is() + ? stmt->dest->short_name() + : fmt::format("{} >> {}", stmt->dest->short_name(), + opengl_data_address_shifter(dt)); + + emit("{} = {}(_{}_{}_[{}], {});", stmt->short_name(), + opengl_atomic_op_type_cap_name(stmt->op_type), + ptr_signats_.at(stmt->dest->id), opengl_data_type_short_name(dt), + index, val_name); + return true; + } + return false; + } + void visit(TernaryOpStmt *tri) override { TI_ASSERT(tri->op_type == TernaryOpType::select); emit("{} {} = {} != 0 ? {} : {};", @@ -753,21 +855,30 @@ class KernelGen : public IRVisitor { } void visit(ReturnStmt *stmt) override { - used.buf_retr = true; + used.buf_args = true; // TODO: use stmt->ret_id instead of 0 as index - emit("_retr_{}_[0] = {};", - opengl_data_type_short_name(stmt->element_type()), - stmt->value->short_name()); + int idx{0}; + for (auto &value : stmt->values) { + emit("_args_{}_[({} >> {}) + {}] = {};", + opengl_data_type_short_name(value->element_type()), + taichi_opengl_ret_base, + opengl_data_address_shifter(value->element_type()), idx, + value->short_name()); + idx += (4 - opengl_data_address_shifter(value->element_type())); + // opengl only support i32, f32 and f64 array, but there are 64bit slots + // in taichi's result buffer,so we need two slots to make them match. + } } void visit(ArgLoadStmt *stmt) override { const auto dt = opengl_data_type_name(stmt->element_type()); - used.buf_args = true; if (stmt->is_ptr) { - used.int32 = true; - emit("int {} = _args_i32_[{} << 1]; // is ext pointer {}", - stmt->short_name(), stmt->arg_id, dt); + if (!used.arr_arg_to_bind_idx.count(stmt->arg_id)) { + used.arr_arg_to_bind_idx[stmt->arg_id] = + static_cast(GLBufId::Arr) + stmt->arg_id; + } } else { + used.buf_args = true; emit("{} {} = _args_{}_[{} << {}];", dt, stmt->short_name(), opengl_data_type_short_name(stmt->element_type()), stmt->arg_id, opengl_argument_address_shifter(stmt->element_type())); @@ -775,8 +886,8 @@ class KernelGen : public IRVisitor { } void visit(ExternalFuncCallStmt *stmt) override { - TI_ASSERT(!stmt->func); - auto format = stmt->source; + TI_ASSERT(stmt->type == ExternalFuncCallStmt::ASSEMBLY); + auto format = stmt->asm_source; std::string source; for (int i = 0; i < format.size(); i++) { @@ -810,9 +921,12 @@ class KernelGen : public IRVisitor { const auto axis = stmt->axis; used.buf_args = true; used.int32 = true; - emit("int {} = _args_i32_[{} + {} * {} + {}];", name, - taichi_opengl_earg_base / sizeof(int), arg_id, taichi_max_num_indices, - axis); + if (!loaded_args_.count(name)) { + emit("int {} = _args_i32_[{} + {} * {} + {}];", name, + taichi_opengl_extra_args_base / sizeof(int), arg_id, + taichi_max_num_indices, axis); + loaded_args_.insert(name); + } } std::string make_kernel_name() { @@ -857,7 +971,7 @@ class KernelGen : public IRVisitor { // Refs: // https://stackoverflow.com/questions/36374652/compute-shaders-optimal-data-division-on-invocations-threads-and-workgroups if (const_iterations > 0) { - if (gen->used_tls) { + if (gen->used_tls_) { // const range with TLS reduction gen->num_workgroups_ = std::max( const_iterations / std::max(gen->workgroup_size_, 1) / 32, 1); @@ -882,7 +996,13 @@ class KernelGen : public IRVisitor { gen->emit("}}"); } }; - + void gen_array_range(Stmt *stmt) { + int num_operands = stmt->num_operands(); + for (int i = 0; i < num_operands; i++) { + gen_array_range(stmt->operand(i)); + } + stmt->accept(this); + } void generate_range_for_kernel(OffloadedStmt *stmt) { TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::range_for); const std::string glsl_kernel_name = make_kernel_name(); @@ -890,8 +1010,8 @@ class KernelGen : public IRVisitor { this->glsl_kernel_name_ = glsl_kernel_name; emit("{{ // range for"); - used_tls = (stmt->tls_prologue != nullptr); - if (used_tls) { + used_tls_ = (stmt->tls_prologue != nullptr); + if (used_tls_) { auto tls_size = stmt->tls_size; // TODO(k-ye): support 'cursor' in LineAppender: emit("int _tls_i32_[{}];", (tls_size + 3) / 4); @@ -920,16 +1040,27 @@ class KernelGen : public IRVisitor { num_workgroups_ = stmt->grid_dim; ScopedGridStrideLoop _gsl(this, end_value - begin_value); emit("int _itv = {} + _sid;", begin_value); + // range_hint is known after compilation, e.g. range of field + stmt->range_hint = std::to_string(end_value - begin_value); stmt->body->accept(this); } else { ScopedIndent _s(line_appender_); - emit("// range known at runtime"); - auto begin_expr = stmt->const_begin ? std::to_string(stmt->begin_value) - : fmt::format("_gtmp_i32_[{} >> 2]", - stmt->begin_offset); - auto end_expr = stmt->const_end ? std::to_string(stmt->end_value) - : fmt::format("_gtmp_i32_[{} >> 2]", - stmt->end_offset); + std::string begin_expr, end_expr; + if (stmt->end_stmt) { + emit("// range from args buffer"); + TI_ASSERT(stmt->const_begin); + begin_expr = std::to_string(stmt->begin_value); + gen_array_range(stmt->end_stmt); + end_expr = stmt->end_stmt->short_name(); + } else { + emit("// range known at runtime"); + begin_expr = stmt->const_begin ? std::to_string(stmt->begin_value) + : fmt::format("_gtmp_i32_[{} >> 2]", + stmt->begin_offset); + end_expr = stmt->const_end + ? std::to_string(stmt->end_value) + : fmt::format("_gtmp_i32_[{} >> 2]", stmt->end_offset); + } workgroup_size_ = stmt->block_dim; num_workgroups_ = stmt->grid_dim; emit("int _beg = {}, _end = {};", begin_expr, end_expr); @@ -938,100 +1069,14 @@ class KernelGen : public IRVisitor { stmt->body->accept(this); } - if (used_tls) { + if (used_tls_) { TI_ASSERT(stmt->tls_epilogue != nullptr); emit("{{ // TLS epilogue"); stmt->tls_epilogue->accept(this); emit("}}"); } - used_tls = false; - - emit("}}\n"); - } - - void generate_struct_for_kernel(OffloadedStmt *stmt) { - TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::struct_for); - used.listman = true; - const std::string glsl_kernel_name = make_kernel_name(); - emit("void {}()", glsl_kernel_name); - this->glsl_kernel_name_ = glsl_kernel_name; - emit("{{ // struct for {}", stmt->snode->node_type_name); - { - ScopedIndent _s(line_appender_); - workgroup_size_ = stmt->block_dim; - num_workgroups_ = stmt->grid_dim; - ScopedGridStrideLoop _gsl(this, "_list_len_"); - emit("int _itv = _list_[_sid];"); - stmt->body->accept(this); - } - emit("}}\n"); - } - - size_t get_snode_base_address(const SNode *snode) { - if (snode->type == SNodeType::root) - return 0; - int chid = find_children_id(snode); - const auto &parent_meta = - struct_compiled_->snode_map.at(snode->parent->node_type_name); - auto choff = parent_meta.children_offsets[chid]; - return choff + get_snode_base_address(snode->parent); - } - - size_t get_snode_meta_address(const SNode *snode) { - auto addr = get_snode_base_address(snode); - addr += struct_compiled_->snode_map.at(snode->node_type_name).stride; - addr -= opengl_get_snode_meta_size(*snode); - return addr; - } - - void generate_listgen_for_dynamic(const SNode *snode) { - TI_ASSERT(snode->type == SNodeType::dynamic); - // the `length` field of a dynamic SNode is at it's end: - // | x[0] | x[1] | x[2] | x[3] | ... | len | - TI_ASSERT_INFO(snode->parent->type == SNodeType::root, - "Non-top-level dynamic not supported yet on OpenGL"); - size_t addr = get_snode_meta_address(snode); - used.int32 = true; - emit("_list_len_ = _data_i32_[{} >> 2];", addr); - emit("for (int i = 0; i < _list_len_; i++) {{"); - { - ScopedIndent _s(line_appender_); - emit("_list_[i] = i;"); - } - emit("}}"); - } - - void generate_listgen_for_dense(const SNode *snode) { - TI_ASSERT(snode->type == SNodeType::dense); - // the `length` field of a dynamic SNode is at it's end: - // | x[0] | x[1] | x[2] | x[3] | ... | len | - emit("_list_len_ = {};", - struct_compiled_->snode_map.at(snode->node_type_name).length); - emit("for (int i = 0; i < _list_len_; i++) {{"); - { - ScopedIndent _s(line_appender_); - emit("_list_[i] = i;"); - } - emit("}}"); - } + used_tls_ = false; - void generate_listgen_kernel(OffloadedStmt *stmt) { - TI_ASSERT(stmt->task_type == OffloadedStmt::TaskType::listgen); - const std::string glsl_kernel_name = make_kernel_name(); - emit("void {}()", glsl_kernel_name); - this->glsl_kernel_name_ = glsl_kernel_name; - used.listman = true; - emit("{{ // listgen {}", stmt->snode->node_type_name); - { - ScopedIndent _s(line_appender_); - if (stmt->snode->type == SNodeType::dense) { - generate_listgen_for_dense(stmt->snode); - } else if (stmt->snode->type == SNodeType::dynamic) { - generate_listgen_for_dynamic(stmt->snode); - } else { - TI_NOT_IMPLEMENTED - } - } emit("}}\n"); } @@ -1039,13 +1084,13 @@ class KernelGen : public IRVisitor { TI_ASSERT(stmt->width() == 1); used.buf_gtmp = true; emit("int {} = {};", stmt->short_name(), stmt->offset); - ptr_signats[stmt->id] = "gtmp"; + ptr_signats_[stmt->id] = "gtmp"; } void visit(ThreadLocalPtrStmt *stmt) override { TI_ASSERT(stmt->width() == 1); emit("int {} = {};", stmt->short_name(), stmt->offset); - ptr_signats[stmt->id] = "tls"; + ptr_signats_[stmt->id] = "tls"; } void visit(LoopIndexStmt *stmt) override { @@ -1103,20 +1148,16 @@ class KernelGen : public IRVisitor { void visit(OffloadedStmt *stmt) override { auto map = irpass::detect_external_ptr_access_in_task(stmt); - this->extptr_access = std::move(map); + this->extptr_access_ = std::move(map); generate_header(); TI_ASSERT(is_top_level_); is_top_level_ = false; - using Type = OffloadedStmt::TaskType; - if (stmt->task_type == Type::serial) { + const auto task_type = stmt->task_type; + if (task_type == OffloadedTaskType::serial) { generate_serial_kernel(stmt); - } else if (stmt->task_type == Type::range_for) { + } else if (task_type == OffloadedTaskType::range_for) { generate_range_for_kernel(stmt); - } else if (stmt->task_type == Type::struct_for) { - generate_struct_for_kernel(stmt); - } else if (stmt->task_type == Type::listgen) { - generate_listgen_kernel(stmt); } else { // struct_for is automatically lowered to ranged_for for dense snodes // (#378). So we only need to support serial and range_for tasks. @@ -1124,19 +1165,14 @@ class KernelGen : public IRVisitor { stmt->task_name()); } is_top_level_ = true; - generate_bottom(); + generate_task_bottom(task_type, stmt->range_hint); + loaded_args_.clear(); } void visit(StructForStmt *) override { TI_ERROR("[glsl] Struct for cannot be nested under OpenGL for now"); } - void visit(ClearListStmt *stmt) override { - used.listman = true; - emit("// clear list {}", stmt->snode->node_type_name); - emit("_list_len_ = 0;"); - } - void visit(IfStmt *if_stmt) override { emit("if ({} != 0) {{", if_stmt->cond->short_name()); if (if_stmt->true_statements) { @@ -1150,7 +1186,7 @@ class KernelGen : public IRVisitor { } public: - CompiledProgram get_compiled_program() { + CompiledTaichiKernel get_compiled_program() { // We have to set it at the last moment, to get all used feature. compiled_program_.set_used(used); return std::move(compiled_program_); @@ -1163,9 +1199,10 @@ class KernelGen : public IRVisitor { } // namespace -CompiledProgram OpenglCodeGen::gen(void) { +CompiledTaichiKernel OpenglCodeGen::gen(void) { #if defined(TI_WITH_OPENGL) - KernelGen codegen(kernel_, struct_compiled_, kernel_name_); + KernelGen codegen(kernel_, struct_compiled_, kernel_name_, + allows_nv_shader_ext_); codegen.run(); return codegen.get_compiled_program(); #else @@ -1177,8 +1214,7 @@ void OpenglCodeGen::lower() { auto ir = kernel_->ir.get(); auto &config = kernel_->program->config; config.demote_dense_struct_fors = true; - irpass::compile_to_executable(ir, config, kernel_, - /*vectorize=*/false, kernel_->grad, + irpass::compile_to_executable(ir, config, kernel_, kernel_->grad, /*ad_use_stack=*/false, config.print_ir, /*lower_global_access=*/true, /*make_thread_local=*/config.make_thread_local); @@ -1187,7 +1223,7 @@ void OpenglCodeGen::lower() { #endif } -CompiledProgram OpenglCodeGen::compile(Kernel &kernel) { +CompiledTaichiKernel OpenglCodeGen::compile(Kernel &kernel) { this->kernel_ = &kernel; this->lower(); diff --git a/taichi/backends/opengl/codegen_opengl.h b/taichi/backends/opengl/codegen_opengl.h index 2ac44e7b2b0bb..a630f291b3959 100644 --- a/taichi/backends/opengl/codegen_opengl.h +++ b/taichi/backends/opengl/codegen_opengl.h @@ -3,6 +3,7 @@ #include "taichi/inc/constants.h" #include "taichi/lang_util.h" #include "taichi/backends/opengl/struct_opengl.h" +#include "taichi/backends/opengl/opengl_api.h" #include #include @@ -17,20 +18,24 @@ namespace opengl { class OpenglCodeGen { public: OpenglCodeGen(const std::string &kernel_name, - const StructCompiledResult *struct_compiled) - : kernel_name_(kernel_name), struct_compiled_(struct_compiled) { + const StructCompiledResult *struct_compiled, + bool allows_nv_shader_ext) + : kernel_name_(kernel_name), + struct_compiled_(struct_compiled), + allows_nv_shader_ext_(allows_nv_shader_ext) { } - CompiledProgram compile(Kernel &kernel); + CompiledTaichiKernel compile(Kernel &kernel); private: void lower(); - CompiledProgram gen(); + CompiledTaichiKernel gen(); const std::string kernel_name_; [[maybe_unused]] const StructCompiledResult *struct_compiled_; Kernel *kernel_; + const bool allows_nv_shader_ext_; }; } // namespace opengl diff --git a/taichi/backends/opengl/opengl_api.cpp b/taichi/backends/opengl/opengl_api.cpp index e441b78edd893..9e870814e75f0 100644 --- a/taichi/backends/opengl/opengl_api.cpp +++ b/taichi/backends/opengl/opengl_api.cpp @@ -1,24 +1,25 @@ -//#define _GLSL_DEBUG 1 #include "opengl_api.h" +#include + #include "taichi/backends/opengl/opengl_kernel_util.h" +#include "taichi/backends/opengl/opengl_utils.h" +#include "taichi/backends/opengl/shaders/runtime.h" +#include "taichi/ir/transforms.h" #include "taichi/program/kernel.h" #include "taichi/program/program.h" #include "taichi/program/py_print_buffer.h" #include "taichi/util/environ_config.h" -#include "taichi/backends/opengl/shaders/runtime.h" -#include "taichi/backends/opengl/shaders/listman.h" -#include "taichi/ir/transforms.h" #ifdef TI_WITH_OPENGL -#include "glad/glad.h" +#include "glad/gl.h" +#include "glad/egl.h" #include "GLFW/glfw3.h" #include "taichi/backends/opengl/opengl_device.h" -#endif - -#include +#endif // TI_WITH_OPENGL -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { namespace opengl { #define PER_OPENGL_EXTENSION(x) bool opengl_extension_##x; @@ -30,27 +31,16 @@ namespace opengl { int opengl_max_block_dim = 1024; int opengl_max_grid_dim = 1024; -#ifdef TI_WITH_OPENGL +// kUseGles is set at most once in initialize_opengl below. +// TODO: Properly support setting GLES/GLSL in opengl backend +// without this global static boolean. +static bool kUseGles = false; -static std::string add_line_markers(std::string x) { - std::string marker; - size_t pos = 0, npos; - int line = 0; - while (1) { - npos = x.find_first_of('\n', pos); - marker = fmt::format("{:3d} ", ++line); - if (npos == std::string::npos) - break; - x.insert(pos, marker); - pos = npos + 1 + marker.size(); - } - return x; -} +#ifdef TI_WITH_OPENGL struct OpenGlRuntimeImpl { struct { DeviceAllocation runtime = kDeviceNullAllocation; - DeviceAllocation listman = kDeviceNullAllocation; DeviceAllocation root = kDeviceNullAllocation; DeviceAllocation gtmp = kDeviceNullAllocation; } core_bufs; @@ -58,16 +48,16 @@ struct OpenGlRuntimeImpl { OpenGlRuntimeImpl() { } - std::unique_ptr runtime; - std::unique_ptr listman; - - std::vector> programs; + std::unique_ptr runtime{nullptr}; + std::vector> programs; }; -bool initialize_opengl(bool error_tolerance) { +// TODO: Move this into ProgramImpl class so that it naturally +// gets access to config->use_gles. +bool initialize_opengl(bool use_gles, bool error_tolerance) { static std::optional supported; // std::nullopt - TI_TRACE("initialize_opengl({}) called", error_tolerance); + TI_TRACE("initialize_opengl({}, {}) called", use_gles, error_tolerance); if (supported.has_value()) { // this function has been called before if (supported.value()) { // detected to be true in last call @@ -79,45 +69,137 @@ bool initialize_opengl(bool error_tolerance) { } } - glfwInit(); - // Compute Shader requires OpenGL 4.3+ (or OpenGL ES 3.1+) - glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE); - glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4); - glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3); - glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); - glfwWindowHint(GLFW_COCOA_MENUBAR, GLFW_FALSE); - // GL context needs a window (There's no true headless GL) - GLFWwindow *window = - glfwCreateWindow(1, 1, "Make OpenGL Context", nullptr, nullptr); - if (!window) { - const char *desc = nullptr; - int status = glfwGetError(&desc); - if (!desc) - desc = "Unknown Error"; - if (error_tolerance) { - // error tolerated, returning false + // Code below is guaranteed to be called at most once. + int opengl_version = 0; + + if (glfwInit()) { + // Compute Shader requires OpenGL 4.3+ (or OpenGL ES 3.1+) + if (use_gles) { + glfwWindowHint(GLFW_CLIENT_API, GLFW_OPENGL_ES_API); + glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3); + glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 1); + } else { + glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE); + glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 4); + glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3); + } + glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); + glfwWindowHint(GLFW_COCOA_MENUBAR, GLFW_FALSE); + // GL context needs a window (when using GLFW) + GLFWwindow *window = + glfwCreateWindow(1, 1, "Make OpenGL Context", nullptr, nullptr); + if (!window) { + const char *desc = nullptr; + int status = glfwGetError(&desc); + if (!desc) + desc = "Unknown Error"; TI_DEBUG("[glsl] cannot create GLFW window: error {}: {}", status, desc); - supported = std::make_optional(false); - return false; + } else { + glfwMakeContextCurrent(window); + if (use_gles) { + opengl_version = gladLoadGLES2(glfwGetProcAddress); + } else { + opengl_version = gladLoadGL(glfwGetProcAddress); + } + TI_DEBUG("OpenGL context loaded through GLFW"); } - TI_ERROR("[glsl] cannot create GLFW window: error {}: {}", status, desc); } - glfwMakeContextCurrent(window); - if (!gladLoadGLLoader((GLADloadproc)glfwGetProcAddress)) { + if (!opengl_version) { + TI_TRACE("Attempting to load with EGL"); + + // Try EGL instead + int egl_version = gladLoaderLoadEGL(nullptr); + + if (!egl_version) { + TI_DEBUG("Failed to load EGL"); + } else { + static const EGLint configAttribs[] = {EGL_SURFACE_TYPE, + EGL_PBUFFER_BIT, + EGL_BLUE_SIZE, + 8, + EGL_GREEN_SIZE, + 8, + EGL_RED_SIZE, + 8, + EGL_DEPTH_SIZE, + 8, + EGL_RENDERABLE_TYPE, + EGL_OPENGL_BIT, + EGL_NONE}; + + // Initialize EGL + EGLDisplay egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY); + + EGLint major, minor; + eglInitialize(egl_display, &major, &minor); + + egl_version = gladLoaderLoadEGL(egl_display); + + TI_DEBUG("Loaded EGL {}.{} on display {}", + GLAD_VERSION_MAJOR(egl_version), GLAD_VERSION_MINOR(egl_version), + egl_display); + + // Select an appropriate configuration + EGLint num_configs; + EGLConfig egl_config; + + eglChooseConfig(egl_display, configAttribs, &egl_config, 1, &num_configs); + + // Bind the API (EGL >= 1.2) + if (egl_version >= GLAD_MAKE_VERSION(1, 2)) { + eglBindAPI(use_gles ? EGL_OPENGL_ES_API : EGL_OPENGL_API); + } + + // Create a context and make it current + EGLContext egl_context = EGL_NO_CONTEXT; + if (use_gles) { + static const EGLint gl_attribs[] = { + EGL_CONTEXT_MAJOR_VERSION, + 3, + EGL_CONTEXT_MINOR_VERSION, + 1, + EGL_NONE, + }; + + egl_context = eglCreateContext(egl_display, egl_config, EGL_NO_CONTEXT, + gl_attribs); + } else { + egl_context = + eglCreateContext(egl_display, egl_config, EGL_NO_CONTEXT, nullptr); + } + + eglMakeCurrent(egl_display, EGL_NO_SURFACE, EGL_NO_SURFACE, egl_context); + + if (use_gles) { + opengl_version = gladLoadGLES2(glad_eglGetProcAddress); + } else { + opengl_version = gladLoadGL(glad_eglGetProcAddress); + } + } + } + + // Load OpenGL API + if (!opengl_version) { if (error_tolerance) { - TI_WARN("[glsl] cannot initialize GLAD"); + TI_WARN("Can not create OpenGL context"); supported = std::make_optional(false); return false; } - TI_ERROR("[glsl] cannot initialize GLAD"); + TI_ERROR("Can not create OpenGL context"); } + + TI_DEBUG("{} version {}.{}", use_gles ? "GLES" : "OpenGL", + GLAD_VERSION_MAJOR(opengl_version), + GLAD_VERSION_MINOR(opengl_version)); + #define PER_OPENGL_EXTENSION(x) \ if ((opengl_extension_##x = GLAD_##x)) \ TI_TRACE("[glsl] Found " #x); #include "taichi/inc/opengl_extension.inc.h" #undef PER_OPENGL_EXTENSION - if (!opengl_extension_GL_ARB_compute_shader) { + + if (!use_gles && !opengl_extension_GL_ARB_compute_shader) { if (error_tolerance) { TI_INFO("Your OpenGL does not support GL_ARB_compute_shader extension"); supported = std::make_optional(false); @@ -134,51 +216,71 @@ bool initialize_opengl(bool error_tolerance) { TI_TRACE("GL_MAX_COMPUTE_WORK_GROUP_SIZE: {}", opengl_max_grid_dim); supported = std::make_optional(true); + kUseGles = use_gles; return true; } -void CompiledProgram::init_args(Kernel *kernel) { +void CompiledTaichiKernel::init_args(Kernel *kernel) { arg_count = kernel->args.size(); - ret_count = kernel->rets.size(); + ret_count = 0; + for (auto &ret : kernel->rets) { + if (auto tensor_type = ret.dt->cast()) + ret_count += tensor_type->get_num_elements(); + else + ret_count += 1; + } for (int i = 0; i < arg_count; i++) { - if (kernel->args[i].is_external_array) { - ext_arr_map[i] = kernel->args[i].size; + const auto dtype_name = kernel->args[i].dt.to_string(); + if (kernel->args[i].is_array) { + arr_args[i] = CompiledArrayArg( + {/*dtype_enum=*/to_gl_dtype_enum(kernel->args[i].dt), dtype_name, + /*field_dim=*/kernel->args[i].total_dim - + kernel->args[i].element_shape.size(), + /*is_scalar=*/kernel->args[i].element_shape.size() == 0, + /*element_shape=*/kernel->args[i].element_shape, + /*shape_offset_in_bytes_in_args_buf=*/taichi_opengl_extra_args_base + + i * taichi_max_num_indices * sizeof(int), + /*total_size=*/kernel->args[i].size}); + } else { + scalar_args[i] = ScalarArg( + {dtype_name, /*offset_in_bytes_in_args_buf=*/i * sizeof(uint64_t)}); } } - for (const auto &[i, size] : ext_arr_map) { - total_ext_arr_size += size; - } - args_buf_size = arg_count * sizeof(uint64_t); - if (ext_arr_map.size()) { - args_buf_size = taichi_opengl_earg_base + + if (arr_args.size()) { + args_buf_size = taichi_opengl_extra_args_base + arg_count * taichi_max_num_indices * sizeof(int); } ret_buf_size = ret_count * sizeof(uint64_t); } -void CompiledProgram::add( - const std::string &kernel_name, - const std::string &kernel_source_code, - int num_workgrpus, +void CompiledTaichiKernel::add( + const std::string &name, + const std::string &source_code, + OffloadedTaskType type, + const std::string &range_hint, + int num_workgroups, int workgroup_size, std::unordered_map *ext_ptr_access) { - num_workgrpus = std::min(num_workgrpus, opengl_max_grid_dim); + num_workgroups = std::min(num_workgroups, opengl_max_grid_dim); workgroup_size = std::min(workgroup_size, opengl_max_block_dim); - size_t layout_pos = kernel_source_code.find("precision highp float;\n"); + size_t layout_pos = source_code.find("precision highp float;\n"); TI_ASSERT(layout_pos != std::string::npos); std::string source = - kernel_source_code.substr(0, layout_pos) + + source_code.substr(0, layout_pos) + fmt::format( "layout(local_size_x = {}, local_size_y = 1, local_size_z = " "1) in;\n", workgroup_size) + - kernel_source_code.substr(layout_pos); + source_code.substr(layout_pos); - kernels.push_back({kernel_name, source, workgroup_size, num_workgrpus}); + TI_DEBUG("[glsl]\ncompiling kernel {}<<<{}, {}>>>:\n{}", name, num_workgroups, + workgroup_size, source); + tasks.push_back( + {name, source, type, range_hint, workgroup_size, num_workgroups}); if (ext_ptr_access) { for (auto pair : *ext_ptr_access) { @@ -191,7 +293,7 @@ void CompiledProgram::add( } } -int CompiledProgram::lookup_or_add_string(const std::string &str) { +int CompiledTaichiKernel::lookup_or_add_string(const std::string &str) { int i; for (i = 0; i < str_table.size(); i++) { if (str_table[i] == str) { @@ -242,7 +344,7 @@ void dump_message_buffer(Device *device, device->unmap(runtime_buf); } -bool CompiledProgram::check_ext_arr_read(int i) const { +bool CompiledTaichiKernel::check_ext_arr_read(int i) const { auto iter = ext_arr_access.find(i); if (iter == ext_arr_access.end()) return false; @@ -251,7 +353,7 @@ bool CompiledProgram::check_ext_arr_read(int i) const { irpass::ExternalPtrAccess::NONE; } -bool CompiledProgram::check_ext_arr_write(int i) const { +bool CompiledTaichiKernel::check_ext_arr_write(int i) const { auto iter = ext_arr_access.find(i); if (iter == ext_arr_access.end()) return false; @@ -260,46 +362,84 @@ bool CompiledProgram::check_ext_arr_write(int i) const { irpass::ExternalPtrAccess::NONE; } -void CompiledProgram::set_used(const UsedFeature &used) { +void CompiledTaichiKernel::set_used(const UsedFeature &used) { this->used = used; } -OpenGlRuntime::~OpenGlRuntime() = default; - -void DeviceCompiledProgram::launch(Context &ctx, OpenGlRuntime *runtime) const { - std::array ext_arr_host_ptrs; +OpenGlRuntime::~OpenGlRuntime() { + saved_arg_bufs.clear(); + impl.reset(nullptr); + device.reset(); +} +void DeviceCompiledTaichiKernel::launch(RuntimeContext &ctx, + Kernel *kernel, + OpenGlRuntime *runtime) const { uint8_t *args_buf_mapped = nullptr; + auto args = kernel->args; + // If we have external array args we'll have to do host-device memcpy. + // Whether we get external array arg is runtime information. + bool has_ext_arr = false; + bool synced = false; + + if (program_.args_buf_size || program_.ret_buf_size) { + args_buf_ = + device_->allocate_memory_unique({taichi_opengl_external_arr_base, + /*host_write=*/true, + /*host_read=*/true, + /*export_sharing=*/false}); + } - // Prepare external array - if (program_.total_ext_arr_size) { - void *baseptr = device_->map(ext_arr_buf_); - - size_t accum_size = 0; - for (const auto &[i, size] : program_.ext_arr_map) { - auto ptr = (void *)ctx.args[i]; - ctx.args[i] = accum_size; - ext_arr_host_ptrs[i] = ptr; - if (program_.check_ext_arr_read(i)) { - std::memcpy((char *)baseptr + accum_size, ptr, size); + // Prepare external arrays/ndarrays + // - ctx.args[i] contains its ptr on host, it could be a raw ptr or + // DeviceAllocation* + // - For raw ptrs, its content will be synced to device through + // ext_arr_bufs_[i] which is its corresponding DeviceAllocation on device. + // Note shapes of these external arrays still reside in argument buffer, + // see more details below. + for (auto &item : program_.arr_args) { + int i = item.first; + TI_ASSERT(args[i].is_array); + if (args[i].size == 0 || ctx.is_device_allocation[i]) + continue; + has_ext_arr = true; + if (args[i].size != item.second.total_size || + ext_arr_bufs_[i] == kDeviceNullAllocation) { + if (ext_arr_bufs_[i] != kDeviceNullAllocation) { + device_->dealloc_memory(ext_arr_bufs_[i]); } - accum_size += size; + ext_arr_bufs_[i] = device_->allocate_memory( + {args[i].size, /*host_write=*/true, /*host_read=*/true, + /*export_sharing=*/false}); + item.second.total_size = args[i].size; } - - device_->unmap(ext_arr_buf_); + void *host_ptr = (void *)ctx.args[i]; + void *baseptr = device_->map(ext_arr_bufs_[i]); + if (program_.check_ext_arr_read(i)) { + std::memcpy((char *)baseptr, host_ptr, args[i].size); + } + device_->unmap(ext_arr_bufs_[i]); } - + // clang-format off // Prepare argument buffer + // Layout: + // | args | shape of ext arr | ret | + // baseptr + // |..taichi_opengl_extra_args_base..| + // |...............taichi_opengl_ret_base.................| + // |................taichi_opengl_external_arr_base..............| + // clang-format on if (program_.args_buf_size) { - args_buf_mapped = (uint8_t *)device_->map(args_buf_); + args_buf_mapped = (uint8_t *)device_->map(*args_buf_); std::memcpy(args_buf_mapped, ctx.args, program_.arg_count * sizeof(uint64_t)); - if (program_.ext_arr_map.size()) { + if (program_.arr_args.size()) { std::memcpy( - args_buf_mapped + size_t(taichi_opengl_earg_base), ctx.extra_args, + args_buf_mapped + size_t(taichi_opengl_extra_args_base), + ctx.extra_args, size_t(program_.arg_count * taichi_max_num_indices) * sizeof(int)); } - device_->unmap(args_buf_); + device_->unmap(*args_buf_); } // Prepare runtime @@ -315,30 +455,41 @@ void DeviceCompiledProgram::launch(Context &ctx, OpenGlRuntime *runtime) const { // Kernel dispatch int i = 0; - for (const auto &kernel : program_.kernels) { + for (const auto &task : program_.tasks) { auto binder = compiled_pipeline_[i]->resource_binder(); auto &core_bufs = runtime->impl->core_bufs; - binder->buffer(0, int(GLBufId::Runtime), core_bufs.runtime); - binder->buffer(0, int(GLBufId::Listman), core_bufs.listman); - binder->buffer(0, int(GLBufId::Root), core_bufs.root); - binder->buffer(0, int(GLBufId::Gtmp), core_bufs.gtmp); - if (program_.args_buf_size) - binder->buffer(0, int(GLBufId::Args), args_buf_); - if (program_.ret_buf_size) - binder->buffer(0, int(GLBufId::Retr), ret_buf_); - if (program_.total_ext_arr_size) - binder->buffer(0, int(GLBufId::Extr), ext_arr_buf_); + binder->buffer(0, static_cast(GLBufId::Runtime), core_bufs.runtime); + if (program_.used.buf_data) + binder->buffer(0, static_cast(GLBufId::Root), core_bufs.root); + binder->buffer(0, static_cast(GLBufId::Gtmp), core_bufs.gtmp); + if (program_.args_buf_size || program_.ret_buf_size) + binder->buffer(0, static_cast(GLBufId::Args), *args_buf_); + // TODO: properly assert and throw if we bind more than allowed SSBOs. + // On most devices this number is 8. But I need to look up how + // to query this information so currently this is thrown from OpenGl. + for (const auto [arg_id, bind_id] : program_.used.arr_arg_to_bind_idx) { + if (ctx.is_device_allocation[arg_id]) { + DeviceAllocation *ptr = + static_cast((void *)ctx.args[arg_id]); + + binder->buffer(0, bind_id, *ptr); + } else { + binder->buffer(0, bind_id, ext_arr_bufs_[arg_id]); + } + } cmdlist->bind_pipeline(compiled_pipeline_[i].get()); - cmdlist->bind_resources(binder); - cmdlist->dispatch(kernel.num_groups, 1, 1); + if (i == 0) + cmdlist->bind_resources(binder); + cmdlist->dispatch(task.num_groups, 1, 1); cmdlist->memory_barrier(); i++; } - if (program_.used.print || program_.total_ext_arr_size || - program_.ret_buf_size) { + if (program_.used.print || has_ext_arr || program_.ret_buf_size) { + // We'll do device->host memcpy later so sync is required. device_->get_compute_stream()->submit_synced(cmdlist.get()); + synced = true; } else { device_->get_compute_stream()->submit(cmdlist.get()); } @@ -349,58 +500,46 @@ void DeviceCompiledProgram::launch(Context &ctx, OpenGlRuntime *runtime) const { program_.str_table); } - if (program_.total_ext_arr_size) { - uint8_t *baseptr = (uint8_t *)device_->map(ext_arr_buf_); - for (const auto &[i, size] : program_.ext_arr_map) { - memcpy(ext_arr_host_ptrs[i], baseptr + size_t(ctx.args[i]), size); + if (has_ext_arr) { + for (auto &item : program_.arr_args) { + int i = item.first; + if (args[i].size != 0 && !ctx.is_device_allocation[i]) { + uint8_t *baseptr = (uint8_t *)device_->map(ext_arr_bufs_[i]); + memcpy((void *)ctx.args[i], baseptr, args[i].size); + device_->unmap(ext_arr_bufs_[i]); + } } - device_->unmap(ext_arr_buf_); } if (program_.ret_buf_size) { - memcpy(runtime->result_buffer, device_->map(ret_buf_), + uint8_t *baseptr = (uint8_t *)device_->map(*args_buf_); + memcpy(runtime->result_buffer, baseptr + taichi_opengl_ret_base, program_.ret_buf_size); - device_->unmap(ret_buf_); + device_->unmap(*args_buf_); } -} - -DeviceCompiledProgram::DeviceCompiledProgram(CompiledProgram &&program, - Device *device) - : program_(std::move(program)), device_(device) { - if (program_.args_buf_size) { - args_buf_ = - device->allocate_memory({program_.args_buf_size, /*host_write=*/true, - /*host_read=*/false, - /*export_sharing=*/false}); + if (program_.args_buf_size || program_.ret_buf_size) { + runtime->saved_arg_bufs.push_back(std::move(args_buf_)); } - if (program_.total_ext_arr_size) { - // Set both host write & host read for now - ext_arr_buf_ = device->allocate_memory({program_.total_ext_arr_size, - /*host_write=*/true, - /*host_read=*/true, - /*export_sharing=*/false}); - } - - if (program_.ret_buf_size) { - ret_buf_ = - device->allocate_memory({program_.ret_buf_size, /*host_write=*/false, - /*host_read=*/true, - /*export_sharing=*/false}); + if (synced) { + runtime->saved_arg_bufs.clear(); } +} - for (auto &k : program_.kernels) { - compiled_pipeline_.push_back( - device->create_pipeline({PipelineSourceType::glsl_src, - k.kernel_src.data(), k.kernel_src.size()}, - k.kernel_name)); +DeviceCompiledTaichiKernel::DeviceCompiledTaichiKernel( + CompiledTaichiKernel &&program, + Device *device) + : device_(device), program_(std::move(program)) { + for (auto &t : program_.tasks) { + compiled_pipeline_.push_back(device->create_pipeline( + {PipelineSourceType::glsl_src, t.src.data(), t.src.size()}, t.name)); } } OpenGlRuntime::OpenGlRuntime() { initialize_opengl(); - device = std::make_unique(); + device = std::make_shared(); impl = std::make_unique(); @@ -408,25 +547,21 @@ OpenGlRuntime::OpenGlRuntime() { impl->core_bufs.runtime = device->allocate_memory( {sizeof(GLSLRuntime), /*host_write=*/false, /*host_read=*/true}); - impl->listman = std::make_unique(); - impl->core_bufs.listman = device->allocate_memory({sizeof(GLSLListman)}); - impl->core_bufs.gtmp = device->allocate_memory({taichi_global_tmp_buffer_size}); auto cmdlist = device->get_compute_stream()->new_command_list(); cmdlist->buffer_fill(impl->core_bufs.runtime.get_ptr(0), sizeof(GLSLRuntime), 0); - cmdlist->buffer_fill(impl->core_bufs.listman.get_ptr(0), sizeof(GLSLListman), - 0); cmdlist->buffer_fill(impl->core_bufs.gtmp.get_ptr(0), taichi_global_tmp_buffer_size, 0); device->get_compute_stream()->submit_synced(cmdlist.get()); } -DeviceCompiledProgram *OpenGlRuntime::keep(CompiledProgram &&program) { - auto p = - std::make_unique(std::move(program), device.get()); +DeviceCompiledTaichiKernel *OpenGlRuntime::keep( + CompiledTaichiKernel &&program) { + auto p = std::make_unique(std::move(program), + device.get()); auto ptr = p.get(); impl->programs.push_back(std::move(p)); return ptr; @@ -440,10 +575,10 @@ void OpenGlRuntime::add_snode_tree(size_t size) { device->get_compute_stream()->submit_synced(cmdlist.get()); } -bool is_opengl_api_available() { +bool is_opengl_api_available(bool use_gles) { if (get_environ_config("TI_ENABLE_OPENGL", 1) == 0) return false; - return initialize_opengl(true); + return initialize_opengl(use_gles, true); } #else @@ -458,7 +593,8 @@ OpenGlRuntime::~OpenGlRuntime() { TI_NOT_IMPLEMENTED; } -DeviceCompiledProgram *OpenGlRuntime::keep(CompiledProgram &&program) { +DeviceCompiledTaichiKernel *OpenGlRuntime::keep( + CompiledTaichiKernel &&program) { TI_NOT_IMPLEMENTED; return nullptr; } @@ -467,15 +603,20 @@ void OpenGlRuntime::add_snode_tree(size_t size) { TI_NOT_IMPLEMENTED; } -bool is_opengl_api_available() { +bool is_opengl_api_available(bool use_gles) { return false; } -bool initialize_opengl(bool error_tolerance) { +bool initialize_opengl(bool use_gles, bool error_tolerance) { TI_NOT_IMPLEMENTED; } #endif // TI_WITH_OPENGL +bool is_gles() { + return kUseGles; +} + } // namespace opengl -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/opengl/opengl_api.h b/taichi/backends/opengl/opengl_api.h index 680fb66f31f7f..6f4ee6e27a8f5 100644 --- a/taichi/backends/opengl/opengl_api.h +++ b/taichi/backends/opengl/opengl_api.h @@ -1,28 +1,30 @@ #pragma once -#include "taichi/common/core.h" -#include "taichi/ir/transforms.h" - +#include #include #include -#include #include "taichi/backends/device.h" -#include "taichi/backends/opengl/opengl_kernel_util.h" #include "taichi/backends/opengl/opengl_kernel_launcher.h" +#include "taichi/backends/opengl/opengl_kernel_util.h" +#include "taichi/common/core.h" +#include "taichi/ir/offloaded_task_type.h" +#include "taichi/ir/transforms.h" #define TI_RUNTIME_HOST #include "taichi/program/context.h" #undef TI_RUNTIME_HOST -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { class Kernel; class OffloadedStmt; namespace opengl { -bool initialize_opengl(bool error_tolerance = false); -bool is_opengl_api_available(); +bool initialize_opengl(bool use_gles = false, bool error_tolerance = false); +bool is_opengl_api_available(bool use_gles = false); +bool is_gles(); #define PER_OPENGL_EXTENSION(x) extern bool opengl_extension_##x; #include "taichi/inc/opengl_extension.inc.h" @@ -39,19 +41,45 @@ extern int opengl_threads_per_block; return false; \ })() -struct CompiledKernel { - std::string kernel_name; - std::string kernel_src; +struct CompiledOffloadedTask { + std::string name; + std::string src; + OffloadedTaskType type; + std::string range_hint; int workgroup_size; int num_groups; - TI_IO_DEF(kernel_name, kernel_src, workgroup_size, num_groups); + TI_IO_DEF(name, src, workgroup_size, num_groups); +}; + +struct ScalarArg { + std::string dtype_name; + size_t offset_in_bytes_in_args_buf{0}; + + TI_IO_DEF(offset_in_bytes_in_args_buf); }; -struct CompiledProgram { +struct CompiledArrayArg { + uint32_t dtype; + std::string dtype_name; + std::size_t field_dim{0}; + bool is_scalar{false}; + std::vector element_shape; + size_t shape_offset_in_bytes_in_args_buf{0}; + size_t total_size{0}; // Runtime information + + TI_IO_DEF(field_dim, + is_scalar, + element_shape, + shape_offset_in_bytes_in_args_buf); +}; + +struct CompiledTaichiKernel { void init_args(Kernel *kernel); - void add(const std::string &kernel_name, - const std::string &kernel_source_code, + void add(const std::string &name, + const std::string &source_code, + OffloadedTaskType type, + const std::string &range_hint, int num_workgrous, int workgroup_size, std::unordered_map *ext_ptr_access = @@ -62,46 +90,51 @@ struct CompiledProgram { bool check_ext_arr_read(int i) const; bool check_ext_arr_write(int i) const; - std::vector kernels; + std::vector tasks; int arg_count{0}; int ret_count{0}; size_t args_buf_size{0}; - size_t total_ext_arr_size{0}; size_t ret_buf_size{0}; - std::unordered_map ext_arr_map; std::unordered_map ext_arr_access; std::vector str_table; UsedFeature used; + std::unordered_map scalar_args; + mutable std::unordered_map arr_args; - TI_IO_DEF(kernels, + TI_IO_DEF(tasks, arg_count, ret_count, args_buf_size, - total_ext_arr_size, ret_buf_size, - ext_arr_map, - ext_arr_access, - str_table); + scalar_args, + arr_args, + used.arr_arg_to_bind_idx); }; -class DeviceCompiledProgram { +class DeviceCompiledTaichiKernel { public: - DeviceCompiledProgram(CompiledProgram &&program, Device *device); - void launch(Context &ctx, OpenGlRuntime *runtime) const; + DeviceCompiledTaichiKernel(CompiledTaichiKernel &&program, Device *device); + void launch(RuntimeContext &ctx, + Kernel *kernel, + OpenGlRuntime *runtime) const; private: Device *device_; - CompiledProgram program_; + CompiledTaichiKernel program_; std::vector> compiled_pipeline_; - DeviceAllocation args_buf_{kDeviceNullAllocation}; - DeviceAllocation ext_arr_buf_{kDeviceNullAllocation}; + mutable std::unique_ptr args_buf_{nullptr}; DeviceAllocation ret_buf_{kDeviceNullAllocation}; + // Only saves numpy/torch cpu based external array since they don't have + // DeviceAllocation. + // Taichi |Ndarray| manages their own DeviceAllocation so it's not saved here. + mutable DeviceAllocation ext_arr_bufs_[taichi_max_num_args]{ + kDeviceNullAllocation}; }; } // namespace opengl - -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/opengl/opengl_data_types.h b/taichi/backends/opengl/opengl_data_types.h index 95d1f4c425285..f3eada96df6df 100644 --- a/taichi/backends/opengl/opengl_data_types.h +++ b/taichi/backends/opengl/opengl_data_types.h @@ -57,13 +57,5 @@ inline int opengl_argument_address_shifter(DataType type) { return 3 - opengl_data_address_shifter(type); } -inline int opengl_get_snode_meta_size(const SNode &snode) { - if (snode.type == SNodeType::dynamic) { - return sizeof(int); - } else { - return 0; - } -} - } // namespace opengl TLANG_NAMESPACE_END diff --git a/taichi/backends/opengl/opengl_device.cpp b/taichi/backends/opengl/opengl_device.cpp index 417314ae4d19d..6794c53da666e 100644 --- a/taichi/backends/opengl/opengl_device.cpp +++ b/taichi/backends/opengl/opengl_device.cpp @@ -1,9 +1,159 @@ #include "opengl_device.h" +#include "opengl_api.h" namespace taichi { namespace lang { namespace opengl { +namespace { +const std::unordered_map format_to_gl_internal_format = { + {BufferFormat::r8, GL_R8}, + {BufferFormat::rg8, GL_RG8}, + {BufferFormat::rgba8, GL_RGBA8}, + {BufferFormat::rgba8srgb, GL_SRGB8_ALPHA8}, + {BufferFormat::bgra8, GL_BGRA8_EXT}, + {BufferFormat::bgra8srgb, GL_INVALID_ENUM}, + {BufferFormat::r8u, GL_R8UI}, + {BufferFormat::rg8u, GL_RG8UI}, + {BufferFormat::rgba8u, GL_RGBA8UI}, + {BufferFormat::r8i, GL_R8I}, + {BufferFormat::rg8i, GL_RG8I}, + {BufferFormat::rgba8i, GL_RGBA8I}, + {BufferFormat::r16, GL_R16}, + {BufferFormat::rg16, GL_RG16}, + {BufferFormat::rgb16, GL_RGB16}, + {BufferFormat::rgba16, GL_RGBA16}, + {BufferFormat::r16u, GL_R16UI}, + {BufferFormat::rg16u, GL_RG16UI}, + {BufferFormat::rgb16u, GL_RGB16UI}, + {BufferFormat::rgba16u, GL_RGBA16UI}, + {BufferFormat::r16i, GL_R16I}, + {BufferFormat::rg16i, GL_RG16I}, + {BufferFormat::rgb16i, GL_RGB16I}, + {BufferFormat::rgba16i, GL_RGBA16I}, + {BufferFormat::r16f, GL_R16F}, + {BufferFormat::rg16f, GL_RG16F}, + {BufferFormat::rgb16f, GL_RGB16F}, + {BufferFormat::rgba16f, GL_RGBA16F}, + {BufferFormat::r32u, GL_R32UI}, + {BufferFormat::rg32u, GL_RG32UI}, + {BufferFormat::rgb32u, GL_RGB32UI}, + {BufferFormat::rgba32u, GL_RGBA32UI}, + {BufferFormat::r32i, GL_R32I}, + {BufferFormat::rg32i, GL_RG32I}, + {BufferFormat::rgb32i, GL_RGB32I}, + {BufferFormat::rgba32i, GL_RGBA32I}, + {BufferFormat::r32f, GL_R32F}, + {BufferFormat::rg32f, GL_RG32F}, + {BufferFormat::rgb32f, GL_RGB32F}, + {BufferFormat::rgba32f, GL_RGBA32F}, + {BufferFormat::depth16, GL_INVALID_ENUM}, + {BufferFormat::depth24stencil8, GL_DEPTH24_STENCIL8}, + {BufferFormat::depth32f, GL_DEPTH32F_STENCIL8}}; + +const std::unordered_map gl_internal_format_to_type = { + {GL_R8, GL_UNSIGNED_BYTE}, + {GL_R8_SNORM, GL_BYTE}, + {GL_R8UI, GL_UNSIGNED_BYTE}, + {GL_R8I, GL_BYTE}, + {GL_R16, GL_UNSIGNED_SHORT}, + {GL_R16_SNORM, GL_SHORT}, + {GL_R16F, GL_HALF_FLOAT}, + {GL_R16UI, GL_UNSIGNED_SHORT}, + {GL_R16I, GL_SHORT}, + {GL_R32UI, GL_UNSIGNED_INT}, + {GL_R32I, GL_INT}, + {GL_R32F, GL_FLOAT}, + {GL_RG8, GL_UNSIGNED_BYTE}, + {GL_RG8_SNORM, GL_BYTE}, + {GL_RG8UI, GL_UNSIGNED_BYTE}, + {GL_RG8I, GL_BYTE}, + {GL_RG16, GL_UNSIGNED_SHORT}, + {GL_RG16_SNORM, GL_SHORT}, + {GL_RG16F, GL_HALF_FLOAT}, + {GL_RG16UI, GL_UNSIGNED_SHORT}, + {GL_RG16I, GL_SHORT}, + {GL_RG32UI, GL_UNSIGNED_INT}, + {GL_RG32I, GL_INT}, + {GL_RG32F, GL_FLOAT}, + {GL_RGB8, GL_UNSIGNED_BYTE}, + {GL_RGB8_SNORM, GL_BYTE}, + {GL_RGB8UI, GL_UNSIGNED_BYTE}, + {GL_RGB8I, GL_BYTE}, + {GL_RGB16, GL_UNSIGNED_SHORT}, + {GL_RGB16_SNORM, GL_SHORT}, + {GL_RGB16F, GL_HALF_FLOAT}, + {GL_RGB16UI, GL_UNSIGNED_SHORT}, + {GL_RGB16I, GL_SHORT}, + {GL_RGB32UI, GL_UNSIGNED_INT}, + {GL_RGB32I, GL_INT}, + {GL_RGB32F, GL_FLOAT}, + {GL_RGBA8, GL_UNSIGNED_BYTE}, + {GL_SRGB8_ALPHA8, GL_UNSIGNED_BYTE}, + {GL_RGBA8_SNORM, GL_BYTE}, + {GL_RGBA8UI, GL_UNSIGNED_BYTE}, + {GL_RGBA8I, GL_BYTE}, + {GL_RGBA16, GL_UNSIGNED_SHORT}, + {GL_RGBA16_SNORM, GL_SHORT}, + {GL_RGBA16F, GL_HALF_FLOAT}, + {GL_RGBA16UI, GL_UNSIGNED_SHORT}, + {GL_RGBA16I, GL_SHORT}, + {GL_RGBA32UI, GL_UNSIGNED_INT}, + {GL_RGBA32I, GL_INT}, + {GL_RGBA32F, GL_FLOAT}}; + +const std::unordered_map gl_internal_format_to_format = { + {GL_R8, GL_RED}, + {GL_R8_SNORM, GL_RED}, + {GL_R8UI, GL_RED}, + {GL_R8I, GL_RED}, + {GL_R16, GL_RED}, + {GL_R16_SNORM, GL_RED}, + {GL_R16F, GL_RED}, + {GL_R16UI, GL_RED}, + {GL_R16I, GL_RED}, + {GL_R32UI, GL_RED}, + {GL_R32I, GL_RED}, + {GL_R32F, GL_RED}, + {GL_RG8, GL_RG}, + {GL_RG8_SNORM, GL_RG}, + {GL_RG8UI, GL_RG}, + {GL_RG8I, GL_RG}, + {GL_RG16, GL_RG}, + {GL_RG16_SNORM, GL_RG}, + {GL_RG16F, GL_RG}, + {GL_RG16UI, GL_RG}, + {GL_RG16I, GL_RG}, + {GL_RG32UI, GL_RG}, + {GL_RG32I, GL_RG}, + {GL_RG32F, GL_RG}, + {GL_RGB8, GL_RGB}, + {GL_RGB8_SNORM, GL_RGB}, + {GL_RGB8UI, GL_RGB}, + {GL_RGB8I, GL_RGB}, + {GL_RGB16, GL_RGB}, + {GL_RGB16_SNORM, GL_RGB}, + {GL_RGB16F, GL_RGB}, + {GL_RGB16UI, GL_RGB}, + {GL_RGB16I, GL_RGB}, + {GL_RGB32UI, GL_RGB}, + {GL_RGB32I, GL_RGB}, + {GL_RGB32F, GL_RGB}, + {GL_RGBA8, GL_RGBA}, + {GL_SRGB8_ALPHA8, GL_RGBA}, + {GL_RGBA8_SNORM, GL_RGBA}, + {GL_RGBA8UI, GL_RGBA}, + {GL_RGBA8I, GL_RGBA}, + {GL_RGBA16, GL_RGBA}, + {GL_RGBA16_SNORM, GL_RGBA}, + {GL_RGBA16F, GL_RGBA}, + {GL_RGBA16UI, GL_RGBA}, + {GL_RGBA16I, GL_RGBA}, + {GL_RGBA32UI, GL_RGBA}, + {GL_RGBA32I, GL_RGBA}, + {GL_RGBA32F, GL_RGBA}}; +} // namespace + std::string get_opengl_error_string(GLenum err) { switch (err) { #define PER_GL_ERR(x) \ @@ -92,7 +242,8 @@ GLPipeline::GLPipeline(const PipelineSourceDesc &desc, shader_id = glCreateShader(GL_COMPUTE_SHADER); const GLchar *source_cstr = (const GLchar *)desc.data; - glShaderSource(shader_id, 1, &source_cstr, nullptr); + int length = desc.size; + glShaderSource(shader_id, 1, &source_cstr, &length); glCompileShader(shader_id); int status = GL_TRUE; @@ -239,21 +390,32 @@ void GLCommandList::draw_indexed(uint32_t num_indicies, void GLCommandList::image_transition(DeviceAllocation img, ImageLayout old_layout, ImageLayout new_layout) { - TI_NOT_IMPLEMENTED; + auto cmd = std::make_unique(); + recorded_commands_.push_back(std::move(cmd)); } void GLCommandList::buffer_to_image(DeviceAllocation dst_img, DevicePtr src_buf, ImageLayout img_layout, const BufferImageCopyParams ¶ms) { - TI_NOT_IMPLEMENTED; + auto cmd = std::make_unique(); + cmd->params = params; + cmd->image = dst_img.alloc_id; + cmd->buffer = src_buf.alloc_id; + cmd->offset = src_buf.offset; + recorded_commands_.push_back(std::move(cmd)); } void GLCommandList::image_to_buffer(DevicePtr dst_buf, DeviceAllocation src_img, ImageLayout img_layout, const BufferImageCopyParams ¶ms) { - TI_NOT_IMPLEMENTED; + auto cmd = std::make_unique(); + cmd->params = params; + cmd->image = src_img.alloc_id; + cmd->buffer = dst_buf.alloc_id; + cmd->offset = dst_buf.offset; + recorded_commands_.push_back(std::move(cmd)); } void GLCommandList::run_commands() { @@ -299,11 +461,11 @@ DeviceAllocation GLDevice::allocate_memory(const AllocParams ¶ms) { alloc.alloc_id = buffer; if (params.host_read && params.host_write) { - buffer_to_access_[buffer] = GL_READ_WRITE; + buffer_to_access_[buffer] = GL_MAP_READ_BIT | GL_MAP_WRITE_BIT; } else if (params.host_read) { - buffer_to_access_[buffer] = GL_READ_ONLY; + buffer_to_access_[buffer] = GL_MAP_READ_BIT; } else if (params.host_write) { - buffer_to_access_[buffer] = GL_WRITE_ONLY; + buffer_to_access_[buffer] = GL_MAP_WRITE_BIT; } return alloc; @@ -333,6 +495,11 @@ void *GLDevice::map_range(DevicePtr ptr, uint64_t size) { } void *GLDevice::map(DeviceAllocation alloc) { + int size = 0; + glBindBuffer(GL_SHADER_STORAGE_BUFFER, alloc.alloc_id); + glGetBufferParameteriv(GL_SHADER_STORAGE_BUFFER, GL_BUFFER_SIZE, &size); + return map_range(alloc.get_ptr(0), size); + /* TI_ASSERT_INFO( buffer_to_access_.find(alloc.alloc_id) != buffer_to_access_.end(), "Buffer not created with host_read or write"); @@ -342,6 +509,7 @@ void *GLDevice::map(DeviceAllocation alloc) { buffer_to_access_.at(alloc.alloc_id)); check_opengl_error("glMapBuffer"); return mapped; + */ } void GLDevice::unmap(DevicePtr ptr) { @@ -397,18 +565,55 @@ std::unique_ptr GLDevice::create_surface(const SurfaceConfig &config) { } DeviceAllocation GLDevice::create_image(const ImageParams ¶ms) { - TI_NOT_IMPLEMENTED; - return kDeviceNullAllocation; + GLuint tex; + glGenTextures(1, &tex); + check_opengl_error("glGenTextures"); + + auto gl_texture_dims = GL_TEXTURE_2D; + if (params.dimension == ImageDimension::d1D) { + gl_texture_dims = GL_TEXTURE_1D; + } else if (params.dimension == ImageDimension::d2D) { + gl_texture_dims = GL_TEXTURE_2D; + } + + auto format = format_to_gl_internal_format.at(params.format); + + glBindTexture(gl_texture_dims, tex); + check_opengl_error("glBindTexture"); + if (params.dimension == ImageDimension::d1D) { + glTexStorage1D(gl_texture_dims, 1, format, params.x); + check_opengl_error("glTexStorage1D"); + } else if (params.dimension == ImageDimension::d2D) { + glTexStorage2D(gl_texture_dims, 1, format, params.x, params.y); + check_opengl_error("glTexStorage2D"); + } else { + glTexStorage3D(gl_texture_dims, 1, format, params.x, params.y, params.z); + check_opengl_error("glTexStorage3D"); + } + + DeviceAllocation alloc; + alloc.device = this; + alloc.alloc_id = tex; + + image_to_dims_[tex] = gl_texture_dims; + image_to_int_format_[tex] = format; + + return alloc; } void GLDevice::destroy_image(DeviceAllocation handle) { - TI_NOT_IMPLEMENTED; + glDeleteTextures(1, &handle.alloc_id); + check_opengl_error("glDeleteTextures"); + image_to_dims_.erase(handle.alloc_id); + image_to_int_format_.erase(handle.alloc_id); } void GLDevice::image_transition(DeviceAllocation img, ImageLayout old_layout, ImageLayout new_layout) { - TI_NOT_IMPLEMENTED; + glMemoryBarrier(GL_TEXTURE_FETCH_BARRIER_BIT | GL_TEXTURE_UPDATE_BARRIER_BIT | + GL_SHADER_IMAGE_ACCESS_BARRIER_BIT | + GL_FRAMEBUFFER_BARRIER_BIT); } void GLDevice::buffer_to_image(DeviceAllocation dst_img, @@ -480,15 +685,79 @@ void GLCommandList::CmdBufferCopy::execute() { void GLCommandList::CmdBufferFill::execute() { glBindBuffer(GL_SHADER_STORAGE_BUFFER, buffer); check_opengl_error("glBindBuffer"); - glClearBufferSubData(GL_SHADER_STORAGE_BUFFER, GL_R32UI, offset, size, GL_RED, - GL_UNSIGNED_INT, &data); - check_opengl_error("glClearBufferSubData"); + if (is_gles()) { + int buf_size = 0; + glGetBufferParameteriv(GL_SHADER_STORAGE_BUFFER, GL_BUFFER_SIZE, &buf_size); + TI_ASSERT(offset == 0 && data == 0 && size == buf_size && + "GLES only supports full clear"); + glBufferData(GL_SHADER_STORAGE_BUFFER, buf_size, nullptr, GL_DYNAMIC_READ); + check_opengl_error("glBufferData"); + } else { + glClearBufferSubData(GL_SHADER_STORAGE_BUFFER, GL_R32F, offset, size, + GL_RED, GL_FLOAT, &data); + check_opengl_error("glClearBufferSubData"); + } } void GLCommandList::CmdDispatch::execute() { glDispatchCompute(x, y, z); } +void GLCommandList::CmdImageTransition::execute() { + glMemoryBarrier(GL_TEXTURE_FETCH_BARRIER_BIT | GL_TEXTURE_UPDATE_BARRIER_BIT | + GL_SHADER_IMAGE_ACCESS_BARRIER_BIT | + GL_FRAMEBUFFER_BARRIER_BIT); +} + +void GLCommandList::CmdBufferToImage::execute() { + auto image_dims = device->get_image_gl_dims(image); + auto image_format = device->get_image_gl_int_dims(image); + auto gl_type = gl_internal_format_to_type.at(image_format); + + glBindTexture(image_dims, image); + check_opengl_error("glBindTexture"); + glBindBuffer(GL_PIXEL_UNPACK_BUFFER, buffer); + check_opengl_error("glBindBuffer"); + if (image_dims == GL_TEXTURE_1D) { + glTexSubImage1D(image_dims, /*level=*/0, params.image_offset.x, + params.image_extent.x, image_format, gl_type, + (void *)offset); + } else if (image_dims == GL_TEXTURE_2D) { + glTexSubImage2D(image_dims, /*level=*/0, params.image_offset.x, + params.image_offset.y, params.image_extent.x, + params.image_extent.y, image_format, gl_type, + (void *)offset); + } else { + glTexSubImage3D( + image_dims, /*level=*/0, params.image_offset.x, params.image_offset.y, + params.image_offset.z, params.image_extent.x, params.image_extent.y, + params.image_extent.z, image_format, gl_type, (void *)offset); + } + check_opengl_error("glTexSubImage"); + glBindTexture(image_dims, /*target=*/0); + glBindBuffer(GL_PIXEL_UNPACK_BUFFER, /*target=*/0); +} + +void GLCommandList::CmdImageToBuffer::execute() { + auto image_dims = device->get_image_gl_dims(image); + auto image_format = device->get_image_gl_int_dims(image); + auto gl_type = gl_internal_format_to_type.at(image_format); + auto unsized_format = gl_internal_format_to_format.at(image_format); + + glBindTexture(image_dims, image); + check_opengl_error("glBindTexture"); + glBindBuffer(GL_PIXEL_UNPACK_BUFFER, buffer); + check_opengl_error("glBindBuffer"); + TI_ASSERT_INFO(params.image_offset.x == 0 && params.image_offset.y == 0 && + params.image_offset.z == 0, + "OpenGL can only copy full images to buffer"); + glGetTexImage(/*level=*/0, image_format, unsized_format, gl_type, + (void *)offset); + check_opengl_error("glGetTexImage"); + glBindTexture(image_dims, /*target=*/0); + glBindBuffer(GL_PIXEL_UNPACK_BUFFER, /*target=*/0); +} + } // namespace opengl } // namespace lang } // namespace taichi diff --git a/taichi/backends/opengl/opengl_device.h b/taichi/backends/opengl/opengl_device.h index a24d5de3e4811..b32b848930650 100644 --- a/taichi/backends/opengl/opengl_device.h +++ b/taichi/backends/opengl/opengl_device.h @@ -2,13 +2,15 @@ #include "taichi/backends/device.h" -#include "glad/glad.h" +#include "glad/gl.h" #include "GLFW/glfw3.h" namespace taichi { namespace lang { namespace opengl { +class GLDevice; + void check_opengl_error(const std::string &msg = "OpenGL"); class GLResourceBinder : public ResourceBinder { @@ -127,6 +129,8 @@ class GLCommandList : public CommandList { struct Cmd { virtual void execute() { } + virtual ~Cmd() { + } }; struct CmdBindPipeline : public Cmd { @@ -164,6 +168,28 @@ class GLCommandList : public CommandList { void execute() override; }; + struct CmdImageTransition : public Cmd { + void execute() override; + }; + + struct CmdBufferToImage : public Cmd { + BufferImageCopyParams params; + GLuint image{0}; + GLuint buffer{0}; + size_t offset{0}; + GLDevice *device{nullptr}; + void execute() override; + }; + + struct CmdImageToBuffer : public Cmd { + BufferImageCopyParams params; + GLuint image{0}; + GLuint buffer{0}; + size_t offset{0}; + GLDevice *device{nullptr}; + void execute() override; + }; + std::vector> recorded_commands_; }; @@ -227,9 +253,19 @@ class GLDevice : public GraphicsDevice { ImageLayout img_layout, const BufferImageCopyParams ¶ms) override; + GLuint get_image_gl_dims(GLuint image) const { + return image_to_dims_.at(image); + } + + GLuint get_image_gl_int_dims(GLuint image) const { + return image_to_int_format_.at(image); + } + private: GLStream stream_; std::unordered_map buffer_to_access_; + std::unordered_map image_to_dims_; + std::unordered_map image_to_int_format_; }; class GLSurface : public Surface { diff --git a/taichi/backends/opengl/opengl_kernel_launcher.h b/taichi/backends/opengl/opengl_kernel_launcher.h index 228bf81969d57..183a9f63ce073 100644 --- a/taichi/backends/opengl/opengl_kernel_launcher.h +++ b/taichi/backends/opengl/opengl_kernel_launcher.h @@ -2,6 +2,7 @@ #include "taichi/lang_util.h" #include "taichi/backends/device.h" +#include "taichi/ir/snode.h" #include @@ -9,18 +10,19 @@ TLANG_NAMESPACE_BEGIN namespace opengl { -struct CompiledProgram; +struct CompiledTaichiKernel; struct OpenGlRuntimeImpl; struct OpenGlRuntime; class GLBuffer; -class DeviceCompiledProgram; +class DeviceCompiledTaichiKernel; struct OpenGlRuntime { - std::unique_ptr impl; - std::unique_ptr device{nullptr}; + std::shared_ptr device{nullptr}; + std::unique_ptr impl{nullptr}; + std::vector> saved_arg_bufs; OpenGlRuntime(); ~OpenGlRuntime(); - DeviceCompiledProgram *keep(CompiledProgram &&program); + DeviceCompiledTaichiKernel *keep(CompiledTaichiKernel &&program); // FIXME: Currently GLSL codegen only supports single root void add_snode_tree(size_t size); @@ -30,10 +32,12 @@ struct OpenGlRuntime { using SNodeId = std::string; struct SNodeInfo { + const SNode *snode; size_t stride; size_t length; std::vector children_offsets; size_t elem_stride; + size_t mem_offset_in_root{0}; }; struct StructCompiledResult { diff --git a/taichi/backends/opengl/opengl_kernel_util.h b/taichi/backends/opengl/opengl_kernel_util.h index 22752dcdafdd9..03937c9f7b480 100644 --- a/taichi/backends/opengl/opengl_kernel_util.h +++ b/taichi/backends/opengl/opengl_kernel_util.h @@ -6,13 +6,20 @@ #include "taichi/ir/snode.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { class SNode; namespace opengl { -constexpr int taichi_opengl_earg_base = taichi_max_num_args * sizeof(uint64_t); +constexpr int taichi_opengl_extra_args_base = + taichi_max_num_args * sizeof(uint64_t); +constexpr int taichi_opengl_ret_base = + taichi_opengl_extra_args_base + + taichi_max_num_args * taichi_max_num_indices * sizeof(int); +constexpr int taichi_opengl_external_arr_base = + taichi_opengl_ret_base + taichi_max_num_ret_value * sizeof(uint64_t); struct UsedFeature { // types: @@ -25,15 +32,13 @@ struct UsedFeature { bool float64{false}; // buffers: + bool buf_data{false}; bool buf_args{false}; - bool buf_earg{false}; - bool buf_extr{false}; bool buf_gtmp{false}; - bool buf_retr{false}; + std::unordered_map arr_arg_to_bind_idx; // utilties: bool fast_pow{false}; - bool listman{false}; bool random{false}; bool print{false}; bool reduction{false}; @@ -46,19 +51,18 @@ struct UsedFeature { enum class GLBufId { Root = 0, - Runtime = 6, - Listman = 7, Gtmp = 1, Args = 2, - Retr = 3, - Extr = 4, + Runtime = 3, + // This is indeed the beginning id for |Arr|s so |Arr| MUST be the last item. + Arr = 4, }; struct IOV { - void *base; - size_t size; + void *base{nullptr}; + size_t size{0}; }; } // namespace opengl - -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/opengl/opengl_program.cpp b/taichi/backends/opengl/opengl_program.cpp index a8fd7c7a2cd59..163ad3a089200 100644 --- a/taichi/backends/opengl/opengl_program.cpp +++ b/taichi/backends/opengl/opengl_program.cpp @@ -8,14 +8,15 @@ namespace lang { FunctionType OpenglProgramImpl::compile(Kernel *kernel, OffloadedStmt *offloaded) { #ifdef TI_WITH_OPENGL - opengl::OpenglCodeGen codegen(kernel->name, &opengl_struct_compiled_.value()); + opengl::OpenglCodeGen codegen(kernel->name, &opengl_struct_compiled_.value(), + config->allow_nv_shader_extension); auto ptr = opengl_runtime_->keep(codegen.compile(*kernel)); - return [ptr, runtime = opengl_runtime_.get()](Context &ctx) { - ptr->launch(ctx, runtime); + return [ptr, kernel, runtime = opengl_runtime_.get()](RuntimeContext &ctx) { + ptr->launch(ctx, kernel, runtime); }; #else - return [](Context &ctx) {}; + return [](RuntimeContext &ctx) {}; #endif } @@ -26,24 +27,39 @@ void OpenglProgramImpl::materialize_runtime(MemoryPool *memory_pool, *result_buffer_ptr = (uint64 *)memory_pool->allocate( sizeof(uint64) * taichi_result_buffer_entries, 8); opengl_runtime_ = std::make_unique(); + opengl_runtime_->result_buffer = *result_buffer_ptr; #else TI_NOT_IMPLEMENTED; #endif } +DeviceAllocation OpenglProgramImpl::allocate_memory_ndarray( + std::size_t alloc_size, + uint64 *result_buffer) { + return opengl_runtime_->device->allocate_memory( + {alloc_size, /*host_write=*/true, /*host_read=*/true, + /*export_sharing=*/false}); +} -void OpenglProgramImpl::materialize_snode_tree( +std::shared_ptr OpenglProgramImpl::get_device_shared() { + return opengl_runtime_->device; +} + +void OpenglProgramImpl::compile_snode_tree_types( SNodeTree *tree, - std::vector> &, - std::unordered_map &, - uint64 *result_buffer) { -#ifdef TI_WITH_OPENGL + std::vector> &snode_trees) { // TODO: support materializing multiple snode trees - auto *const root = tree->root(); opengl::OpenglStructCompiler scomp; - opengl_struct_compiled_ = scomp.run(*root); + opengl_struct_compiled_ = scomp.run(*(tree->root())); TI_TRACE("OpenGL root buffer size: {} B", opengl_struct_compiled_->root_size); +} + +void OpenglProgramImpl::materialize_snode_tree( + SNodeTree *tree, + std::vector> &snode_trees_, + uint64 *result_buffer) { +#ifdef TI_WITH_OPENGL + compile_snode_tree_types(tree, snode_trees_); opengl_runtime_->add_snode_tree(opengl_struct_compiled_->root_size); - opengl_runtime_->result_buffer = result_buffer; #else TI_NOT_IMPLEMENTED; #endif @@ -54,7 +70,7 @@ std::unique_ptr OpenglProgramImpl::make_aot_module_builder() { // fine to JIT to GLSL on systems without the OpenGL runtime. #ifdef TI_WITH_OPENGL return std::make_unique( - opengl_struct_compiled_.value()); + opengl_struct_compiled_.value(), config->allow_nv_shader_extension); #else TI_NOT_IMPLEMENTED; return nullptr; diff --git a/taichi/backends/opengl/opengl_program.h b/taichi/backends/opengl/opengl_program.h index d85f96dcf09f3..d6f8e9e120307 100644 --- a/taichi/backends/opengl/opengl_program.h +++ b/taichi/backends/opengl/opengl_program.h @@ -16,6 +16,7 @@ namespace taichi { namespace lang { + class OpenglProgramImpl : public ProgramImpl { public: OpenglProgramImpl(CompileConfig &config) : ProgramImpl(config) { @@ -32,27 +33,36 @@ class OpenglProgramImpl : public ProgramImpl { KernelProfilerBase *profiler, uint64 **result_buffer_ptr) override; + void compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) override; + void materialize_snode_tree( SNodeTree *tree, std::vector> &snode_trees_, - std::unordered_map &snodes, uint64 *result_buffer) override; void synchronize() override { } + DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, + uint64 *result_buffer) override; + + std::shared_ptr get_device_shared() override; + std::unique_ptr make_aot_module_builder() override; - virtual void destroy_snode_tree(SNodeTree *snode_tree) override { + void destroy_snode_tree(SNodeTree *snode_tree) override { TI_NOT_IMPLEMENTED } - ~OpenglProgramImpl() { + ~OpenglProgramImpl() override { } private: std::optional opengl_struct_compiled_; std::unique_ptr opengl_runtime_; }; + } // namespace lang } // namespace taichi diff --git a/taichi/backends/opengl/opengl_utils.cpp b/taichi/backends/opengl/opengl_utils.cpp new file mode 100644 index 0000000000000..dee9d71ae950f --- /dev/null +++ b/taichi/backends/opengl/opengl_utils.cpp @@ -0,0 +1,35 @@ +#include "taichi/backends/opengl/opengl_utils.h" +#include "glad/gl.h" + +namespace taichi { +namespace lang { +namespace opengl { + +uint32_t to_gl_dtype_enum(DataType dt) { + if (dt == PrimitiveType::u64) { + return GL_UNSIGNED_INT64_ARB; + } else if (dt == PrimitiveType::i64) { + return GL_INT64_ARB; + } else if (dt == PrimitiveType::u32) { + return GL_UNSIGNED_INT; + } else if (dt == PrimitiveType::i32) { + return GL_INT; + } else if (dt == PrimitiveType::u16) { + return GL_UNSIGNED_SHORT; + } else if (dt == PrimitiveType::i16) { + return GL_SHORT; + } else if (dt == PrimitiveType::u8) { + return GL_UNSIGNED_BYTE; + } else if (dt == PrimitiveType::i8) { + return GL_BYTE; + } else if (dt == PrimitiveType::f64) { + return GL_DOUBLE; + } else if (dt == PrimitiveType::f32) { + return GL_FLOAT; + } else { + TI_NOT_IMPLEMENTED + } +} +} // namespace opengl +} // namespace lang +} // namespace taichi diff --git a/taichi/ir/ir_modified.h b/taichi/backends/opengl/opengl_utils.h similarity index 51% rename from taichi/ir/ir_modified.h rename to taichi/backends/opengl/opengl_utils.h index d5669690084e5..1767e6883cc90 100644 --- a/taichi/ir/ir_modified.h +++ b/taichi/backends/opengl/opengl_utils.h @@ -1,9 +1,12 @@ #pragma once +#include "taichi/ir/type.h" namespace taichi { namespace lang { +namespace opengl { -class IRModified {}; +uint32_t to_gl_dtype_enum(DataType dt); +} } // namespace lang } // namespace taichi diff --git a/taichi/backends/opengl/shaders/atomics_macro_f32.glsl.h b/taichi/backends/opengl/shaders/atomics_macro_f32.glsl.h index d0637dcf35e7f..8e80c619553eb 100644 --- a/taichi/backends/opengl/shaders/atomics_macro_f32.glsl.h +++ b/taichi/backends/opengl/shaders/atomics_macro_f32.glsl.h @@ -1,54 +1,44 @@ // vim: ft=glsl -// clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" -#ifdef TI_INSIDE_OPENGL_CODEGEN -#define OPENGL_BEGIN_ATOMIC_F32_DEF constexpr auto kOpenGLAtomicF32SourceCode = -#define OPENGL_END_ATOMIC_F32_DEF ; -#else +#ifndef TI_INSIDE_OPENGL_CODEGEN static_assert(false, "Do not include"); -#define OPENGL_BEGIN_ATOMIC_F32_DEF -#define OPENGL_END_ATOMIC_F32_DEF #endif -OPENGL_BEGIN_ATOMIC_F32_DEF -"#define DEFINE_ATOMIC_F32_FUNCTIONS(NAME) " -STR( -float atomicAdd_##NAME##_f32(int addr, float rhs) { - int old, new, ret; - do { - old = _##NAME##_i32_[addr]; - new = floatBitsToInt((intBitsToFloat(old) + rhs)); - } while (old != atomicCompSwap(_##NAME##_i32_[addr], old, new)); - return intBitsToFloat(old); -} -float atomicSub_##NAME##_f32(int addr, float rhs) { - int old, new, ret; - do { - old = _##NAME##_i32_[addr]; - new = floatBitsToInt((intBitsToFloat(old) - rhs)); - } while (old != atomicCompSwap(_##NAME##_i32_[addr], old, new)); - return intBitsToFloat(old); -} -float atomicMax_##NAME##_f32(int addr, float rhs) { - int old, new, ret; - do { - old = _##NAME##_i32_[addr]; - new = floatBitsToInt(max(intBitsToFloat(old), rhs)); - } while (old != atomicCompSwap(_##NAME##_i32_[addr], old, new)); - return intBitsToFloat(old); -} -float atomicMin_##NAME##_f32(int addr, float rhs) { - int old, new, ret; - do { - old = _##NAME##_i32_[addr]; - new = floatBitsToInt(min(intBitsToFloat(old), rhs)); - } while (old != atomicCompSwap(_##NAME##_i32_[addr], old, new)); - return intBitsToFloat(old); -} -\n -) -OPENGL_END_ATOMIC_F32_DEF - -#undef OPENGL_BEGIN_ATOMIC_F32_DEF -#undef OPENGL_END_ATOMIC_F32_DEF +#define GENERATE_OPENGL_ATOMIC_F32(NAME) \ + constexpr auto kOpenGlAtomicF32Source_##NAME = STR( \ + float atomicAdd_##NAME##_f32(int addr, float rhs) { \ + int old_val, new_val, ret; \ + do { \ + old_val = _##NAME##_i32_[addr]; \ + new_val = floatBitsToInt((intBitsToFloat(old_val) + rhs)); \ + } while (old_val != \ + atomicCompSwap(_##NAME##_i32_[addr], old_val, new_val)); \ + return intBitsToFloat(old_val); \ + } float atomicSub_##NAME##_f32(int addr, float rhs) { \ + int old_val, new_val, ret; \ + do { \ + old_val = _##NAME##_i32_[addr]; \ + new_val = floatBitsToInt((intBitsToFloat(old_val) - rhs)); \ + } while (old_val != \ + atomicCompSwap(_##NAME##_i32_[addr], old_val, new_val)); \ + return intBitsToFloat(old_val); \ + } float atomicMax_##NAME##_f32(int addr, float rhs) { \ + int old_val, new_val, ret; \ + do { \ + old_val = _##NAME##_i32_[addr]; \ + new_val = floatBitsToInt(max(intBitsToFloat(old_val), rhs)); \ + } while (old_val != \ + atomicCompSwap(_##NAME##_i32_[addr], old_val, new_val)); \ + return intBitsToFloat(old_val); \ + } float atomicMin_##NAME##_f32(int addr, float rhs) { \ + int old_val, new_val, ret; \ + do { \ + old_val = _##NAME##_i32_[addr]; \ + new_val = floatBitsToInt(min(intBitsToFloat(old_val), rhs)); \ + } while (old_val != \ + atomicCompSwap(_##NAME##_i32_[addr], old_val, new_val)); \ + return intBitsToFloat(old_val); \ + }); +// NOLINTEND(*) diff --git a/taichi/backends/opengl/shaders/atomics_macro_f64.glsl.h b/taichi/backends/opengl/shaders/atomics_macro_f64.glsl.h index 784ec1641e8f4..2795d88ad4fa0 100644 --- a/taichi/backends/opengl/shaders/atomics_macro_f64.glsl.h +++ b/taichi/backends/opengl/shaders/atomics_macro_f64.glsl.h @@ -1,5 +1,6 @@ // vim: ft=glsl // clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" #ifdef TI_INSIDE_OPENGL_CODEGEN @@ -15,36 +16,36 @@ OPENGL_BEGIN_ATOMIC_F64_DEF "#define DEFINE_ATOMIC_F64_FUNCTIONS(NAME) " STR( double atomicAdd_##NAME_f64(int addr, double rhs) { - int old, new, ret; + int old_val, new_val, ret; do { - old = _##NAME##_i64_[addr]; - new = floatBitsToInt((intBitsToFloat(old) + rhs)); - } while (old != atomicCompSwap(_##NAME##_i64_[addr], old, new)); - return intBitsToFloat(old); + old_val = _##NAME##_i64_[addr]; + new_val = floatBitsToInt((intBitsToFloat(old_val) + rhs)); + } while (old_val != atomicCompSwap(_##NAME##_i64_[addr], old_val, new_val)); + return intBitsToFloat(old_val); } double atomicSub_##NAME##_f64(int addr, double rhs) { - int old, new, ret; + int old_val, new_val, ret; do { - old = _##NAME##_i64_[addr]; - new = floatBitsToInt((intBitsToFloat(old) - rhs)); - } while (old != atomicCompSwap(_##NAME##_i64_[addr], old, new)); - return intBitsToFloat(old); + old_val = _##NAME##_i64_[addr]; + new_val = floatBitsToInt((intBitsToFloat(old_val) - rhs)); + } while (old_val != atomicCompSwap(_##NAME##_i64_[addr], old_val, new_val)); + return intBitsToFloat(old_val); } double atomicMax_##NAME##_f64(int addr, double rhs) { - int old, new, ret; + int old_val, new_val, ret; do { - old = _##NAME##_i64_[addr]; - new = floatBitsToInt(max(intBitsToFloat(old), rhs)); - } while (old != atomicCompSwap(_##NAME##_i64_[addr], old, new)); - return intBitsToFloat(old); + old_val = _##NAME##_i64_[addr]; + new_val = floatBitsToInt(max(intBitsToFloat(old_val), rhs)); + } while (old_val != atomicCompSwap(_##NAME##_i64_[addr], old_val, new_val)); + return intBitsToFloat(old_val); } double atomicMin_##NAME##_f64(int addr, double rhs) { - int old, new, ret; + int old_val, new_val, ret; do { - old = _##NAME##_i64_[addr]; - new = floatBitsToInt(min(intBitsToFloat(old), rhs)); - } while (old != atomicCompSwap(_##NAME##_i64_[addr], old, new)); - return intBitsToFloat(old); + old_val = _##NAME##_i64_[addr]; + new_val = floatBitsToInt(min(intBitsToFloat(old_val), rhs)); + } while (old_val != atomicCompSwap(_##NAME##_i64_[addr], old_val, new_val)); + return intBitsToFloat(old_val); } \n ) @@ -52,3 +53,4 @@ OPENGL_END_ATOMIC_F64_DEF #undef OPENGL_BEGIN_ATOMIC_F64_DEF #undef OPENGL_END_ATOMIC_F64_DEF +// NOLINTEND(*) diff --git a/taichi/backends/opengl/shaders/fast_pow.glsl.h b/taichi/backends/opengl/shaders/fast_pow.glsl.h index 98e18baba90da..1d5b33ffd8dc9 100644 --- a/taichi/backends/opengl/shaders/fast_pow.glsl.h +++ b/taichi/backends/opengl/shaders/fast_pow.glsl.h @@ -1,5 +1,6 @@ // vim: ft=glsl // clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" #ifdef TI_INSIDE_OPENGL_CODEGEN @@ -51,3 +52,4 @@ OPENGL_END_FAST_POW_DEF #undef OPENGL_BEGIN_FAST_POW_DEF #undef OPENGL_END_FAST_POW_DEF +// NOLINTEND(*) diff --git a/taichi/backends/opengl/shaders/indirect.glsl.h b/taichi/backends/opengl/shaders/indirect.glsl.h index 546c96f989b3d..04b215a7ef50d 100644 --- a/taichi/backends/opengl/shaders/indirect.glsl.h +++ b/taichi/backends/opengl/shaders/indirect.glsl.h @@ -1,5 +1,6 @@ // vim: ft=glsl // clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" "#version 430 core\nprecision highp float;\n" #define TI_INSIDE_OPENGL_CODEGEN @@ -39,3 +40,4 @@ void _compute_indirect( // get_indirect_evaluator() will prepend a main here, with template arguments ) +// NOLINTEND(*) diff --git a/taichi/backends/opengl/shaders/listman.h b/taichi/backends/opengl/shaders/listman.h deleted file mode 100644 index 37852e09fd270..0000000000000 --- a/taichi/backends/opengl/shaders/listman.h +++ /dev/null @@ -1,29 +0,0 @@ -#define MAX_LIST (1024 * 256) // * 4 = 1 MB - -#ifdef TI_INSIDE_OPENGL_CODEGEN -#define OPENG_BEGIN_LISTMAN_DEF constexpr auto kOpenGLListmanSourceCode = -// clang-format off - -#include "taichi/util/macros.h" -OPENG_BEGIN_LISTMAN_DEF -STR( -layout(std430, binding = 7) buffer listman { - int _list_len_; - int _list_[]; -}; -)"\n"; -#undef OPENG_BEGIN_LISTMAN_DEF - -// clang-format on -#else - -TLANG_NAMESPACE_BEGIN - -struct GLSLListman { - int list_len; - int list[MAX_LIST]; -}; - -TLANG_NAMESPACE_END - -#endif diff --git a/taichi/backends/opengl/shaders/print.glsl.h b/taichi/backends/opengl/shaders/print.glsl.h index a49e8f8f4a535..e4eabfc216581 100644 --- a/taichi/backends/opengl/shaders/print.glsl.h +++ b/taichi/backends/opengl/shaders/print.glsl.h @@ -1,5 +1,6 @@ // vim: ft=glsl // clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" #ifdef TI_INSIDE_OPENGL_CODEGEN @@ -39,3 +40,4 @@ OPENGL_END_PRINT_DEF #undef OPENGL_BEGIN_PRINT_DEF #undef OPENGL_END_PRINT_DEF +// NOLINTEND(*) diff --git a/taichi/backends/opengl/shaders/random.glsl.h b/taichi/backends/opengl/shaders/random.glsl.h index 6355fbe5430b7..258ef7f465e65 100644 --- a/taichi/backends/opengl/shaders/random.glsl.h +++ b/taichi/backends/opengl/shaders/random.glsl.h @@ -1,5 +1,6 @@ // vim: ft=glsl // clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" #ifdef TI_INSIDE_OPENGL_CODEGEN @@ -51,3 +52,4 @@ OPENGL_END_RANDOM_DEF #undef OPENGL_BEGIN_RANDOM_DEF #undef OPENGL_END_RANDOM_DEF +// NOLINTEND(*) diff --git a/taichi/backends/opengl/shaders/reduction.glsl.h b/taichi/backends/opengl/shaders/reduction.glsl.h index 25fd885e72abc..29a4d95fb4cb0 100644 --- a/taichi/backends/opengl/shaders/reduction.glsl.h +++ b/taichi/backends/opengl/shaders/reduction.glsl.h @@ -1,55 +1,48 @@ // vim: ft=glsl -// clang-format off +// NOLINTBEGIN(*) #include "taichi/util/macros.h" constexpr auto kOpenGLReductionCommon = STR( -shared float _reduction_temp_float[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z]; -shared int _reduction_temp_int[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z]; -shared uint _reduction_temp_uint[gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z]; -float add(float a, float b) { return a + b; } -int add(int a, int b) { return a + b; } -uint add(uint a, uint b) { return a + b; } -\n -); + shared float _reduction_temp_float[gl_WorkGroupSize.x * gl_WorkGroupSize.y * + gl_WorkGroupSize.z]; + shared int _reduction_temp_int[gl_WorkGroupSize.x * gl_WorkGroupSize.y * + gl_WorkGroupSize.z]; + shared uint _reduction_temp_uint[gl_WorkGroupSize.x * gl_WorkGroupSize.y * + gl_WorkGroupSize.z]; + float add(float a, float b) { return a + b; } int add(int a, int b) { + return a + b; + } uint add(uint a, uint b) { return a + b; } +\n); -#ifdef TI_INSIDE_OPENGL_CODEGEN -#define OPENGL_BEGIN_REDUCTION_DEF constexpr auto kOpenGLReductionSourceCode = -#define OPENGL_END_REDUCTION_DEF ; -#else +#ifndef TI_INSIDE_OPENGL_CODEGEN static_assert(false, "Do not include"); -#define OPENGL_BEGIN_REDUCTION_DEF -#define OPENGL_END_REDUCTION_DEF #endif -OPENGL_BEGIN_REDUCTION_DEF -"#define DEFINE_REDUCTION_FUNCTIONS(OP, TYPE) " -STR( -TYPE reduction_workgroup_##OP##_##TYPE##(in TYPE r) { - _reduction_temp_##TYPE##[gl_LocalInvocationIndex] = r; - barrier(); - memoryBarrierShared(); - const int group_size = int(gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z); - const int depth = int(ceil(log2(float(group_size)))); - for (int i = 0; i < depth; ++i) { - const int radix = 1 << (i + 1); - const int stride = 1 << i; - const int cmp_index = int(gl_LocalInvocationIndex) + stride; - if (gl_LocalInvocationIndex % radix == 0 && cmp_index < group_size) { - _reduction_temp_##TYPE##[gl_LocalInvocationIndex] = ##OP##( - _reduction_temp_##TYPE##[gl_LocalInvocationIndex], - _reduction_temp_##TYPE##[cmp_index] - ); - } - barrier(); - memoryBarrierShared(); - } - const TYPE result = _reduction_temp_##TYPE##[0]; - barrier(); - return result; -} -\n -) -OPENGL_END_REDUCTION_DEF +#define GENERATE_OPENGL_REDUCTION_FUNCTIONS(OP, TYPE) \ + constexpr auto kOpenGlReductionSource_##OP##_##TYPE = \ + STR(TYPE reduction_workgroup_##OP##_##TYPE(in TYPE r) { \ + _reduction_temp_##TYPE[gl_LocalInvocationIndex] = r; \ + barrier(); \ + memoryBarrierShared(); \ + const int group_size = \ + int(gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z); \ + const int depth = int(ceil(log2(float(group_size)))); \ + for (int i = 0; i < depth; ++i) { \ + const int radix = 1 << (i + 1); \ + const int stride = 1 << i; \ + const int cmp_index = int(gl_LocalInvocationIndex) + stride; \ + if (gl_LocalInvocationIndex % radix == 0 && \ + cmp_index < group_size) { \ + _reduction_temp_##TYPE[gl_LocalInvocationIndex] = \ + OP(_reduction_temp_##TYPE[gl_LocalInvocationIndex], \ + _reduction_temp_##TYPE[cmp_index]); \ + } \ + barrier(); \ + memoryBarrierShared(); \ + } \ + const TYPE result = _reduction_temp_##TYPE[0]; \ + barrier(); \ + return result; \ + }); -#undef OPENGL_BEGIN_REDUCTION_DEF -#undef OPENGL_END_REDUCTION_DEF +// NOLINTEND(*) diff --git a/taichi/backends/opengl/struct_opengl.cpp b/taichi/backends/opengl/struct_opengl.cpp index 638b5c693ae01..3fb0146be4147 100644 --- a/taichi/backends/opengl/struct_opengl.cpp +++ b/taichi/backends/opengl/struct_opengl.cpp @@ -1,30 +1,36 @@ #include "struct_opengl.h" #include "taichi/ir/snode.h" +#include TLANG_NAMESPACE_BEGIN namespace opengl { OpenglStructCompiler::CompiledResult OpenglStructCompiler::run(SNode &node) { TI_ASSERT(node.type == SNodeType::root); - collect_snodes(node); - // The host side has run this! - // infer_snode_properties(node); - auto snodes_rev = snodes_; - std::reverse(snodes_rev.begin(), snodes_rev.end()); + generate_snode_tree(node); - for (auto &n : snodes_rev) { - generate_types(*n); - } CompiledResult result; + result.root_size = snode_map_.at(node.node_type_name).stride; result.snode_map = std::move(snode_map_); - result.root_size = compute_snode_size(node); result.root_snode_type_name = node.node_type_name; return result; } -void OpenglStructCompiler::collect_snodes(SNode &snode) { +void OpenglStructCompiler::generate_snode_tree(const SNode &root) { + collect_snodes(root); + // The host side has run this! + // infer_snode_properties(node); + + for (int i = snodes_.size() - 1; i >= 0; i--) { + generate_types(*snodes_[i]); + } + snode_map_.at(root.node_type_name).mem_offset_in_root = 0; + align_as_elem_stride(root); +} + +void OpenglStructCompiler::collect_snodes(const SNode &snode) { snodes_.push_back(&snode); for (int ch_id = 0; ch_id < (int)snode.ch.size(); ch_id++) { auto &ch = snode.ch[ch_id]; @@ -37,7 +43,8 @@ void OpenglStructCompiler::generate_types(const SNode &snode) { const auto &node_name = snode.node_type_name; const auto child_name = node_name + "_ch"; auto &snode_info = snode_map_[node_name]; - auto &snode_child_info = snode_map_[child_name]; + snode_info.snode = &snode; + SNodeInfo snode_child_info; if (!is_place) { size_t stride_num = 0; snode_info.children_offsets.resize(snode.ch.size()); @@ -64,15 +71,10 @@ void OpenglStructCompiler::generate_types(const SNode &snode) { if (is_place) { const auto dt_name = opengl_data_type_name(snode.dt); snode_info.stride = data_type_size(snode.dt); - } else if (snode.type == SNodeType::dense || - snode.type == SNodeType::dynamic || - snode.type == SNodeType::root) { - int n = snode.num_cells_per_container; - // the `length` field of a dynamic SNode is at it's end: - // | x[0] | x[1] | x[2] | x[3] | ... | len | - int extension = opengl_get_snode_meta_size(snode); + } else if (snode.type == SNodeType::dense || snode.type == SNodeType::root) { + int64 n = snode.num_cells_per_container; snode_info.length = n; - snode_info.stride = snode_child_info.stride * n + extension; // my stride + snode_info.stride = snode_child_info.stride * n; // my stride snode_info.elem_stride = snode_child_info.stride; // my child stride } else { TI_ERROR( @@ -84,17 +86,45 @@ void OpenglStructCompiler::generate_types(const SNode &snode) { } } -size_t OpenglStructCompiler::compute_snode_size(const SNode &snode) { - if (snode.is_place()) { - return data_type_size(snode.dt); - } +namespace { +template +std::vector sort_index_by(const std::vector &v) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), + [&v](size_t i1, size_t i2) { return v[i1] < v[i2]; }); + return idx; +} +} // namespace + +void OpenglStructCompiler::align_as_elem_stride(const SNode &snode) { size_t ch_size = 0; - for (const auto &ch : snode.ch) { - ch_size += compute_snode_size(*ch); + auto &snode_meta = snode_map_.at(snode.node_type_name); + if (snode.is_place()) { + ch_size = data_type_size(snode.dt); + } else { + // Sort snode.ch by snode_meta.children_offsets so that we compute + // the mem_offset_in_root in the right order. + auto sorted_indices = sort_index_by(snode_meta.children_offsets); + for (size_t i : sorted_indices) { + auto offset = ch_size + snode_meta.mem_offset_in_root; + // Pad so that the base address of snode.ch[i] is multiple of its + // elem_stride. + auto &ch_snode_meta = snode_map_.at(snode.ch[i]->node_type_name); + auto alignment = ch_snode_meta.elem_stride; + auto alignment_bytes = + alignment ? alignment - 1 - (offset + alignment - 1) % alignment : 0; + auto ch_mem_offset_in_root = offset + alignment_bytes; + ch_snode_meta.mem_offset_in_root = ch_mem_offset_in_root; + snode_meta.children_offsets[i] = + ch_mem_offset_in_root - snode_meta.mem_offset_in_root; + + align_as_elem_stride(*snode.ch[i]); + ch_size += (alignment_bytes + ch_snode_meta.stride); + } } - int n = snode.num_cells_per_container; - return n * ch_size + opengl_get_snode_meta_size(snode); + snode_meta.elem_stride = ch_size; + snode_meta.stride = snode.num_cells_per_container * ch_size; } - } // namespace opengl TLANG_NAMESPACE_END diff --git a/taichi/backends/opengl/struct_opengl.h b/taichi/backends/opengl/struct_opengl.h index 4fc0124cedac3..475bf97d52270 100644 --- a/taichi/backends/opengl/struct_opengl.h +++ b/taichi/backends/opengl/struct_opengl.h @@ -21,11 +21,12 @@ class OpenglStructCompiler { CompiledResult run(SNode &node); private: - void collect_snodes(SNode &snode); + void collect_snodes(const SNode &snode); void generate_types(const SNode &snode); - size_t compute_snode_size(const SNode &sn); + void generate_snode_tree(const SNode &root); + void align_as_elem_stride(const SNode &sn); - std::vector snodes_; + std::vector snodes_; std::unordered_map snode_map_; }; diff --git a/taichi/backends/vulkan/aot_module_builder_impl.cpp b/taichi/backends/vulkan/aot_module_builder_impl.cpp new file mode 100644 index 0000000000000..5ec480b2dd5a6 --- /dev/null +++ b/taichi/backends/vulkan/aot_module_builder_impl.cpp @@ -0,0 +1,202 @@ +#include "taichi/backends/vulkan/aot_module_builder_impl.h" + +#include +#include + +#include "taichi/aot/module_data.h" +#include "taichi/codegen/spirv/spirv_codegen.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +namespace { +class AotDataConverter { + public: + static aot::ModuleData convert(const TaichiAotData &in) { + AotDataConverter c{}; + return c.visit(in); + } + + private: + explicit AotDataConverter() = default; + + aot::ModuleData visit(const TaichiAotData &in) const { + aot::ModuleData res{}; + for (const auto &ker : in.kernels) { + auto val = visit(ker); + res.kernels[ker.name] = val; + } + res.fields = in.fields; + res.root_buffer_size = in.root_buffer_size; + return res; + } + + aot::CompiledTaichiKernel visit( + const spirv::TaichiKernelAttributes &in) const { + aot::CompiledTaichiKernel res{}; + res.tasks.reserve(in.tasks_attribs.size()); + for (const auto &t : in.tasks_attribs) { + res.tasks.push_back(visit(t)); + } + res.args_count = in.ctx_attribs.args().size(); + res.rets_count = in.ctx_attribs.rets().size(); + res.args_buffer_size = in.ctx_attribs.args_bytes(); + res.rets_buffer_size = in.ctx_attribs.rets_bytes(); + for (const auto &arg : in.ctx_attribs.args()) { + if (!arg.is_array) { + aot::ScalarArg scalar_arg{}; + scalar_arg.dtype_name = arg.dt.to_string(); + scalar_arg.offset_in_args_buf = arg.offset_in_mem; + res.scalar_args[arg.index] = scalar_arg; + } else { + aot::ArrayArg arr_arg{}; + arr_arg.dtype_name = arg.dt.to_string(); + arr_arg.field_dim = arg.field_dim; + arr_arg.element_shape = arg.element_shape; + arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t); + res.arr_args[arg.index] = arr_arg; + } + } + return res; + } + + aot::CompiledOffloadedTask visit(const TaskAttributes &in) const { + aot::CompiledOffloadedTask res{}; + res.type = offloaded_task_type_name(in.task_type); + res.name = in.name; + // TODO: update range_hint after ndarray is supported on vulkan. + if (in.range_for_attribs && in.range_for_attribs->const_begin && + in.range_for_attribs->const_end) { + res.range_hint = std::to_string(in.range_for_attribs->end - + in.range_for_attribs->begin); + } + res.gpu_block_size = in.advisory_num_threads_per_group; + return res; + } +}; + +} // namespace +AotModuleBuilderImpl::AotModuleBuilderImpl( + const std::vector &compiled_structs) + : compiled_structs_(compiled_structs) { + aot_target_device_ = std::make_unique(Arch::vulkan); + if (!compiled_structs.empty()) { + ti_aot_data_.root_buffer_size = compiled_structs[0].root_size; + } +} + +uint32_t AotModuleBuilderImpl::to_vk_dtype_enum(DataType dt) { + if (dt == PrimitiveType::u64) { + return 0; + } else if (dt == PrimitiveType::i64) { + return 1; + } else if (dt == PrimitiveType::u32) { + return 2; + } else if (dt == PrimitiveType::i32) { + return 3; + } else if (dt == PrimitiveType::u16) { + return 4; + } else if (dt == PrimitiveType::i16) { + return 5; + } else if (dt == PrimitiveType::u8) { + return 6; + } else if (dt == PrimitiveType::i8) { + return 7; + } else if (dt == PrimitiveType::f64) { + return 8; + } else if (dt == PrimitiveType::f32) { + return 9; + } else { + TI_NOT_IMPLEMENTED + } +} + +std::string AotModuleBuilderImpl::write_spv_file( + const std::string &output_dir, + const TaskAttributes &k, + const std::vector &source_code) const { + const std::string spv_path = fmt::format("{}/{}.spv", output_dir, k.name); + std::ofstream fs(spv_path, std::ios_base::binary | std::ios::trunc); + fs.write((char *)source_code.data(), source_code.size() * sizeof(uint32_t)); + fs.close(); + return spv_path; +} + +void AotModuleBuilderImpl::dump(const std::string &output_dir, + const std::string &filename) const { + TI_WARN_IF(!filename.empty(), + "Filename prefix is ignored on vulkan backend."); + const std::string bin_path = fmt::format("{}/metadata.tcb", output_dir); + write_to_binary_file(ti_aot_data_, bin_path); + + auto converted = AotDataConverter::convert(ti_aot_data_); + for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) { + auto &k = ti_aot_data_.kernels[i]; + for (int j = 0; j < k.tasks_attribs.size(); ++j) { + std::string spv_path = write_spv_file(output_dir, k.tasks_attribs[j], + ti_aot_data_.spirv_codes[i][j]); + converted.kernels[k.name].tasks[j].source_path = spv_path; + } + } + + const std::string json_path = fmt::format("{}/metadata.json", output_dir); + converted.dump_json(json_path); +} + +void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, + Kernel *kernel) { + spirv::lower(kernel); + auto compiled = + run_codegen(kernel, aot_target_device_.get(), compiled_structs_); + compiled.kernel_attribs.name = identifier; + ti_aot_data_.kernels.push_back(compiled.kernel_attribs); + ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); +} + +void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, + bool is_scalar, + DataType dt, + std::vector shape, + int row_num, + int column_num) { + // Note that currently we only support adding dense fields in AOT for all + // backends. In opengl backend we only error out when a non dense field is + // added to the aot module, but in metal backend we error out earlier when + // constructing aot module. Ideally we will unify this behavior but it doesn't + // matter too much for now. + TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent), + "AOT: only supports dense field"); + + const auto &dense_desc = + compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id); + + aot::CompiledFieldData field_data; + field_data.field_name = identifier; + field_data.is_scalar = is_scalar; + field_data.dtype = to_vk_dtype_enum(dt); + field_data.dtype_name = dt.to_string(); + field_data.shape = shape; + field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent_cell; + if (!is_scalar) { + field_data.element_shape = {row_num, column_num}; + } + ti_aot_data_.fields.push_back(field_data); +} + +void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier, + const std::string &key, + Kernel *kernel) { + spirv::lower(kernel); + auto compiled = + run_codegen(kernel, aot_target_device_.get(), compiled_structs_); + + compiled.kernel_attribs.name = identifier + "|" + key; + ti_aot_data_.kernels.push_back(compiled.kernel_attribs); + ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); +} + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/aot_module_builder_impl.h b/taichi/backends/vulkan/aot_module_builder_impl.h new file mode 100644 index 0000000000000..408e24b1f3dc8 --- /dev/null +++ b/taichi/backends/vulkan/aot_module_builder_impl.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include + +#include "taichi/aot/module_builder.h" +#include "taichi/backends/vulkan/aot_utils.h" +#include "taichi/backends/vulkan/runtime.h" +#include "taichi/codegen/spirv/snode_struct_compiler.h" +#include "taichi/codegen/spirv/kernel_utils.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +class AotModuleBuilderImpl : public AotModuleBuilder { + public: + explicit AotModuleBuilderImpl( + const std::vector &compiled_structs); + + void dump(const std::string &output_dir, + const std::string &filename) const override; + + private: + void add_per_backend(const std::string &identifier, Kernel *kernel) override; + + void add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, + bool is_scalar, + DataType dt, + std::vector shape, + int row_num, + int column_num) override; + + void add_per_backend_tmpl(const std::string &identifier, + const std::string &key, + Kernel *kernel) override; + + std::string write_spv_file(const std::string &output_dir, + const TaskAttributes &k, + const std::vector &source_code) const; + + uint32_t to_vk_dtype_enum(DataType dt); + + const std::vector &compiled_structs_; + TaichiAotData ti_aot_data_; + std::unique_ptr aot_target_device_; +}; + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/aot_module_loader_impl.cpp b/taichi/backends/vulkan/aot_module_loader_impl.cpp new file mode 100644 index 0000000000000..249d0fea570f0 --- /dev/null +++ b/taichi/backends/vulkan/aot_module_loader_impl.cpp @@ -0,0 +1,128 @@ +#include "taichi/backends/vulkan/aot_module_loader_impl.h" + +#include +#include + +#include "taichi/backends/vulkan/runtime.h" + +namespace taichi { +namespace lang { +namespace vulkan { +namespace { + +using KernelHandle = VkRuntime::KernelHandle; + +class KernelImpl : public aot::Kernel { + public: + explicit KernelImpl(VkRuntime *runtime, KernelHandle handle) + : runtime_(runtime), handle_(handle) { + } + + void launch(RuntimeContext *ctx) override { + runtime_->launch_kernel(handle_, ctx); + } + + private: + VkRuntime *const runtime_; + const KernelHandle handle_; +}; + +class AotModuleImpl : public aot::Module { + public: + explicit AotModuleImpl(const AotModuleParams ¶ms) + : runtime_(params.runtime) { + const std::string bin_path = + fmt::format("{}/metadata.tcb", params.module_path); + read_from_binary_file(ti_aot_data_, bin_path); + + for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) { + auto k = ti_aot_data_.kernels[i]; + + std::vector> spirv_sources_codes; + for (int j = 0; j < k.tasks_attribs.size(); ++j) { + std::vector res = + read_spv_file(params.module_path, k.tasks_attribs[j]); + spirv_sources_codes.push_back(res); + } + ti_aot_data_.spirv_codes.push_back(spirv_sources_codes); + } + } + + std::unique_ptr get_kernel(const std::string &name) override { + return make_new_kernel(name); + } + + std::unique_ptr get_field(const std::string &name) override { + TI_NOT_IMPLEMENTED; + } + + size_t get_root_size() const override { + return ti_aot_data_.root_buffer_size; + } + + // Module metadata + Arch arch() const override { + return Arch::vulkan; + } + uint64_t version() const override { + TI_NOT_IMPLEMENTED; + } + + private: + bool get_kernel_params_by_name(const std::string &name, + VkRuntime::RegisterParams &kernel) { + for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) { + // Offloaded task names encode more than the name of the function, but for + // AOT, only use the name of the function which should be the first part + // of the struct + if (ti_aot_data_.kernels[i].name.rfind(name, 0) == 0) { + kernel.kernel_attribs = ti_aot_data_.kernels[i]; + kernel.task_spirv_source_codes = ti_aot_data_.spirv_codes[i]; + // We don't have to store the number of SNodeTree in |ti_aot_data_| yet, + // because right now we only support a single SNodeTree during AOT. + // TODO: Support multiple SNodeTrees in AOT. + kernel.num_snode_trees = 1; + return true; + } + } + return false; + } + + std::unique_ptr make_new_kernel( + const std::string &name) override { + VkRuntime::RegisterParams kparams; + if (!get_kernel_params_by_name(name, kparams)) { + TI_DEBUG("Failed to load kernel {}", name); + return nullptr; + } + auto handle = runtime_->register_taichi_kernel(kparams); + return std::make_unique(runtime_, handle); + } + + std::vector read_spv_file(const std::string &output_dir, + const TaskAttributes &k) { + const std::string spv_path = fmt::format("{}/{}.spv", output_dir, k.name); + std::vector source_code; + std::ifstream fs(spv_path, std::ios_base::binary | std::ios::ate); + size_t size = fs.tellg(); + fs.seekg(0, std::ios::beg); + source_code.resize(size / sizeof(uint32_t)); + fs.read((char *)source_code.data(), size); + fs.close(); + return source_code; + } + + TaichiAotData ti_aot_data_; + VkRuntime *runtime_{nullptr}; +}; + +} // namespace + +std::unique_ptr make_aot_module(std::any mod_params) { + AotModuleParams params = std::any_cast(mod_params); + return std::make_unique(params); +} + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/aot_module_loader_impl.h b/taichi/backends/vulkan/aot_module_loader_impl.h new file mode 100644 index 0000000000000..37e28d01f6388 --- /dev/null +++ b/taichi/backends/vulkan/aot_module_loader_impl.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +#include "taichi/backends/vulkan/aot_utils.h" +#include "taichi/backends/vulkan/runtime.h" +#include "taichi/codegen/spirv/kernel_utils.h" + +#include "taichi/aot/module_loader.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +class VkRuntime; + +struct AotModuleParams { + std::string module_path; + VkRuntime *runtime{nullptr}; +}; + +std::unique_ptr make_aot_module(std::any mod_params); +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/aot_utils.h b/taichi/backends/vulkan/aot_utils.h new file mode 100644 index 0000000000000..5c00d4023efa8 --- /dev/null +++ b/taichi/backends/vulkan/aot_utils.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "taichi/codegen/spirv/kernel_utils.h" +#include "taichi/aot/module_loader.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +/** + * AOT module data for the vulkan backend. + */ +struct TaichiAotData { + // BufferMetaData metadata; + std::vector>> spirv_codes; + std::vector kernels; + std::vector fields; + size_t root_buffer_size{0}; + + TI_IO_DEF(kernels, fields, root_buffer_size); +}; + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/codegen_vulkan.h b/taichi/backends/vulkan/codegen_vulkan.h deleted file mode 100644 index ad61f92ddbda9..0000000000000 --- a/taichi/backends/vulkan/codegen_vulkan.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "taichi/lang_util.h" - -#include "taichi/backends/vulkan/snode_struct_compiler.h" - -namespace taichi { -namespace lang { - -class Kernel; - -namespace vulkan { - -class VkRuntime; - -void lower(Kernel *kernel); - -// These ASTs must have already been lowered at the CHI level. -FunctionType compile_to_executable(Kernel *kernel, VkRuntime *runtime); - -} // namespace vulkan -} // namespace lang -} // namespace taichi diff --git a/taichi/backends/vulkan/data_type_utils.h b/taichi/backends/vulkan/data_type_utils.h deleted file mode 100644 index c837a9d6d39d3..0000000000000 --- a/taichi/backends/vulkan/data_type_utils.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include -#include - -#include "taichi/lang_util.h" - -namespace taichi { -namespace lang { -namespace vulkan { - -inline std::size_t vk_data_type_size(DataType dt) { - // Vulkan buffers require a minimum alignment of 4 bytes. - // https://vulkan-tutorial.com/Uniform_buffers/Descriptor_pool_and_sets#page_Alignment-requirements - return std::max(data_type_size(dt), 4); -} - -} // namespace vulkan -} // namespace lang -} // namespace taichi diff --git a/taichi/backends/vulkan/kernel_utils.cpp b/taichi/backends/vulkan/kernel_utils.cpp deleted file mode 100644 index a5ed48d696cf0..0000000000000 --- a/taichi/backends/vulkan/kernel_utils.cpp +++ /dev/null @@ -1,124 +0,0 @@ -#include "taichi/backends/vulkan/kernel_utils.h" - -#include - -#include "taichi/backends/vulkan/data_type_utils.h" -#include "taichi/program/kernel.h" -#define TI_RUNTIME_HOST -#include "taichi/program/context.h" -#undef TI_RUNTIME_HOST - -namespace taichi { -namespace lang { -namespace vulkan { - -// static -std::string TaskAttributes::buffers_name(BufferInfo b) { - if (b.type == BufferType::Context) { - return "Context"; - } - if (b.type == BufferType::GlobalTmps) { - return "GlobalTmps"; - } - if (b.type == BufferType::Root) { - return std::string("Root: ") + std::to_string(b.root_id); - } - TI_ERROR("unrecognized buffer type"); -} - -std::string TaskAttributes::debug_string() const { - std::string result; - result += fmt::format( - "", - TaskAttributes::buffers_name(buffer), binding); -} - -KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) - : args_bytes_(0), - rets_bytes_(0), - extra_args_bytes_(Context::extra_args_size) { - arg_attribs_vec_.reserve(kernel.args.size()); - for (const auto &ka : kernel.args) { - ArgAttributes aa; - aa.dt = ka.dt; - const size_t dt_bytes = vk_data_type_size(aa.dt); - if (dt_bytes != 4) { - TI_ERROR("Vulakn kernel only supports 32-bit data, got {}", - data_type_name(aa.dt)); - } - aa.is_array = ka.is_external_array; - // For array, |ka.size| is #elements * elements_size - aa.stride = aa.is_array ? ka.size : dt_bytes; - aa.index = arg_attribs_vec_.size(); - arg_attribs_vec_.push_back(aa); - } - for (const auto &kr : kernel.rets) { - RetAttributes ra; - ra.dt = kr.dt; - const size_t dt_bytes = vk_data_type_size(ra.dt); - if (dt_bytes != 4) { - // Metal doesn't support 64bit data buffers. - TI_ERROR("Vulakn kernel only supports 32-bit data, got {}", - data_type_name(ra.dt)); - } - ra.is_array = false; // TODO(#909): this is a temporary limitation - ra.stride = dt_bytes; - ra.index = ret_attribs_vec_.size(); - ret_attribs_vec_.push_back(ra); - } - - auto arrange_scalar_before_array = [](auto *vec, size_t offset) -> size_t { - std::vector scalar_indices; - std::vector array_indices; - for (int i = 0; i < vec->size(); ++i) { - if ((*vec)[i].is_array) { - array_indices.push_back(i); - } else { - scalar_indices.push_back(i); - } - } - size_t bytes = offset; - // Put scalar args in the memory first - for (int i : scalar_indices) { - auto &attribs = (*vec)[i]; - attribs.offset_in_mem = bytes; - bytes += attribs.stride; - TI_TRACE(" at={} scalar offset_in_mem={} stride={}", i, - attribs.offset_in_mem, attribs.stride); - } - // Then the array args - for (int i : array_indices) { - auto &attribs = (*vec)[i]; - attribs.offset_in_mem = bytes; - bytes += attribs.stride; - TI_TRACE(" at={} array offset_in_mem={} stride={}", i, - attribs.offset_in_mem, attribs.stride); - } - return bytes - offset; - }; - - TI_TRACE("args:"); - args_bytes_ = arrange_scalar_before_array(&arg_attribs_vec_, 0); - TI_TRACE("rets:"); - rets_bytes_ = arrange_scalar_before_array(&ret_attribs_vec_, args_bytes_); - TI_TRACE("sizes: args={} rets={} ctx={} total={}", args_bytes(), rets_bytes(), - ctx_bytes(), total_bytes()); - TI_ASSERT(has_rets() == (rets_bytes_ > 0)); -} - -} // namespace vulkan -} // namespace lang -} // namespace taichi diff --git a/taichi/backends/vulkan/loader.cpp b/taichi/backends/vulkan/loader.cpp deleted file mode 100644 index d71a5f069daaf..0000000000000 --- a/taichi/backends/vulkan/loader.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#define VOLK_IMPLEMENTATION -#include - -#include "taichi/backends/vulkan/loader.h" -#include "taichi/common/logging.h" - -namespace taichi { -namespace lang { -namespace vulkan { - -VulkanLoader::VulkanLoader() { -} - -bool VulkanLoader::init() { - std::call_once(init_flag_, [&]() { - if (initialized) { - return; - } - VkResult result = volkInitialize(); - initialized = result == VK_SUCCESS; - }); - return initialized; -} - -void VulkanLoader::load_instance(VkInstance instance) { - vulkan_instance_ = instance; - volkLoadInstance(instance); -} -void VulkanLoader::load_device(VkDevice device) { - vulkan_device_ = device; - volkLoadDevice(device); -} - -PFN_vkVoidFunction VulkanLoader::load_function(const char *name) { - auto result = - vkGetInstanceProcAddr(VulkanLoader::instance().vulkan_instance_, name); - TI_WARN_IF(result == nullptr, "loaded vulkan function {} is nullptr", name); - return result; -} - -bool is_vulkan_api_available() { - return VulkanLoader::instance().init(); -} - -} // namespace vulkan -} // namespace lang -} // namespace taichi diff --git a/taichi/backends/vulkan/runtime.cpp b/taichi/backends/vulkan/runtime.cpp index b942c05e6b5bf..96325f2c848b6 100644 --- a/taichi/backends/vulkan/runtime.cpp +++ b/taichi/backends/vulkan/runtime.cpp @@ -1,4 +1,5 @@ #include "taichi/backends/vulkan/runtime.h" +#include "taichi/program/program.h" #include #include @@ -9,6 +10,8 @@ #include #include +#include "fp16.h" + #define TI_RUNTIME_HOST #include "taichi/program/context.h" #undef TI_RUNTIME_HOST @@ -23,7 +26,7 @@ class StopWatch { StopWatch() : begin_(std::chrono::system_clock::now()) { } - int GetMicros() { + int get_micros() { typedef std::chrono::duration fsec; auto now = std::chrono::system_clock::now(); @@ -41,26 +44,28 @@ class StopWatch { class HostDeviceContextBlitter { public: HostDeviceContextBlitter(const KernelContextAttributes *ctx_attribs, - Context *host_ctx, + RuntimeContext *host_ctx, Device *device, uint64_t *host_result_buffer, - DeviceAllocation *device_buffer, - DeviceAllocation *host_shadow_buffer) + DeviceAllocation *device_args_buffer, + DeviceAllocation *device_ret_buffer) : ctx_attribs_(ctx_attribs), host_ctx_(host_ctx), - device_(device), host_result_buffer_(host_result_buffer), - device_buffer_(device_buffer), - host_shadow_buffer_(host_shadow_buffer) { + device_args_buffer_(device_args_buffer), + device_ret_buffer_(device_ret_buffer), + device_(device) { } - void host_to_device() { - if (ctx_attribs_->empty()) { + void host_to_device( + const std::unordered_map &ext_arrays, + const std::unordered_map &ext_arr_size) { + if (!ctx_attribs_->has_args()) { return; } char *const device_base = - reinterpret_cast(device_->map(*device_buffer_)); + reinterpret_cast(device_->map(*device_args_buffer_)); #define TO_DEVICE(short_type, type) \ if (dt->is_primitive(PrimitiveTypeID::short_type)) { \ @@ -75,8 +80,23 @@ class HostDeviceContextBlitter { char *device_ptr = device_base + arg.offset_in_mem; do { if (arg.is_array) { - const void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(device_ptr, host_ptr, arg.stride); + if (!host_ctx_->is_device_allocation[i] && ext_arr_size.at(i)) { + // Only need to blit ext arrs (host array) + DeviceAllocation buffer = ext_arrays.at(i); + char *const device_arr_ptr = + reinterpret_cast(device_->map(buffer)); + const void *host_ptr = host_ctx_->get_arg(i); + std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i)); + device_->unmap(buffer); + } + // Substitue in the device address if supported + if (device_->get_cap( + DeviceCapability::spirv_has_physical_storage_buffer)) { + uint64_t addr = + device_->get_memory_physical_pointer(ext_arrays.at(i)); + reinterpret_cast(device_ptr)[0] = addr; + } + // We should not process the rest break; } if (device_->get_cap(DeviceCapability::spirv_has_int8)) { @@ -97,20 +117,31 @@ class HostDeviceContextBlitter { if (device_->get_cap(DeviceCapability::spirv_has_float64)) { TO_DEVICE(f64, float64) } + if (device_->get_cap(DeviceCapability::spirv_has_float16)) { + if (dt->is_primitive(PrimitiveTypeID::f16)) { + auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg(i)); + reinterpret_cast(device_ptr)[0] = d; + break; + } + } TI_ERROR("Vulkan does not support arg type={}", data_type_name(arg.dt)); } while (0); } + char *device_ptr = device_base + ctx_attribs_->extra_args_mem_offset(); std::memcpy(device_ptr, host_ctx_->extra_args, ctx_attribs_->extra_args_bytes()); - device_->unmap(*device_buffer_); + device_->unmap(*device_args_buffer_); #undef TO_DEVICE } - void device_to_host() { + bool device_to_host( + CommandList *cmdlist, + const std::unordered_map &ext_arrays, + const std::unordered_map &ext_arr_size) { if (ctx_attribs_->empty()) { - return; + return false; } bool require_sync = ctx_attribs_->rets().size() > 0; @@ -118,35 +149,46 @@ class HostDeviceContextBlitter { for (int i = 0; i < ctx_attribs_->args().size(); ++i) { const auto &arg = ctx_attribs_->args()[i]; if (arg.is_array) { - require_sync = true; + if (!host_ctx_->is_device_allocation[i] && ext_arr_size.at(i)) { + require_sync = true; + } } } } if (require_sync) { - device_->get_compute_stream()->command_sync(); + device_->get_compute_stream()->submit_synced(cmdlist); } else { - return; + return false; } - char *const device_base = - reinterpret_cast(device_->map(*host_shadow_buffer_)); - for (int i = 0; i < ctx_attribs_->args().size(); ++i) { const auto &arg = ctx_attribs_->args()[i]; - char *device_ptr = device_base + arg.offset_in_mem; if (arg.is_array) { - void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(host_ptr, device_ptr, arg.stride); + if (!host_ctx_->is_device_allocation[i] && ext_arr_size.at(i)) { + // Only need to blit ext arrs (host array) + DeviceAllocation buffer = ext_arrays.at(i); + char *const device_arr_ptr = + reinterpret_cast(device_->map(buffer)); + void *host_ptr = host_ctx_->get_arg(i); + std::memcpy(host_ptr, device_arr_ptr, ext_arr_size.at(i)); + device_->unmap(buffer); + } } } -#define TO_HOST(short_type, type) \ - if (dt->is_primitive(PrimitiveTypeID::short_type)) { \ - const type d = *reinterpret_cast(device_ptr); \ - host_result_buffer_[i] = \ - taichi_union_cast_with_different_sizes(d); \ - break; \ + if (!ctx_attribs_->has_rets()) + return require_sync; + + char *const device_base = + reinterpret_cast(device_->map(*device_ret_buffer_)); + +#define TO_HOST(short_type, type, offset) \ + if (dt->is_primitive(PrimitiveTypeID::short_type)) { \ + const type d = *(reinterpret_cast(device_ptr) + offset); \ + host_result_buffer_[offset] = \ + taichi_union_cast_with_different_sizes(d); \ + continue; \ } for (int i = 0; i < ctx_attribs_->rets().size(); ++i) { @@ -155,65 +197,75 @@ class HostDeviceContextBlitter { const auto &ret = ctx_attribs_->rets()[i]; char *device_ptr = device_base + ret.offset_in_mem; const auto dt = ret.dt; - do { - if (ret.is_array) { - void *host_ptr = host_ctx_->get_arg(i); - std::memcpy(host_ptr, device_ptr, ret.stride); - break; - } + const auto num = ret.stride / data_type_size(ret.dt); + for (int j = 0; j < num; ++j) { if (device_->get_cap(DeviceCapability::spirv_has_int8)) { - TO_HOST(i8, int8) - TO_HOST(u8, uint8) + TO_HOST(i8, int8, j) + TO_HOST(u8, uint8, j) } if (device_->get_cap(DeviceCapability::spirv_has_int16)) { - TO_HOST(i16, int16) - TO_HOST(u16, uint16) + TO_HOST(i16, int16, j) + TO_HOST(u16, uint16, j) } - TO_HOST(i32, int32) - TO_HOST(u32, uint32) - TO_HOST(f32, float32) + TO_HOST(i32, int32, j) + TO_HOST(u32, uint32, j) + TO_HOST(f32, float32, j) if (device_->get_cap(DeviceCapability::spirv_has_int64)) { - TO_HOST(i64, int64) - TO_HOST(u64, uint64) + TO_HOST(i64, int64, j) + TO_HOST(u64, uint64, j) } if (device_->get_cap(DeviceCapability::spirv_has_float64)) { - TO_HOST(f64, float64) + TO_HOST(f64, float64, j) + } + if (device_->get_cap(DeviceCapability::spirv_has_float16)) { + if (dt->is_primitive(PrimitiveTypeID::f16)) { + const float d = fp16_ieee_to_fp32_value( + *reinterpret_cast(device_ptr) + j); + host_result_buffer_[j] = + taichi_union_cast_with_different_sizes(d); + continue; + } } TI_ERROR("Vulkan does not support return value type={}", data_type_name(ret.dt)); - } while (0); + } } #undef TO_HOST - device_->unmap(*host_shadow_buffer_); + device_->unmap(*device_ret_buffer_); + + return true; } static std::unique_ptr maybe_make( const KernelContextAttributes *ctx_attribs, - Context *host_ctx, + RuntimeContext *host_ctx, Device *device, uint64_t *host_result_buffer, - DeviceAllocation *device_buffer, - DeviceAllocation *host_shadow_buffer) { + DeviceAllocation *device_args_buffer, + DeviceAllocation *device_ret_buffer) { if (ctx_attribs->empty()) { return nullptr; } return std::make_unique( - ctx_attribs, host_ctx, device, host_result_buffer, device_buffer, - host_shadow_buffer); + ctx_attribs, host_ctx, device, host_result_buffer, device_args_buffer, + device_ret_buffer); } private: const KernelContextAttributes *const ctx_attribs_; - Context *const host_ctx_; + RuntimeContext *const host_ctx_; uint64_t *const host_result_buffer_; - DeviceAllocation *const device_buffer_; - DeviceAllocation *const host_shadow_buffer_; + DeviceAllocation *const device_args_buffer_; + DeviceAllocation *const device_ret_buffer_; Device *const device_; }; } // namespace +constexpr size_t kGtmpBufferSize = 1024 * 1024; +constexpr size_t kListGenBufferSize = 32 << 20; + // Info for launching a compiled Taichi kernel, which consists of a series of // Vulkan pipelines. @@ -221,22 +273,24 @@ CompiledTaichiKernel::CompiledTaichiKernel(const Params &ti_params) : ti_kernel_attribs_(*ti_params.ti_kernel_attribs), device_(ti_params.device) { input_buffers_[BufferType::GlobalTmps] = ti_params.global_tmps_buffer; - for (int root = 0; root < ti_params.compiled_structs.size(); ++root) { + input_buffers_[BufferType::ListGen] = ti_params.listgen_buffer; + + // Compiled_structs can be empty if loading a kernel from an AOT module as + // the SNode are not re-compiled/structured. In thise case, we assume a + // single root buffer size configured from the AOT module. + for (int root = 0; root < ti_params.num_snode_trees; ++root) { BufferInfo buffer = {BufferType::Root, root}; input_buffers_[buffer] = ti_params.root_buffers[root]; } - const auto ctx_sz = ti_kernel_attribs_.ctx_attribs.total_bytes(); - if (!ti_kernel_attribs_.ctx_attribs.empty()) { - Device::AllocParams params; - ctx_buffer_ = ti_params.device->allocate_memory_unique( - {size_t(ctx_sz), - /*host_write=*/true, /*host_read=*/false, - /*export_sharing=*/false, AllocUsage::Storage}); - ctx_buffer_host_ = ti_params.device->allocate_memory_unique( - {size_t(ctx_sz), - /*host_write=*/false, /*host_read=*/true, - /*export_sharing=*/false, AllocUsage::Storage}); - input_buffers_[BufferType::Context] = ctx_buffer_.get(); + + const auto arg_sz = ti_kernel_attribs_.ctx_attribs.args_bytes(); + const auto ret_sz = ti_kernel_attribs_.ctx_attribs.rets_bytes(); + + args_buffer_size_ = arg_sz; + ret_buffer_size_ = ret_sz; + + if (arg_sz) { + args_buffer_size_ += ti_kernel_attribs_.ctx_attribs.extra_args_bytes(); } const auto &task_attribs = ti_kernel_attribs_.tasks_attribs; @@ -248,7 +302,7 @@ CompiledTaichiKernel::CompiledTaichiKernel(const Params &ti_params) (void *)spirv_bins[i].data(), spirv_bins[i].size() * sizeof(uint32_t)}; auto vp = - ti_params.device->create_pipeline(source_desc, ti_kernel_attribs_.name); + ti_params.device->create_pipeline(source_desc, task_attribs[i].name); pipelines_.push_back(std::move(vp)); } } @@ -261,15 +315,19 @@ size_t CompiledTaichiKernel::num_pipelines() const { return pipelines_.size(); } -DeviceAllocation *CompiledTaichiKernel::ctx_buffer() const { - return ctx_buffer_.get(); +size_t CompiledTaichiKernel::get_args_buffer_size() const { + return args_buffer_size_; } -DeviceAllocation *CompiledTaichiKernel::ctx_buffer_host() const { - return ctx_buffer_host_.get(); +size_t CompiledTaichiKernel::get_ret_buffer_size() const { + return ret_buffer_size_; } -void CompiledTaichiKernel::command_list(CommandList *cmdlist) const { +void CompiledTaichiKernel::generate_command_list( + CommandList *cmdlist, + DeviceAllocationGuard *args_buffer, + DeviceAllocationGuard *ret_buffer, + const std::unordered_map &ext_arrs) const { const auto &task_attribs = ti_kernel_attribs_.tasks_attribs; for (int i = 0; i < task_attribs.size(); ++i) { @@ -280,7 +338,30 @@ void CompiledTaichiKernel::command_list(CommandList *cmdlist) const { attribs.advisory_num_threads_per_group; ResourceBinder *binder = vp->resource_binder(); for (auto &bind : attribs.buffer_binds) { - binder->rw_buffer(0, bind.binding, *input_buffers_.at(bind.buffer)); + if (bind.buffer.type == BufferType::ExtArr) { + binder->rw_buffer(0, bind.binding, ext_arrs.at(bind.buffer.root_id)); + } else if (args_buffer && bind.buffer.type == BufferType::Args) { + binder->buffer(0, bind.binding, *args_buffer); + } else if (ret_buffer && bind.buffer.type == BufferType::Rets) { + binder->rw_buffer(0, bind.binding, *ret_buffer); + } else { + DeviceAllocation *alloc = input_buffers_.at(bind.buffer); + if (alloc) { + binder->rw_buffer(0, bind.binding, *alloc); + } + } + } + + if (attribs.task_type == OffloadedTaskType::listgen) { + for (auto &bind : attribs.buffer_binds) { + if (bind.buffer.type == BufferType::ListGen) { + // FIXME: properlly support multiple list + cmdlist->buffer_fill(input_buffers_.at(bind.buffer)->get_ptr(0), + kListGenBufferSize, + /*data=*/0); + cmdlist->buffer_barrier(*input_buffers_.at(bind.buffer)); + } + } } cmdlist->bind_pipeline(vp); @@ -288,22 +369,17 @@ void CompiledTaichiKernel::command_list(CommandList *cmdlist) const { cmdlist->dispatch(group_x); cmdlist->memory_barrier(); } - - const auto ctx_sz = ti_kernel_attribs_.ctx_attribs.total_bytes(); - if (!ti_kernel_attribs_.ctx_attribs.empty()) { - cmdlist->buffer_copy(ctx_buffer_host_->get_ptr(0), ctx_buffer_->get_ptr(0), - ctx_sz); - cmdlist->buffer_barrier(*ctx_buffer_host_); - } } VkRuntime::VkRuntime(const Params ¶ms) - : host_result_buffer_(params.host_result_buffer), device_(params.device) { + : device_(params.device), host_result_buffer_(params.host_result_buffer) { TI_ASSERT(host_result_buffer_ != nullptr); - init_buffers(); + current_cmdlist_pending_since_ = high_res_clock::now(); + init_nonroot_buffers(); } VkRuntime::~VkRuntime() { + synchronize(); { decltype(ti_kernels_) tmp; tmp.swap(ti_kernels_); @@ -311,47 +387,21 @@ VkRuntime::~VkRuntime() { global_tmps_buffer_.reset(); } -void VkRuntime::materialize_snode_tree(SNodeTree *tree) { - auto *const root = tree->root(); - CompiledSNodeStructs compiled_structs = vulkan::compile_snode_structs(*root); - add_root_buffer(compiled_structs.root_size); - compiled_snode_structs_.push_back(compiled_structs); -} - -void VkRuntime::destroy_snode_tree(SNodeTree *snode_tree) { - int root_id = -1; - for (int i = 0; i < compiled_snode_structs_.size(); ++i) { - if (compiled_snode_structs_[i].root == snode_tree->root()) { - root_id = i; - } - } - if (root_id == -1) { - TI_ERROR("the tree to be destroyed cannot be found"); - } - root_buffers_[root_id].reset(); -} - -const std::vector &VkRuntime::get_compiled_structs() - const { - return compiled_snode_structs_; -} - VkRuntime::KernelHandle VkRuntime::register_taichi_kernel( VkRuntime::RegisterParams reg_params) { CompiledTaichiKernel::Params params; params.ti_kernel_attribs = &(reg_params.kernel_attribs); - params.compiled_structs = get_compiled_structs(); + params.num_snode_trees = reg_params.num_snode_trees; params.device = device_; params.root_buffers = {}; for (int root = 0; root < root_buffers_.size(); ++root) { params.root_buffers.push_back(root_buffers_[root].get()); } params.global_tmps_buffer = global_tmps_buffer_.get(); + params.listgen_buffer = listgen_buffer_.get(); for (int i = 0; i < reg_params.task_spirv_source_codes.size(); ++i) { - const auto &attribs = reg_params.kernel_attribs.tasks_attribs[i]; const auto &spirv_src = reg_params.task_spirv_source_codes[i]; - const auto &task_name = attribs.name; // If we can reach here, we have succeeded. Otherwise // std::optional::value() would have killed us. @@ -363,28 +413,141 @@ VkRuntime::KernelHandle VkRuntime::register_taichi_kernel( return res; } -void VkRuntime::launch_kernel(KernelHandle handle, Context *host_ctx) { +void VkRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) { auto *ti_kernel = ti_kernels_[handle.id_].get(); + + std::unique_ptr args_buffer{nullptr}, + ret_buffer{nullptr}; + + if (ti_kernel->get_args_buffer_size()) { + args_buffer = device_->allocate_memory_unique( + {ti_kernel->get_args_buffer_size(), + /*host_write=*/true, /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Uniform}); + } + + if (ti_kernel->get_ret_buffer_size()) { + ret_buffer = device_->allocate_memory_unique( + {ti_kernel->get_ret_buffer_size(), + /*host_write=*/false, /*host_read=*/true, + /*export_sharing=*/false, AllocUsage::Storage}); + } + + // Create context blitter auto ctx_blitter = HostDeviceContextBlitter::maybe_make( &ti_kernel->ti_kernel_attribs().ctx_attribs, host_ctx, device_, - host_result_buffer_, ti_kernel->ctx_buffer(), - ti_kernel->ctx_buffer_host()); + host_result_buffer_, args_buffer.get(), ret_buffer.get()); + + // `any_arrays` contain both external arrays and NDArrays + std::unordered_map any_arrays; + // `ext_array_size` only holds the size of external arrays (host arrays) + // As buffer size information is only needed when it needs to be allocated + // and transferred by the host + std::unordered_map ext_array_size; + + // Prepare context buffers & arrays if (ctx_blitter) { - TI_ASSERT(ti_kernel->ctx_buffer() != nullptr); - ctx_blitter->host_to_device(); + TI_ASSERT(ti_kernel->get_args_buffer_size() || + ti_kernel->get_ret_buffer_size()); + + int i = 0; + const auto &args = ti_kernel->ti_kernel_attribs().ctx_attribs.args(); + for (auto &arg : args) { + if (arg.is_array) { + if (host_ctx->is_device_allocation[i]) { + // NDArray + if (host_ctx->args[i]) { + any_arrays[i] = *(DeviceAllocation *)(host_ctx->args[i]); + } else { + any_arrays[i] = kDeviceNullAllocation; + } + } else { + // Compute ext arr sizes + size_t size = arg.stride; + bool has_zero_axis = false; + + for (int ax = 0; ax < 8; ax++) { + // FIXME: how and when do we determine the size of ext arrs? + size_t axis_size = host_ctx->extra_args[i][ax]; + if (axis_size) { + if (has_zero_axis) { + // e.g. shape [1, 0, 1] + size = 0; + } else { + size *= host_ctx->extra_args[i][ax]; + } + } else { + has_zero_axis = true; + } + } + + ext_array_size[i] = size; + + // Alloc ext arr + if (size) { + DeviceAllocation extarr_buf = device_->allocate_memory( + {size, /*host_write=*/true, /*host_read=*/true, + /*export_sharing=*/false, AllocUsage::Storage}); + any_arrays[i] = extarr_buf; + } else { + any_arrays[i] = kDeviceNullAllocation; + } + } + } + i++; + } + + ctx_blitter->host_to_device(any_arrays, ext_array_size); } + // Create new command list if current one is nullptr if (!current_cmdlist_) { + ctx_buffers_.clear(); + current_cmdlist_pending_since_ = high_res_clock::now(); current_cmdlist_ = device_->get_compute_stream()->new_command_list(); } - ti_kernel->command_list(current_cmdlist_.get()); + // Record commands + ti_kernel->generate_command_list(current_cmdlist_.get(), args_buffer.get(), + ret_buffer.get(), any_arrays); + // Keep context buffers used in this dispatch + if (ti_kernel->get_args_buffer_size()) { + ctx_buffers_.push_back(std::move(args_buffer)); + } + if (ti_kernel->get_ret_buffer_size()) { + ctx_buffers_.push_back(std::move(ret_buffer)); + } + + // If we need to host sync, sync and remove in-flight references if (ctx_blitter) { - device_->get_compute_stream()->submit(current_cmdlist_.get()); - ctx_blitter->device_to_host(); + if (ctx_blitter->device_to_host(current_cmdlist_.get(), any_arrays, + ext_array_size)) { + current_cmdlist_ = nullptr; + ctx_buffers_.clear(); + } + } - current_cmdlist_ = nullptr; + // If we have accumulated some work but does not require sync + // and if the accumulated cmdlist has been pending for some time + // launch the cmdlist to start processing. + if (current_cmdlist_) { + constexpr uint64_t max_pending_time = 2000; // 2000us = 2ms + auto duration = high_res_clock::now() - current_cmdlist_pending_since_; + if (std::chrono::duration_cast(duration) + .count() > max_pending_time) { + device_->get_compute_stream()->submit(current_cmdlist_.get()); + current_cmdlist_ = nullptr; + } + } + + // Dealloc external arrays + for (auto pair : any_arrays) { + if (pair.second != kDeviceNullAllocation) { + if (!host_ctx->is_device_allocation[pair.first]) { + device_->dealloc_memory(pair.second); + } + } } } @@ -394,17 +557,21 @@ void VkRuntime::synchronize() { current_cmdlist_ = nullptr; } device_->get_compute_stream()->command_sync(); + ctx_buffers_.clear(); } Device *VkRuntime::get_ti_device() const { return device_; } -void VkRuntime::init_buffers() { - size_t gtmp_buffer_size = 1024 * 1024; - +void VkRuntime::init_nonroot_buffers() { global_tmps_buffer_ = device_->allocate_memory_unique( - {gtmp_buffer_size, + {kGtmpBufferSize, + /*host_write=*/false, /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Storage}); + + listgen_buffer_ = device_->allocate_memory_unique( + {kListGenBufferSize, /*host_write=*/false, /*host_read=*/false, /*export_sharing=*/false, AllocUsage::Storage}); @@ -412,7 +579,9 @@ void VkRuntime::init_buffers() { Stream *stream = device_->get_compute_stream(); auto cmdlist = stream->new_command_list(); - cmdlist->buffer_fill(global_tmps_buffer_->get_ptr(0), gtmp_buffer_size, + cmdlist->buffer_fill(global_tmps_buffer_->get_ptr(0), kGtmpBufferSize, + /*data=*/0); + cmdlist->buffer_fill(listgen_buffer_->get_ptr(0), kListGenBufferSize, /*data=*/0); stream->submit_synced(cmdlist.get()); } @@ -433,6 +602,27 @@ void VkRuntime::add_root_buffer(size_t root_buffer_size) { root_buffers_.push_back(std::move(new_buffer)); } +VkRuntime::RegisterParams run_codegen( + Kernel *kernel, + Device *device, + const std::vector &compiled_structs) { + const auto id = Program::get_kernel_id(); + const auto taichi_kernel_name(fmt::format("{}_k{:04d}_vk", kernel->name, id)); + TI_TRACE("VK codegen for Taichi kernel={}", taichi_kernel_name); + spirv::KernelCodegen::Params params; + params.ti_kernel_name = taichi_kernel_name; + params.kernel = kernel; + params.compiled_structs = compiled_structs; + params.device = device; + params.enable_spv_opt = + kernel->program->config.external_optimization_level > 0; + spirv::KernelCodegen codegen(params); + VkRuntime::RegisterParams res; + codegen.run(res.kernel_attribs, res.task_spirv_source_codes); + res.num_snode_trees = compiled_structs.size(); + return res; +} + } // namespace vulkan } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/runtime.h b/taichi/backends/vulkan/runtime.h index d1389277ad08a..d862dcf5cf755 100644 --- a/taichi/backends/vulkan/runtime.h +++ b/taichi/backends/vulkan/runtime.h @@ -2,10 +2,12 @@ #include "taichi/lang_util.h" #include +#include #include "taichi/backends/device.h" -#include "taichi/backends/vulkan/snode_struct_compiler.h" -#include "taichi/backends/vulkan/kernel_utils.h" +#include "taichi/codegen/spirv/snode_struct_compiler.h" +#include "taichi/codegen/spirv/kernel_utils.h" +#include "taichi/codegen/spirv/spirv_codegen.h" #include "taichi/program/compile_config.h" #include "taichi/struct/snode_tree.h" #include "taichi/program/snode_expr_utils.h" @@ -14,26 +16,33 @@ namespace taichi { namespace lang { namespace vulkan { +using namespace taichi::lang::spirv; + using BufferType = TaskAttributes::BufferType; using BufferInfo = TaskAttributes::BufferInfo; using BufferBind = TaskAttributes::BufferBind; using BufferInfoHasher = TaskAttributes::BufferInfoHasher; +using high_res_clock = std::chrono::high_resolution_clock; + // TODO: In the future this isn't necessarily a pointer, since DeviceAllocation // is already a pretty cheap handle> using InputBuffersMap = std::unordered_map; +class SNodeTreeManager; + class CompiledTaichiKernel { public: struct Params { const TaichiKernelAttributes *ti_kernel_attribs{nullptr}; std::vector> spirv_bins; - std::vector compiled_structs; + std::size_t num_snode_trees{0}; Device *device{nullptr}; std::vector root_buffers; DeviceAllocation *global_tmps_buffer{nullptr}; + DeviceAllocation *listgen_buffer{nullptr}; }; CompiledTaichiKernel(const Params &ti_params); @@ -42,26 +51,29 @@ class CompiledTaichiKernel { size_t num_pipelines() const; - DeviceAllocation *ctx_buffer() const; + size_t get_args_buffer_size() const; + size_t get_ret_buffer_size() const; - DeviceAllocation *ctx_buffer_host() const; - - void command_list(CommandList *cmdlist) const; + void generate_command_list( + CommandList *cmdlist, + DeviceAllocationGuard *args_buffer, + DeviceAllocationGuard *ret_buffer, + const std::unordered_map &ext_arrs) const; private: TaichiKernelAttributes ti_kernel_attribs_; std::vector tasks_attribs_; - Device *device_; + [[maybe_unused]] Device *device_; InputBuffersMap input_buffers_; - std::unique_ptr ctx_buffer_{nullptr}; - std::unique_ptr ctx_buffer_host_{nullptr}; + size_t args_buffer_size_{0}; + size_t ret_buffer_size_{0}; std::vector> pipelines_; }; -class VkRuntime { +class TI_DLL_EXPORT VkRuntime { public: struct Params { uint64_t *host_result_buffer{nullptr}; @@ -81,40 +93,45 @@ class VkRuntime { struct RegisterParams { TaichiKernelAttributes kernel_attribs; std::vector> task_spirv_source_codes; + std::size_t num_snode_trees{0}; }; KernelHandle register_taichi_kernel(RegisterParams params); - void launch_kernel(KernelHandle handle, Context *host_ctx); - - void materialize_snode_tree(SNodeTree *tree); - - void destroy_snode_tree(SNodeTree *snode_tree); + void launch_kernel(KernelHandle handle, RuntimeContext *host_ctx); void synchronize(); Device *get_ti_device() const; - const std::vector &get_compiled_structs() const; + void add_root_buffer(size_t root_buffer_size); private: - void init_buffers(); - void add_root_buffer(size_t root_buffer_size); + friend class taichi::lang::vulkan::SNodeTreeManager; - Device *device_; + void init_nonroot_buffers(); + Device *device_{nullptr}; uint64_t *const host_result_buffer_; std::vector> root_buffers_; std::unique_ptr global_tmps_buffer_; + // FIXME: Support proper multiple lists + std::unique_ptr listgen_buffer_; + + std::vector> ctx_buffers_; std::unique_ptr current_cmdlist_{nullptr}; + high_res_clock::time_point current_cmdlist_pending_since_; std::vector> ti_kernels_; - - std::vector compiled_snode_structs_; }; +VkRuntime::RegisterParams run_codegen( + Kernel *kernel, + Device *device, + const std::vector &compiled_structs); + } // namespace vulkan } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/snode_tree_manager.cpp b/taichi/backends/vulkan/snode_tree_manager.cpp new file mode 100644 index 0000000000000..0bfb6d2f01edd --- /dev/null +++ b/taichi/backends/vulkan/snode_tree_manager.cpp @@ -0,0 +1,38 @@ +#include "taichi/backends/vulkan/snode_tree_manager.h" + +#include "taichi/backends/vulkan/runtime.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +SNodeTreeManager::SNodeTreeManager(VkRuntime *rtm) : runtime_(rtm) { +} + +void SNodeTreeManager::materialize_snode_tree(SNodeTree *tree) { + auto *const root = tree->root(); + CompiledSNodeStructs compiled_structs = vulkan::compile_snode_structs(*root); + runtime_->add_root_buffer(compiled_structs.root_size); + compiled_snode_structs_.push_back(compiled_structs); +} + +void SNodeTreeManager::destroy_snode_tree(SNodeTree *snode_tree) { + int root_id = -1; + for (int i = 0; i < compiled_snode_structs_.size(); ++i) { + if (compiled_snode_structs_[i].root == snode_tree->root()) { + root_id = i; + } + } + if (root_id == -1) { + TI_ERROR("the tree to be destroyed cannot be found"); + } + runtime_->root_buffers_[root_id].reset(); +} + +DevicePtr SNodeTreeManager::get_snode_tree_device_ptr(int tree_id) { + return runtime_->root_buffers_[tree_id]->get_ptr(); +} + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/snode_tree_manager.h b/taichi/backends/vulkan/snode_tree_manager.h new file mode 100644 index 0000000000000..d946c308632b9 --- /dev/null +++ b/taichi/backends/vulkan/snode_tree_manager.h @@ -0,0 +1,43 @@ +#pragma once + +#include + +#include "taichi/backends/device.h" +#include "taichi/codegen/spirv/snode_struct_compiler.h" +#include "taichi/struct/snode_tree.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +class VkRuntime; + +/** + * @brief Manages the SNodeTrees for the Vulkan backend. + * + */ +class SNodeTreeManager { + private: + using CompiledSNodeStructs = taichi::lang::spirv::CompiledSNodeStructs; + + public: + explicit SNodeTreeManager(VkRuntime *rtm); + + const std::vector &get_compiled_structs() const { + return compiled_snode_structs_; + } + + void materialize_snode_tree(SNodeTree *tree); + + void destroy_snode_tree(SNodeTree *snode_tree); + + DevicePtr get_snode_tree_device_ptr(int tree_id); + + private: + VkRuntime *const runtime_; + std::vector compiled_snode_structs_; +}; + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/spirv_header.h b/taichi/backends/vulkan/spirv_header.h deleted file mode 100644 index 79f73790eba78..0000000000000 --- a/taichi/backends/vulkan/spirv_header.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once - -#ifdef TI_WITH_VULKAN - -#include - -#endif // TI_WITH_VULKAN diff --git a/taichi/backends/vulkan/spirv_snode_compiler.cpp b/taichi/backends/vulkan/spirv_snode_compiler.cpp deleted file mode 100644 index dad4db357b6bf..0000000000000 --- a/taichi/backends/vulkan/spirv_snode_compiler.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "taichi/backends/vulkan/spirv_snode_compiler.h" - -namespace taichi { -namespace lang { -namespace vulkan { - -namespace spirv { - -// Compile SNodes into Spirv-type-based struct -class SpirvSNodeCompiler { - public: - CompiledSpirvSNode run(IRBuilder *builder, - const CompiledSNodeStructs *compiled_structs) { - CompiledSpirvSNode result; - if (compiled_structs->root_size != 0) { - result.root_stype = compute_snode_stype( - builder, compiled_structs, - compiled_structs->snode_descriptors.find(compiled_structs->root->id) - ->second, - &result.snode_id_struct_stype_tbl, &result.snode_id_array_stype_tbl); - } else { // Use an arbitary default type to skip empty root buffer - result.root_stype = builder->i32_type(); - } - return result; - } - - SType compute_snode_stype(IRBuilder *ir_, - const CompiledSNodeStructs *compiled_structs, - const SNodeDescriptor &sn_desc, - SNodeSTypeTbl *snode_id_struct_stype_tbl_, - SNodeSTypeTbl *snode_id_array_stype_tbl_) { - const auto &sn = sn_desc.snode; - if (sn->is_place()) { - return ir_->get_primitive_buffer_type(true, sn->dt); - } else { - SType sn_type = ir_->get_null_type(); - sn_type.snode_desc = sn_desc; - sn_type.flag = TypeKind::kSNodeStruct; - ir_->debug(spv::OpName, sn_type, sn->node_type_name); - - uint32_t cn_cnt = 0; - for (const auto &ch : sn->ch) { - const SNodeDescriptor &ch_desc = - compiled_structs->snode_descriptors.find(ch->id)->second; - const auto &ch_sn = ch_desc.snode; - SType ch_type = compute_snode_stype(ir_, compiled_structs, ch_desc, - snode_id_struct_stype_tbl_, - snode_id_array_stype_tbl_); - SType ch_type_array; - - if (!ch_sn->is_place()) { - ch_type_array = ir_->get_null_type(); - ch_type_array.flag = TypeKind::kSNodeArray; - ch_type_array.element_type_id = ch_type.id; - - Value length = ir_->int_immediate_number( - ir_->i32_type(), ch_desc.cells_per_container_pot()); - ir_->declare_global(spv::OpTypeArray, ch_type_array, ch_type, - length); // Length - ir_->decorate(spv::OpDecorate, ch_type_array, - spv::DecorationArrayStride, - ch_desc.cell_stride); // Stride - } else { - ch_type_array = ch_type; - } - ir_->decorate(spv::OpMemberDecorate, sn_type, cn_cnt++, - spv::DecorationOffset, - ch_desc.mem_offset_in_parent_cell); // Offset - sn_type.snode_child_type_id.push_back(ch_type_array.id); - - TI_ASSERT(snode_id_struct_stype_tbl_->find(ch_sn->id) == - snode_id_struct_stype_tbl_->end()); - snode_id_struct_stype_tbl_->insert( - std::make_pair(ch_sn->id, std::move(ch_type))); - TI_ASSERT(snode_id_array_stype_tbl_->find(ch_sn->id) == - snode_id_array_stype_tbl_->end()); - snode_id_array_stype_tbl_->insert( - std::make_pair(ch_sn->id, std::move(ch_type_array))); - } - - ir_->declare_global(spv::OpTypeStruct, sn_type, - sn_type.snode_child_type_id); - return sn_type; - } - } -}; - -CompiledSpirvSNode compile_spirv_snode_structs( - IRBuilder *builder, - const CompiledSNodeStructs *compiled_structs) { - SpirvSNodeCompiler compiler; - return compiler.run(builder, compiled_structs); -} - -} // namespace spirv -} // namespace vulkan -} // namespace lang -} // namespace taichi diff --git a/taichi/backends/vulkan/spirv_snode_compiler.h b/taichi/backends/vulkan/spirv_snode_compiler.h deleted file mode 100644 index 500cadd00ad1d..0000000000000 --- a/taichi/backends/vulkan/spirv_snode_compiler.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include "taichi/backends/vulkan/spirv_header.h" -#include "taichi/backends/vulkan/spirv_ir_builder.h" - -namespace taichi { -namespace lang { -namespace vulkan { - -namespace spirv { - -using SNodeSTypeTbl = std::unordered_map; - -struct CompiledSpirvSNode { - SType root_stype; - - // map from snode id to snode struct SType - SNodeSTypeTbl snode_id_struct_stype_tbl; - // map from snode id to snode array SType - SNodeSTypeTbl snode_id_array_stype_tbl; - - SType query_snode_struct_stype(const int &id) const { - return snode_id_struct_stype_tbl.find(id)->second; - } - SType query_snode_array_stype(const int &id) const { - return snode_id_array_stype_tbl.find(id)->second; - } -}; - -CompiledSpirvSNode compile_spirv_snode_structs( - IRBuilder *builer, - const CompiledSNodeStructs *compiled_structs); - -} // namespace spirv -} // namespace vulkan -} // namespace lang -} // namespace taichi diff --git a/taichi/backends/vulkan/vulkan_api.cpp b/taichi/backends/vulkan/vulkan_api.cpp index bea72fdeeaf4a..6032b6ed1a45d 100644 --- a/taichi/backends/vulkan/vulkan_api.cpp +++ b/taichi/backends/vulkan/vulkan_api.cpp @@ -1,4 +1,7 @@ +#define VOLK_IMPLEMENTATION + #include "taichi/backends/vulkan/vulkan_api.h" +#include "taichi/backends/vulkan/vulkan_loader.h" namespace vkapi { @@ -74,7 +77,12 @@ DeviceObjVkBufferView::~DeviceObjVkBufferView() { } DeviceObjVkAccelerationStructureKHR::~DeviceObjVkAccelerationStructureKHR() { - vkDestroyAccelerationStructureKHR(device, accel, nullptr); + PFN_vkDestroyAccelerationStructureKHR destroy_raytracing_pipeline_khr = + PFN_vkDestroyAccelerationStructureKHR(vkGetInstanceProcAddr( + taichi::lang::vulkan::VulkanLoader::instance().get_instance(), + "vkDestroyAccelerationStructureKHR")); + + destroy_raytracing_pipeline_khr(device, accel, nullptr); } IDeviceObj create_device_obj(VkDevice device) { @@ -341,7 +349,12 @@ IVkPipeline create_raytracing_pipeline( create_info->basePipelineIndex = 0; } - vkCreateRayTracingPipelinesKHR(device, deferredOperation, + PFN_vkCreateRayTracingPipelinesKHR create_raytracing_pipeline_khr = + PFN_vkCreateRayTracingPipelinesKHR(vkGetInstanceProcAddr( + taichi::lang::vulkan::VulkanLoader::instance().get_instance(), + "vkCreateRayTracingPipelinesKHR")); + + create_raytracing_pipeline_khr(device, deferredOperation, cache ? cache->cache : VK_NULL_HANDLE, 1, create_info, nullptr, &obj->pipeline); @@ -499,7 +512,13 @@ IVkAccelerationStructureKHR create_acceleration_structure( info.type = type; info.deviceAddress = 0; - vkCreateAccelerationStructureKHR(buffer->device, &info, nullptr, &obj->accel); + PFN_vkCreateAccelerationStructureKHR create_acceleration_structure_khr = + PFN_vkCreateAccelerationStructureKHR(vkGetInstanceProcAddr( + taichi::lang::vulkan::VulkanLoader::instance().get_instance(), + "vkCreateAccelerationStructureKHR")); + + create_acceleration_structure_khr(buffer->device, &info, nullptr, + &obj->accel); return obj; } diff --git a/taichi/backends/vulkan/vulkan_common.h b/taichi/backends/vulkan/vulkan_common.h index e6acdadc670cd..5ddf593bc8659 100644 --- a/taichi/backends/vulkan/vulkan_common.h +++ b/taichi/backends/vulkan/vulkan_common.h @@ -4,8 +4,13 @@ #define VK_USE_PLATFORM_WIN32_KHR 1 #endif +#ifdef ANDROID +#define VK_USE_PLATFORM_ANDROID_KHR +#endif + #include #define VK_NO_PROTOTYPES + #include #include @@ -15,13 +20,11 @@ namespace taichi { namespace lang { namespace vulkan { -#pragma message("BAIL_ON_VK_BAD_RESULT uses exception") - -#define BAIL_ON_VK_BAD_RESULT(result, msg) \ - do { \ - if ((result) != VK_SUCCESS) { \ - throw std::runtime_error((msg)); \ - }; \ +#define BAIL_ON_VK_BAD_RESULT(result, msg) \ + do { \ + if ((result) != VK_SUCCESS) { \ + TI_ERROR("Vulkan Error : {} : {}", result, (msg)); \ + }; \ } while (0) inline constexpr VkAllocationCallbacks *kNoVkAllocCallbacks = nullptr; diff --git a/taichi/backends/vulkan/vulkan_device.cpp b/taichi/backends/vulkan/vulkan_device.cpp index dd6de1e59604e..c58c29efc03b9 100644 --- a/taichi/backends/vulkan/vulkan_device.cpp +++ b/taichi/backends/vulkan/vulkan_device.cpp @@ -1,4 +1,4 @@ -#include "taichi/backends/vulkan/embedded_device.h" +#include "taichi/backends/vulkan/vulkan_device_creator.h" #include #include @@ -10,7 +10,7 @@ #include "taichi/backends/vulkan/vulkan_common.h" #include "taichi/backends/vulkan/vulkan_utils.h" -#include "taichi/backends/vulkan/loader.h" +#include "taichi/backends/vulkan/vulkan_loader.h" #include "taichi/backends/vulkan/vulkan_device.h" #include "taichi/common/logging.h" @@ -95,7 +95,8 @@ const std::unordered_map image_layout_ti_2_vk = { {ImageLayout::depth_attachment_read, VK_IMAGE_LAYOUT_DEPTH_READ_ONLY_OPTIMAL}, {ImageLayout::transfer_dst, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL}, - {ImageLayout::transfer_src, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL}}; + {ImageLayout::transfer_src, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL}, + {ImageLayout::present_src, VK_IMAGE_LAYOUT_PRESENT_SRC_KHR}}; VkImageLayout image_layout_ti_to_vk(ImageLayout layout) { if (image_layout_ti_2_vk.find(layout) == image_layout_ti_2_vk.end()) { @@ -164,7 +165,15 @@ vkapi::IVkPipeline VulkanPipeline::graphics_pipeline( blend_attachments[i].colorWriteMask = VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT | VK_COLOR_COMPONENT_B_BIT | VK_COLOR_COMPONENT_A_BIT; - blend_attachments[i].blendEnable = VK_FALSE; + blend_attachments[i].blendEnable = VK_TRUE; + blend_attachments[i].srcColorBlendFactor = VK_BLEND_FACTOR_SRC_ALPHA; + blend_attachments[i].dstColorBlendFactor = + VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA; + blend_attachments[i].colorBlendOp = VK_BLEND_OP_ADD; + blend_attachments[i].srcAlphaBlendFactor = VK_BLEND_FACTOR_SRC_ALPHA; + blend_attachments[i].dstAlphaBlendFactor = + VK_BLEND_FACTOR_ONE_MINUS_SRC_ALPHA; + blend_attachments[i].alphaBlendOp = VK_BLEND_OP_ADD; } graphics_pipeline_template_->color_blending.attachmentCount = @@ -277,6 +286,7 @@ void VulkanPipeline::create_pipeline_layout() { } void VulkanPipeline::create_compute_pipeline(const Params ¶ms) { + TI_TRACE("Compiling Vulkan pipeline {}", params.name); pipeline_ = vkapi::create_compute_pipeline(device_, 0, shader_stages_[0], pipeline_layout_); } @@ -504,7 +514,9 @@ void VulkanResourceBinder::rw_buffer(uint32_t set, TI_WARN("Overriding last binding"); } } - bindings[binding] = {VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, ptr, size}; + + Binding new_binding = {VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, ptr, size}; + bindings[binding] = new_binding; } void VulkanResourceBinder::rw_buffer(uint32_t set, @@ -526,7 +538,9 @@ void VulkanResourceBinder::buffer(uint32_t set, TI_WARN("Overriding last binding"); } } - bindings[binding] = {VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, ptr, size}; + + Binding new_binding = {VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, ptr, size}; + bindings[binding] = new_binding; } void VulkanResourceBinder::buffer(uint32_t set, @@ -666,6 +680,9 @@ VulkanCommandList::~VulkanCommandList() { void VulkanCommandList::bind_pipeline(Pipeline *p) { auto pipeline = static_cast(p); + if (current_pipeline_ == pipeline) + return; + if (pipeline->is_graphics()) { vkapi::IVkPipeline vk_pipeline = pipeline->graphics_pipeline( current_renderpass_desc_, current_renderpass_); @@ -702,10 +719,23 @@ void VulkanCommandList::bind_resources(ResourceBinder *ti_binder) { VulkanResourceBinder *binder = static_cast(ti_binder); for (auto &pair : binder->get_sets()) { + VkPipelineLayout pipeline_layout = + current_pipeline_->pipeline_layout()->layout; + vkapi::IVkDescriptorSetLayout layout = ti_device_->get_desc_set_layout(pair.second); - vkapi::IVkDescriptorSet set = ti_device_->alloc_desc_set(layout); - binder->write_to_set(pair.first, *ti_device_, set); + + vkapi::IVkDescriptorSet set = nullptr; + + if (currently_used_sets_.find(pair.second) != currently_used_sets_.end()) { + set = currently_used_sets_.at(pair.second); + } + + if (!set) { + set = ti_device_->alloc_desc_set(layout); + binder->write_to_set(pair.first, *ti_device_, set); + currently_used_sets_[pair.second] = set; + } VkPipelineBindPoint bind_point; if (current_pipeline_->is_graphics()) { @@ -714,8 +744,7 @@ void VulkanCommandList::bind_resources(ResourceBinder *ti_binder) { bind_point = VK_PIPELINE_BIND_POINT_COMPUTE; } - vkCmdBindDescriptorSets(buffer_->buffer, bind_point, - current_pipeline_->pipeline_layout()->layout, + vkCmdBindDescriptorSets(buffer_->buffer, bind_point, pipeline_layout, /*firstSet=*/0, /*descriptorSetCount=*/1, &set->set, /*dynamicOffsetCount=*/0, @@ -868,15 +897,16 @@ void VulkanCommandList::begin_renderpass(int x0, rp_desc.color_attachments.emplace_back(format, color_clear[i]); fb_desc.attachments.push_back(view); clear_values[i].color = - VkClearColorValue{clear_colors[i][0], clear_colors[i][1], - clear_colors[i][2], clear_colors[i][3]}; + VkClearColorValue{{clear_colors[i][0], clear_colors[i][1], + clear_colors[i][2], clear_colors[i][3]}}; } if (has_depth) { - auto [image, view, format] = ti_device_->get_vk_image(*depth_attachment); + auto [depth_image, depth_view, depth_format] = + ti_device_->get_vk_image(*depth_attachment); clear_values[num_color_attachments].depthStencil = VkClearDepthStencilValue{0.0, 0}; - fb_desc.attachments.push_back(view); + fb_desc.attachments.push_back(depth_view); } current_renderpass_ = ti_device_->get_renderpass(rp_desc); @@ -956,17 +986,21 @@ void VulkanCommandList::image_transition(DeviceAllocation img, static std::unordered_map stages; stages[VK_IMAGE_LAYOUT_UNDEFINED] = VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT; stages[VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL] = VK_PIPELINE_STAGE_TRANSFER_BIT; + stages[VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL] = VK_PIPELINE_STAGE_TRANSFER_BIT; stages[VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL] = VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT; stages[VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL] = VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT; + stages[VK_IMAGE_LAYOUT_PRESENT_SRC_KHR] = VK_PIPELINE_STAGE_TRANSFER_BIT; static std::unordered_map access; access[VK_IMAGE_LAYOUT_UNDEFINED] = (VkAccessFlagBits)0; access[VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL] = VK_ACCESS_TRANSFER_WRITE_BIT; + access[VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL] = VK_ACCESS_TRANSFER_READ_BIT; access[VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL] = VK_ACCESS_SHADER_READ_BIT; access[VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL] = VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT; + access[VK_IMAGE_LAYOUT_PRESENT_SRC_KHR] = VK_ACCESS_MEMORY_READ_BIT; if (stages.find(old_layout) == stages.end() || stages.find(new_layout) == stages.end()) { @@ -1040,8 +1074,64 @@ void VulkanCommandList::image_to_buffer(DevicePtr dst_buf, buffer_->refs.push_back(buffer); } +void VulkanCommandList::copy_image(DeviceAllocation dst_img, + DeviceAllocation src_img, + ImageLayout dst_img_layout, + ImageLayout src_img_layout, + const ImageCopyParams ¶ms) { + VkImageCopy copy{}; + copy.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy.srcSubresource.layerCount = 1; + copy.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy.dstSubresource.layerCount = 1; + copy.extent.width = params.width; + copy.extent.height = params.height; + copy.extent.depth = params.depth; + + auto [dst_vk_image, dst_view, dst_format] = ti_device_->get_vk_image(dst_img); + auto [src_vk_image, src_view, src_format] = ti_device_->get_vk_image(src_img); + + vkCmdCopyImage(buffer_->buffer, src_vk_image->image, + image_layout_ti_to_vk(src_img_layout), dst_vk_image->image, + image_layout_ti_to_vk(dst_img_layout), 1, ©); + + buffer_->refs.push_back(dst_vk_image); + buffer_->refs.push_back(src_vk_image); +} + +void VulkanCommandList::blit_image(DeviceAllocation dst_img, + DeviceAllocation src_img, + ImageLayout dst_img_layout, + ImageLayout src_img_layout, + const ImageCopyParams ¶ms) { + VkOffset3D blit_size; + blit_size.x = params.width; + blit_size.y = params.height; + blit_size.z = params.depth; + VkImageBlit blit{}; + blit.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + blit.srcSubresource.layerCount = 1; + blit.srcOffsets[1] = blit_size; + blit.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + blit.dstSubresource.layerCount = 1; + blit.dstOffsets[1] = blit_size; + + auto [dst_vk_image, dst_view, dst_format] = ti_device_->get_vk_image(dst_img); + auto [src_vk_image, src_view, src_format] = ti_device_->get_vk_image(src_img); + + vkCmdBlitImage(buffer_->buffer, src_vk_image->image, + image_layout_ti_to_vk(src_img_layout), dst_vk_image->image, + image_layout_ti_to_vk(dst_img_layout), 1, &blit, + VK_FILTER_NEAREST); + + buffer_->refs.push_back(dst_vk_image); + buffer_->refs.push_back(src_vk_image); +} + void VulkanCommandList::set_line_width(float width) { - vkCmdSetLineWidth(buffer_->buffer, width); + if (ti_device_->get_cap(DeviceCapability::wide_lines)) { + vkCmdSetLineWidth(buffer_->buffer, width); + } } vkapi::IVkRenderPass VulkanCommandList::current_renderpass() { @@ -1077,8 +1167,8 @@ VulkanDevice::~VulkanDevice() { framebuffer_pools_.clear(); renderpass_pools_.clear(); - vmaDestroyPool(allocator_, export_pool_.pool); vmaDestroyAllocator(allocator_); + vmaDestroyAllocator(allocator_export_); } std::unique_ptr VulkanDevice::create_pipeline( @@ -1144,17 +1234,29 @@ DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR; #endif + bool export_sharing = params.export_sharing && + this->get_cap(DeviceCapability::vk_has_external_memory); + VmaAllocationCreateInfo alloc_info{}; - if (params.export_sharing) { + if (export_sharing) { buffer_info.pNext = &external_mem_buffer_create_info; - alloc_info.pool = export_pool_.pool; } - +#ifdef __APPLE__ + // weird behavior on apple: these flags are needed even if either read or + // write is required + if (params.host_read || params.host_write) { +#else if (params.host_read && params.host_write) { +#endif //__APPLE__ // This should be the unified memory on integrated GPUs alloc_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; alloc_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; +#ifdef __APPLE__ + // weird behavior on apple: if coherent bit is not set, then the memory + // writes between map() and unmap() cannot be seen by gpu + alloc_info.preferredFlags |= VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; +#endif //__APPLE__ } else if (params.host_read) { alloc_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU; } else if (params.host_write) { @@ -1163,15 +1265,29 @@ DeviceAllocation VulkanDevice::allocate_memory(const AllocParams ¶ms) { alloc_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; } - alloc.buffer = - vkapi::create_buffer(device_, allocator_, &buffer_info, &alloc_info); - vmaGetAllocationInfo(allocator_, alloc.buffer->allocation, &alloc.alloc_info); + if (get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) { + buffer_info.usage |= VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_KHR; + } + + alloc.buffer = vkapi::create_buffer( + device_, export_sharing ? allocator_export_ : allocator_, &buffer_info, + &alloc_info); + vmaGetAllocationInfo(alloc.buffer->allocator, alloc.buffer->allocation, + &alloc.alloc_info); #ifdef TI_VULKAN_DEBUG_ALLOCATIONS TI_TRACE("Allocate VK buffer {}, alloc_id={}", (void *)alloc.buffer, handle.alloc_id); #endif + if (get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) { + VkBufferDeviceAddressInfoKHR info{}; + info.sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO_KHR; + info.buffer = alloc.buffer->buffer; + info.pNext = nullptr; + alloc.addr = vkGetBufferDeviceAddressKHR(device_, &info); + } + return handle; } @@ -1191,15 +1307,26 @@ void VulkanDevice::dealloc_memory(DeviceAllocation handle) { allocations_.erase(handle.alloc_id); } +uint64_t VulkanDevice::get_memory_physical_pointer(DeviceAllocation handle) { + const auto &alloc_int = allocations_.at(handle.alloc_id); + return uint64_t(alloc_int.addr); +} + void *VulkanDevice::map_range(DevicePtr ptr, uint64_t size) { AllocationInternal &alloc_int = allocations_.at(ptr.alloc_id); TI_ASSERT_INFO(alloc_int.mapped == nullptr, "Memory can not be mapped multiple times"); - vkMapMemory(device_, alloc_int.alloc_info.deviceMemory, - alloc_int.alloc_info.offset + ptr.offset, size, 0, - &alloc_int.mapped); + if (alloc_int.buffer->allocator) { + vmaMapMemory(alloc_int.buffer->allocator, alloc_int.buffer->allocation, + &alloc_int.mapped); + alloc_int.mapped = (uint8_t *)(alloc_int.mapped) + ptr.offset; + } else { + vkMapMemory(device_, alloc_int.alloc_info.deviceMemory, + alloc_int.alloc_info.offset + ptr.offset, size, 0, + &alloc_int.mapped); + } return alloc_int.mapped; } @@ -1210,9 +1337,14 @@ void *VulkanDevice::map(DeviceAllocation alloc) { TI_ASSERT_INFO(alloc_int.mapped == nullptr, "Memory can not be mapped multiple times"); - vkMapMemory(device_, alloc_int.alloc_info.deviceMemory, - alloc_int.alloc_info.offset, alloc_int.alloc_info.size, 0, - &alloc_int.mapped); + if (alloc_int.buffer->allocator) { + vmaMapMemory(alloc_int.buffer->allocator, alloc_int.buffer->allocation, + &alloc_int.mapped); + } else { + vkMapMemory(device_, alloc_int.alloc_info.deviceMemory, + alloc_int.alloc_info.offset, alloc_int.alloc_info.size, 0, + &alloc_int.mapped); + } return alloc_int.mapped; } @@ -1222,7 +1354,12 @@ void VulkanDevice::unmap(DevicePtr ptr) { TI_ASSERT_INFO(alloc_int.mapped, "Memory is not mapped"); - vkUnmapMemory(device_, alloc_int.alloc_info.deviceMemory); + if (alloc_int.buffer->allocator) { + vmaUnmapMemory(alloc_int.buffer->allocator, alloc_int.buffer->allocation); + } else { + vkUnmapMemory(device_, alloc_int.alloc_info.deviceMemory); + } + alloc_int.mapped = nullptr; } @@ -1231,7 +1368,12 @@ void VulkanDevice::unmap(DeviceAllocation alloc) { TI_ASSERT_INFO(alloc_int.mapped, "Memory is not mapped"); - vkUnmapMemory(device_, alloc_int.alloc_info.deviceMemory); + if (alloc_int.buffer->allocator) { + vmaUnmapMemory(alloc_int.buffer->allocator, alloc_int.buffer->allocation); + } else { + vkUnmapMemory(device_, alloc_int.alloc_info.deviceMemory); + } + alloc_int.mapped = nullptr; } @@ -1287,11 +1429,26 @@ void VulkanStream::submit(CommandList *cmdlist_) { } */ + VkPipelineStageFlags stage_flag{VK_PIPELINE_STAGE_ALL_COMMANDS_BIT}; + VkSubmitInfo submit_info{}; submit_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; submit_info.commandBufferCount = 1; submit_info.pCommandBuffers = &buffer->buffer; + if (last_semaphore_) { + submit_info.waitSemaphoreCount = 1; + submit_info.pWaitSemaphores = &last_semaphore_->semaphore; + submit_info.pWaitDstStageMask = &stage_flag; + } + + auto semaphore = vkapi::create_semaphore(buffer->device, 0); + last_semaphore_ = semaphore; + buffer->refs.push_back(semaphore); + + submit_info.signalSemaphoreCount = 1; + submit_info.pSignalSemaphores = &semaphore->semaphore; + submitted_cmdbuffers_.push_back(buffer); BAIL_ON_VK_BAD_RESULT(vkQueueSubmit(queue_, /*submitCount=*/1, &submit_info, @@ -1308,20 +1465,31 @@ void VulkanStream::submit_synced(CommandList *cmdlist) { submit_info.commandBufferCount = 1; submit_info.pCommandBuffers = &buffer->buffer; + VkPipelineStageFlags stage_flag{VK_PIPELINE_STAGE_ALL_COMMANDS_BIT}; + + if (last_semaphore_) { + submit_info.waitSemaphoreCount = 1; + submit_info.pWaitSemaphores = &last_semaphore_->semaphore; + submit_info.pWaitDstStageMask = &stage_flag; + } + BAIL_ON_VK_BAD_RESULT(vkQueueSubmit(queue_, /*submitCount=*/1, &submit_info, /*fence=*/cmd_sync_fence_->fence), "failed to submit command buffer"); - // Timeout is in nanoseconds, 60s = 60,000ms = 60,000,000ns vkWaitForFences(device_.vk_device(), 1, &cmd_sync_fence_->fence, true, - (60 * 1000 * 1000)); + UINT64_MAX); vkResetFences(device_.vk_device(), 1, &cmd_sync_fence_->fence); + + submitted_cmdbuffers_.clear(); + last_semaphore_ = nullptr; } void VulkanStream::command_sync() { vkQueueWaitIdle(queue_); submitted_cmdbuffers_.clear(); + last_semaphore_ = nullptr; } std::unique_ptr VulkanDevice::create_raster_pipeline( @@ -1345,11 +1513,11 @@ std::unique_ptr VulkanDevice::create_raster_pipeline( } else if (src_desc.stage == PipelineStageType::vertex) { code.stage = VK_SHADER_STAGE_VERTEX_BIT; } else if (src_desc.stage == PipelineStageType::geometry) { - code.stage == VK_SHADER_STAGE_GEOMETRY_BIT; + code.stage = VK_SHADER_STAGE_GEOMETRY_BIT; } else if (src_desc.stage == PipelineStageType::tesselation_control) { - code.stage == VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; + code.stage = VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; } else if (src_desc.stage == PipelineStageType::tesselation_eval) { - code.stage == VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + code.stage = VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; } } @@ -1472,8 +1640,11 @@ DeviceAllocation VulkanDevice::create_image(const ImageParams ¶ms) { alloc.format = image_info.format; + bool export_sharing = params.export_sharing && + this->get_cap(DeviceCapability::vk_has_external_memory); + VkExternalMemoryImageCreateInfo external_mem_image_create_info = {}; - if (params.export_sharing) { + if (export_sharing) { external_mem_image_create_info.sType = VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_IMAGE_CREATE_INFO; external_mem_image_create_info.pNext = NULL; @@ -1490,13 +1661,14 @@ DeviceAllocation VulkanDevice::create_image(const ImageParams ¶ms) { VmaAllocationCreateInfo alloc_info{}; if (params.export_sharing) { - alloc_info.pool = export_pool_.pool; } alloc_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; - alloc.image = - vkapi::create_image(device_, allocator_, &image_info, &alloc_info); - vmaGetAllocationInfo(allocator_, alloc.image->allocation, &alloc.alloc_info); + alloc.image = vkapi::create_image( + device_, export_sharing ? allocator_export_ : allocator_, &image_info, + &alloc_info); + vmaGetAllocationInfo(alloc.image->allocator, alloc.image->allocation, + &alloc.alloc_info); VkImageViewCreateInfo view_info{}; view_info.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO; @@ -1646,7 +1818,10 @@ vkapi::IVkDescriptorSetLayout VulkanDevice::get_desc_set_layout( create_info.bindingCount = bindings.size(); create_info.pBindings = bindings.data(); - return vkapi::create_descriptor_set_layout(device_, &create_info); + auto layout = vkapi::create_descriptor_set_layout(device_, &create_info); + desc_set_layouts_[set] = layout; + + return layout; } else { return desc_set_layouts_.at(set); } @@ -1669,8 +1844,15 @@ vkapi::IVkDescriptorSet VulkanDevice::alloc_desc_set( } void VulkanDevice::create_vma_allocator() { + VmaAllocatorCreateInfo allocatorInfo = {}; + allocatorInfo.vulkanApiVersion = + this->get_cap(DeviceCapability::vk_api_version); + allocatorInfo.physicalDevice = physical_device_; + allocatorInfo.device = device_; + allocatorInfo.instance = instance_; + VolkDeviceTable table; - VmaVulkanFunctions vk_vma_functions; + VmaVulkanFunctions vk_vma_functions{0}; volkLoadDeviceTable(&table, device_); vk_vma_functions.vkGetPhysicalDeviceProperties = @@ -1707,62 +1889,36 @@ void VulkanDevice::create_vma_allocator() { PFN_vkGetPhysicalDeviceMemoryProperties2KHR(vkGetInstanceProcAddr( volkGetLoadedInstance(), "vkGetPhysicalDeviceMemoryProperties2KHR")); - VmaAllocatorCreateInfo allocatorInfo = {}; - allocatorInfo.vulkanApiVersion = - this->get_cap(DeviceCapability::vk_api_version); - allocatorInfo.physicalDevice = physical_device_; - allocatorInfo.device = device_; - allocatorInfo.instance = instance_; allocatorInfo.pVulkanFunctions = &vk_vma_functions; - vmaCreateAllocator(&allocatorInfo, &allocator_); + if (get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) { + allocatorInfo.flags |= VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT; + } - { - VkBufferCreateInfo export_buf_create_info = { - VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO}; - export_buf_create_info.size = 1024; // Whatever. - export_buf_create_info.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | - VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + vmaCreateAllocator(&allocatorInfo, &allocator_); - VmaAllocationCreateInfo alloc_create_info = {}; - alloc_create_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; + VkPhysicalDeviceMemoryProperties properties; + vkGetPhysicalDeviceMemoryProperties(physical_device_, &properties); - uint32_t memTypeIndex; - vmaFindMemoryTypeIndexForBufferInfo(allocator_, &export_buf_create_info, - &alloc_create_info, &memTypeIndex); + std::vector flags( + properties.memoryTypeCount); - export_pool_.export_mem_alloc_info.sType = - VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR; + for (int i = 0; i < properties.memoryTypeCount; i++) { + auto flag = properties.memoryTypes[i].propertyFlags; + if (flag & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) { #ifdef _WIN64 - - export_pool_.export_mem_win32_handle_info.sType = - VK_STRUCTURE_TYPE_EXPORT_MEMORY_WIN32_HANDLE_INFO_KHR; - export_pool_.export_mem_win32_handle_info.pNext = NULL; - export_pool_.export_mem_win32_handle_info.pAttributes = - &export_pool_.win_security_attribs; - export_pool_.export_mem_win32_handle_info.dwAccess = - DXGI_SHARED_RESOURCE_READ | DXGI_SHARED_RESOURCE_WRITE; - export_pool_.export_mem_win32_handle_info.name = (LPCWSTR)NULL; - - export_pool_.export_mem_alloc_info.pNext = - &export_pool_.export_mem_win32_handle_info; - export_pool_.export_mem_alloc_info.handleTypes = - VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT; + flags[i] = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT; #else - export_pool_.export_mem_alloc_info.pNext = NULL; - export_pool_.export_mem_alloc_info.handleTypes = - VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR; + flags[i] = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT; #endif + } else { + flags[i] = 0; + } + } - VmaPoolCreateInfo pool_info{}; - pool_info.memoryTypeIndex = memTypeIndex; - pool_info.blockSize = kMemoryBlockSize; // 128MB - pool_info.maxBlockCount = 16; - pool_info.pMemoryAllocateNext = &export_pool_.export_mem_alloc_info; + allocatorInfo.pTypeExternalMemoryHandleTypes = flags.data(); - vmaCreatePool(allocator_, &pool_info, &export_pool_.pool); - } + vmaCreateAllocator(&allocatorInfo, &allocator_export_); } void VulkanDevice::new_descriptor_pool() { @@ -1825,24 +1981,54 @@ VkPresentModeKHR choose_swap_present_mode( } VulkanSurface::VulkanSurface(VulkanDevice *device, const SurfaceConfig &config) - : device_(device), config_(config) { - glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); + : config_(config), device_(device) { +#if !defined(TI_EMSCRIPTENED) +#ifdef ANDROID + window_ = (ANativeWindow *)config.window_handle; +#else window_ = (GLFWwindow *)config.window_handle; - VkResult err = - glfwCreateWindowSurface(device->vk_instance(), window_, NULL, &surface_); - if (err) { - TI_ERROR("Failed to create window surface ({})", err); - return; - } +#endif + if (window_) { +#ifdef ANDROID + VkAndroidSurfaceCreateInfoKHR createInfo{ + .sType = VK_STRUCTURE_TYPE_ANDROID_SURFACE_CREATE_INFO_KHR, + .pNext = nullptr, + .flags = 0, + .window = window_}; + + vkCreateAndroidSurfaceKHR(device->vk_instance(), &createInfo, nullptr, + &surface_); +#else + glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); + VkResult err = glfwCreateWindowSurface(device->vk_instance(), window_, NULL, + &surface_); + if (err) { + TI_ERROR("Failed to create window surface ({})", err); + return; + } +#endif - create_swap_chain(); + create_swap_chain(); - VkSemaphoreCreateInfo sema_create_info; - sema_create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; - sema_create_info.pNext = nullptr; - sema_create_info.flags = 0; - vkCreateSemaphore(device->vk_device(), &sema_create_info, kNoVkAllocCallbacks, - &image_available_); + VkSemaphoreCreateInfo sema_create_info; + sema_create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; + sema_create_info.pNext = nullptr; + sema_create_info.flags = 0; + vkCreateSemaphore(device->vk_device(), &sema_create_info, + kNoVkAllocCallbacks, &image_available_); + } else { + ImageParams params = {ImageDimension::d2D, + BufferFormat::rgba8, + ImageLayout::present_src, + config.width, + config.height, + 1, + false}; + // screenshot_image_ = device->create_image(params); + swapchain_images_.push_back(device->create_image(params)); + swapchain_images_.push_back(device->create_image(params)); + } +#endif } void VulkanSurface::create_swap_chain() { @@ -1895,7 +2081,12 @@ void VulkanSurface::create_swap_chain() { choose_swap_present_mode(present_modes, config_.vsync, config_.adaptive); int width, height; +#ifdef ANDROID + width = ANativeWindow_getWidth(window_); + height = ANativeWindow_getWidth(window_); +#elif !defined(TI_EMSCRIPTENED) glfwGetFramebufferSize(window_, &width, &height); +#endif VkExtent2D extent = {uint32_t(width), uint32_t(height)}; @@ -1909,7 +2100,8 @@ void VulkanSurface::create_swap_chain() { createInfo.imageColorSpace = surface_format.colorSpace; createInfo.imageExtent = extent; createInfo.imageArrayLayers = 1; - createInfo.imageUsage = VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT; + createInfo.imageUsage = + VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT; createInfo.imageSharingMode = VK_SHARING_MODE_EXCLUSIVE; createInfo.queueFamilyIndexCount = 0; createInfo.pQueueFamilyIndices = nullptr; @@ -1917,7 +2109,7 @@ void VulkanSurface::create_swap_chain() { createInfo.compositeAlpha = VK_COMPOSITE_ALPHA_OPAQUE_BIT_KHR; createInfo.presentMode = present_mode; createInfo.clipped = VK_TRUE; - createInfo.oldSwapchain = nullptr; + createInfo.oldSwapchain = VK_NULL_HANDLE; if (vkCreateSwapchainKHR(device_->vk_device(), &createInfo, kNoVkAllocCallbacks, &swapchain_) != VK_SUCCESS) { @@ -1969,10 +2161,24 @@ void VulkanSurface::destroy_swap_chain() { vkDestroySwapchainKHR(device_->vk_device(), swapchain_, nullptr); } +int VulkanSurface::get_image_count() { + return swapchain_images_.size(); +} + VulkanSurface::~VulkanSurface() { - destroy_swap_chain(); - vkDestroySemaphore(device_->vk_device(), image_available_, nullptr); - vkDestroySurfaceKHR(device_->vk_instance(), surface_, nullptr); + if (config_.window_handle) { + destroy_swap_chain(); + vkDestroySemaphore(device_->vk_device(), image_available_, nullptr); + vkDestroySurfaceKHR(device_->vk_instance(), surface_, nullptr); + } else { + for (auto &img : swapchain_images_) { + device_->destroy_image(img); + } + swapchain_images_.clear(); + } + if (screenshot_buffer_ != kDeviceNullAllocation) { + device_->dealloc_memory(screenshot_buffer_); + } } void VulkanSurface::resize(uint32_t width, uint32_t height) { @@ -1981,14 +2187,26 @@ void VulkanSurface::resize(uint32_t width, uint32_t height) { } std::pair VulkanSurface::get_size() { + if (!config_.window_handle) { + return std::make_pair(config_.width, config_.height); + } int width, height; +#ifdef ANDROID + width = ANativeWindow_getWidth(window_); + height = ANativeWindow_getWidth(window_); +#elif !defined(TI_EMSCRIPTENED) glfwGetFramebufferSize(window_, &width, &height); +#endif return std::make_pair(width, height); } DeviceAllocation VulkanSurface::get_target_image() { - vkAcquireNextImageKHR(device_->vk_device(), swapchain_, UINT64_MAX, - image_available_, VK_NULL_HANDLE, &image_index_); + if (!config_.window_handle) { + image_index_ = (image_index_ + 1) % swapchain_images_.size(); + } else { + vkAcquireNextImageKHR(device_->vk_device(), swapchain_, UINT64_MAX, + image_available_, VK_NULL_HANDLE, &image_index_); + } return swapchain_images_[image_index_]; } @@ -2013,6 +2231,70 @@ void VulkanSurface::present_image() { vkQueuePresentKHR(device_->graphics_queue(), &presentInfo); } +DeviceAllocation VulkanSurface::get_image_data() { + auto *stream = device_->get_graphics_stream(); + DeviceAllocation img_alloc = swapchain_images_[image_index_]; + auto [w, h] = get_size(); + size_t size_bytes = w * h * 4; + + /* + if (screenshot_image_ == kDeviceNullAllocation) { + ImageParams params = {ImageDimension::d2D, + BufferFormat::rgba8, + ImageLayout::transfer_dst, + w, + h, + 1, + false}; + screenshot_image_ = device_->create_image(params); + } + */ + + if (screenshot_buffer_ == kDeviceNullAllocation) { + Device::AllocParams params{size_bytes, /*host_wrtie*/ false, + /*host_read*/ true, /*export_sharing*/ false, + AllocUsage::Uniform}; + screenshot_buffer_ = device_->allocate_memory(params); + } + + device_->image_transition(img_alloc, ImageLayout::present_src, + ImageLayout::transfer_src); + + std::unique_ptr cmd_list{nullptr}; + + /* + if (config_.window_handle) { + // TODO: check if blit is suppoted, and use copy_image if not + cmd_list = stream->new_command_list(); + cmd_list->blit_image(screenshot_image_, img_alloc, + ImageLayout::transfer_dst, ImageLayout::transfer_src, + {w, h, 1}); + cmd_list->image_transition(screenshot_image_, ImageLayout::transfer_dst, + ImageLayout::transfer_src); + stream->submit_synced(cmd_list.get()); + } + */ + + BufferImageCopyParams copy_params; + copy_params.image_extent.x = w; + copy_params.image_extent.y = h; + cmd_list = stream->new_command_list(); + // TODO: directly map the image to cpu memory + cmd_list->image_to_buffer(screenshot_buffer_.get_ptr(), img_alloc, + ImageLayout::transfer_src, copy_params); + /* + if (config_.window_handle) { + cmd_list->image_transition(screenshot_image_, ImageLayout::transfer_src, + ImageLayout::transfer_dst); + } + */ + cmd_list->image_transition(img_alloc, ImageLayout::transfer_src, + ImageLayout::present_src); + stream->submit_synced(cmd_list.get()); + + return screenshot_buffer_; +} + VulkanStream::VulkanStream(VulkanDevice &device, VkQueue queue, uint32_t queue_family_index) diff --git a/taichi/backends/vulkan/vulkan_device.h b/taichi/backends/vulkan/vulkan_device.h index 0a5d8b4fb270f..8281984fec4cb 100644 --- a/taichi/backends/vulkan/vulkan_device.h +++ b/taichi/backends/vulkan/vulkan_device.h @@ -4,7 +4,11 @@ #include +#ifdef ANDROID +#include +#elif !defined(TI_EMSCRIPTENED) #include +#endif #include #include @@ -96,8 +100,17 @@ class VulkanResourceBinder : public ResourceBinder { struct Binding { VkDescriptorType type; DevicePtr ptr; - size_t size; + VkDeviceSize size; VkSampler sampler{VK_NULL_HANDLE}; // used only for images + + bool operator==(const Binding &other) const { + return other.type == type && other.ptr == ptr && other.size == size && + other.sampler == sampler; + } + + bool operator!=(const Binding &other) const { + return !(*this == other); + } }; struct Set { @@ -109,13 +122,21 @@ class VulkanResourceBinder : public ResourceBinder { return false; } for (auto &pair : bindings) { - const Binding &other_binding = other.bindings.at(pair.first); + auto other_binding_iter = other.bindings.find(pair.first); + if (other_binding_iter == other.bindings.end()) { + return false; + } + const Binding &other_binding = other_binding_iter->second; if (other_binding.type != pair.second.type) { return false; } } return true; } + + bool operator!=(const Set &other) const { + return !(*this == other); + } }; struct SetLayoutHasher { @@ -129,6 +150,45 @@ class VulkanResourceBinder : public ResourceBinder { } }; + struct DescSetCmp { + bool operator()(const Set &a, const Set &b) const { + if (a.bindings.size() != b.bindings.size()) { + return false; + } + for (auto &pair : a.bindings) { + auto other_binding_iter = b.bindings.find(pair.first); + if (other_binding_iter == b.bindings.end()) { + return false; + } + const Binding &other_binding = other_binding_iter->second; + if (other_binding != pair.second) { + return false; + } + } + return true; + } + }; + + struct DescSetHasher { + std::size_t operator()(const Set &set) const { + // TODO: Come up with a better hash + size_t hash = 0; + for (const auto &pair : set.bindings) { + size_t binding_hash = 0; + uint32_t *u32_ptr = (uint32_t *)&pair.second; + for (int i = 0; i < sizeof(Set) / sizeof(uint32_t); i++) { + binding_hash = binding_hash ^ u32_ptr[i]; + binding_hash = (binding_hash << 7) | (binding_hash >> (64 - 7)); + } + binding_hash = binding_hash ^ pair.first; + binding_hash = + (binding_hash << pair.first) | (binding_hash >> (64 - pair.first)); + hash = hash ^ binding_hash; + } + return hash; + } + }; + struct VulkanBindings : public Bindings { std::vector< std::pair> @@ -141,10 +201,18 @@ class VulkanResourceBinder : public ResourceBinder { std::unique_ptr materialize() override; - void rw_buffer(uint32_t set, uint32_t binding, DevicePtr ptr, size_t size); - void rw_buffer(uint32_t set, uint32_t binding, DeviceAllocation alloc); - void buffer(uint32_t set, uint32_t binding, DevicePtr ptr, size_t size); - void buffer(uint32_t set, uint32_t binding, DeviceAllocation alloc); + void rw_buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) override; + void rw_buffer(uint32_t set, + uint32_t binding, + DeviceAllocation alloc) override; + void buffer(uint32_t set, + uint32_t binding, + DevicePtr ptr, + size_t size) override; + void buffer(uint32_t set, uint32_t binding, DeviceAllocation alloc) override; void image(uint32_t set, uint32_t binding, DeviceAllocation alloc, @@ -314,6 +382,18 @@ class VulkanCommandList : public CommandList { ImageLayout img_layout, const BufferImageCopyParams ¶ms) override; + void copy_image(DeviceAllocation dst_img, + DeviceAllocation src_img, + ImageLayout dst_img_layout, + ImageLayout src_img_layout, + const ImageCopyParams ¶ms) override; + + void blit_image(DeviceAllocation dst_img, + DeviceAllocation src_img, + ImageLayout dst_img_layout, + ImageLayout src_img_layout, + const ImageCopyParams ¶ms) override; + vkapi::IVkRenderPass current_renderpass(); // Vulkan specific functions @@ -329,6 +409,12 @@ class VulkanCommandList : public CommandList { vkapi::IVkCommandBuffer buffer_; VulkanPipeline *current_pipeline_{nullptr}; + std::unordered_map + currently_used_sets_; + // Renderpass & raster pipeline VulkanRenderPassDesc current_renderpass_desc_; vkapi::IVkRenderPass current_renderpass_{VK_NULL_HANDLE}; @@ -345,8 +431,11 @@ class VulkanSurface : public Surface { void present_image() override; std::pair get_size() override; + int get_image_count() override; BufferFormat image_format() override; - virtual void resize(uint32_t width, uint32_t height); + void resize(uint32_t width, uint32_t height) override; + + DeviceAllocation get_image_data() override; private: void create_swap_chain(); @@ -358,25 +447,19 @@ class VulkanSurface : public Surface { VkSurfaceKHR surface_; VkSwapchainKHR swapchain_; VkSemaphore image_available_; +#ifdef ANDROID + ANativeWindow *window_; +#elif !defined(TI_EMSCRIPTENED) GLFWwindow *window_; +#endif BufferFormat image_format_; uint32_t image_index_{0}; std::vector swapchain_images_; -}; -struct VulkanMemoryPool { - VmaPool pool; - - // the lifetime of these needs to == the lifetime of the vmapool. - // because these are needed for allocating memory, which happens multiple - // times. - VkExportMemoryAllocateInfoKHR export_mem_alloc_info{}; -#ifdef _WIN64 - WindowsSecurityAttributes win_security_attribs; - VkExportMemoryWin32HandleInfoKHR export_mem_win32_handle_info{}; -#endif + // DeviceAllocation screenshot_image_{kDeviceNullAllocation}; + DeviceAllocation screenshot_buffer_{kDeviceNullAllocation}; }; struct DescPool { @@ -406,6 +489,8 @@ class VulkanStream : public Stream { VkQueue queue_; uint32_t queue_family_index_; + vkapi::IVkSemaphore last_semaphore_{nullptr}; + // Command pools are per-thread vkapi::IVkFence cmd_sync_fence_; vkapi::IVkCommandPool command_pool_; @@ -434,6 +519,8 @@ class VulkanDevice : public GraphicsDevice { DeviceAllocation allocate_memory(const AllocParams ¶ms) override; void dealloc_memory(DeviceAllocation handle) override; + uint64_t get_memory_physical_pointer(DeviceAllocation handle) override; + // Mapping can fail and will return nullptr void *map_range(DevicePtr ptr, uint64_t size) override; void *map(DeviceAllocation alloc) override; @@ -509,8 +596,6 @@ class VulkanDevice : public GraphicsDevice { VulkanResourceBinder::Set &set); vkapi::IVkDescriptorSet alloc_desc_set(vkapi::IVkDescriptorSetLayout layout); - static constexpr size_t kMemoryBlockSize = 128ull * 1024 * 1024; - private: void create_vma_allocator(); void new_descriptor_pool(); @@ -519,7 +604,7 @@ class VulkanDevice : public GraphicsDevice { VkDevice device_; VkPhysicalDevice physical_device_; VmaAllocator allocator_; - VulkanMemoryPool export_pool_; + VmaAllocator allocator_export_{nullptr}; VkQueue compute_queue_; uint32_t compute_queue_family_index_; @@ -536,6 +621,7 @@ class VulkanDevice : public GraphicsDevice { VmaAllocationInfo alloc_info; vkapi::IVkBuffer buffer; void *mapped{nullptr}; + VkDeviceAddress addr{0}; }; unordered_map allocations_; diff --git a/taichi/backends/vulkan/embedded_device.cpp b/taichi/backends/vulkan/vulkan_device_creator.cpp similarity index 68% rename from taichi/backends/vulkan/embedded_device.cpp rename to taichi/backends/vulkan/vulkan_device_creator.cpp index 689ce45b86d22..dbaf3675e7f43 100644 --- a/taichi/backends/vulkan/embedded_device.cpp +++ b/taichi/backends/vulkan/vulkan_device_creator.cpp @@ -1,4 +1,4 @@ -#include "taichi/backends/vulkan/embedded_device.h" +#include "taichi/backends/vulkan/vulkan_device_creator.h" #include #include @@ -7,7 +7,7 @@ #include #include "taichi/backends/vulkan/vulkan_common.h" -#include "taichi/backends/vulkan/loader.h" +#include "taichi/backends/vulkan/vulkan_loader.h" #include "taichi/backends/vulkan/vulkan_device.h" #include "taichi/common/logging.h" @@ -162,7 +162,7 @@ bool is_device_suitable(VkPhysicalDevice device, VkSurfaceKHR surface) { // this means we need ui VkPhysicalDeviceFeatures features{}; vkGetPhysicalDeviceFeatures(device, &features); - return indices.is_complete_for_ui() && features.wideLines == VK_TRUE; + return indices.is_complete_for_ui(); } else { return indices.is_complete(); } @@ -170,8 +170,8 @@ bool is_device_suitable(VkPhysicalDevice device, VkSurfaceKHR surface) { } // namespace -EmbeddedVulkanDevice::EmbeddedVulkanDevice( - const EmbeddedVulkanDevice::Params ¶ms) +VulkanDeviceCreator::VulkanDeviceCreator( + const VulkanDeviceCreator::Params ¶ms) : params_(params) { if (!VulkanLoader::instance().init()) { throw std::runtime_error("Error loading vulkan"); @@ -202,7 +202,7 @@ EmbeddedVulkanDevice::EmbeddedVulkanDevice( } } -EmbeddedVulkanDevice::~EmbeddedVulkanDevice() { +VulkanDeviceCreator::~VulkanDeviceCreator() { ti_device_.reset(); if (surface_ != VK_NULL_HANDLE) { vkDestroySurfaceKHR(instance_, surface_, kNoVkAllocCallbacks); @@ -215,7 +215,7 @@ EmbeddedVulkanDevice::~EmbeddedVulkanDevice() { vkDestroyInstance(instance_, kNoVkAllocCallbacks); } -void EmbeddedVulkanDevice::create_instance() { +void VulkanDeviceCreator::create_instance() { VkApplicationInfo app_info{}; app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; app_info.pApplicationName = "Taichi Vulkan Backend"; @@ -228,7 +228,7 @@ void EmbeddedVulkanDevice::create_instance() { app_info.apiVersion = params_.api_version.value(); } else { // The highest version designed to use - app_info.apiVersion = VK_API_VERSION_1_2; + app_info.apiVersion = VK_API_VERSION_1_3; } VkInstanceCreateInfo create_info{}; @@ -304,10 +304,11 @@ void EmbeddedVulkanDevice::create_instance() { if (res != VK_SUCCESS) { throw std::runtime_error("failed to create instance"); } + VulkanLoader::instance().load_instance(instance_); } -void EmbeddedVulkanDevice::setup_debug_messenger() { +void VulkanDeviceCreator::setup_debug_messenger() { if constexpr (!kEnableValidationLayers) { return; } @@ -320,11 +321,11 @@ void EmbeddedVulkanDevice::setup_debug_messenger() { "failed to set up debug messenger"); } -void EmbeddedVulkanDevice::create_surface() { +void VulkanDeviceCreator::create_surface() { surface_ = params_.surface_creator(instance_); } -void EmbeddedVulkanDevice::pick_physical_device() { +void VulkanDeviceCreator::pick_physical_device() { uint32_t device_count = 0; vkEnumeratePhysicalDevices(instance_, &device_count, nullptr); TI_ASSERT_INFO(device_count > 0, "failed to find GPUs with Vulkan support"); @@ -332,10 +333,34 @@ void EmbeddedVulkanDevice::pick_physical_device() { std::vector devices(device_count); vkEnumeratePhysicalDevices(instance_, &device_count, devices.data()); physical_device_ = VK_NULL_HANDLE; - for (const auto &device : devices) { - if (is_device_suitable(device, surface_)) { - physical_device_ = device; - break; + + for (int i = 0; i < device_count; i++) { + VkPhysicalDeviceProperties properties{}; + vkGetPhysicalDeviceProperties(devices[i], &properties); + TI_INFO("Found Vulkan Device {} ({})", i, properties.deviceName); + } + + auto device_id = VulkanLoader::instance().visible_device_id; + bool has_visible_device{0}; + if (!device_id.empty()) { + int id = std::stoi(device_id); + TI_ASSERT_INFO( + (id >= 0) && (id < device_count), + "TI_VISIBLE_DEVICE={} is not valid, found {} devices available", id, + device_count); + if (is_device_suitable(devices[id], surface_)) { + physical_device_ = devices[id]; + has_visible_device = 1; + } + } + + if (!has_visible_device) { + // could not find a user defined visible device, use the first one suitable + for (const auto &device : devices) { + if (is_device_suitable(device, surface_)) { + physical_device_ = device; + break; + } } } TI_ASSERT_INFO(physical_device_ != VK_NULL_HANDLE, @@ -344,15 +369,15 @@ void EmbeddedVulkanDevice::pick_physical_device() { queue_family_indices_ = find_queue_families(physical_device_, surface_); } -void EmbeddedVulkanDevice::create_logical_device() { +void VulkanDeviceCreator::create_logical_device() { std::vector queue_create_infos; std::unordered_set unique_families; - if (params_.is_for_ui) { - unique_families = {queue_family_indices_.graphics_family.value(), - queue_family_indices_.present_family.value()}; - } else { - unique_families = {queue_family_indices_.compute_family.value()}; + if (queue_family_indices_.compute_family.has_value()) { + unique_families.insert(queue_family_indices_.compute_family.value()); + } + if (queue_family_indices_.graphics_family.has_value()) { + unique_families.insert(queue_family_indices_.graphics_family.value()); } float queue_priority = 1.0f; @@ -371,13 +396,24 @@ void EmbeddedVulkanDevice::create_logical_device() { create_info.queueCreateInfoCount = queue_create_infos.size(); // Get device properties - VkPhysicalDeviceProperties physical_device_properties; + VkPhysicalDeviceProperties physical_device_properties{}; vkGetPhysicalDeviceProperties(physical_device_, &physical_device_properties); + TI_INFO("Vulkan Device \"{}\" supports Vulkan {} version {}.{}.{}", + physical_device_properties.deviceName, + VK_API_VERSION_VARIANT(physical_device_properties.apiVersion), + VK_API_VERSION_MAJOR(physical_device_properties.apiVersion), + VK_API_VERSION_MINOR(physical_device_properties.apiVersion), + VK_API_VERSION_PATCH(physical_device_properties.apiVersion)); + ti_device_->set_cap(DeviceCapability::vk_api_version, physical_device_properties.apiVersion); ti_device_->set_cap(DeviceCapability::spirv_version, 0x10000); - if (physical_device_properties.apiVersion >= VK_API_VERSION_1_1) { + if (physical_device_properties.apiVersion >= VK_API_VERSION_1_3) { + ti_device_->set_cap(DeviceCapability::spirv_version, 0x10600); + } else if (physical_device_properties.apiVersion >= VK_API_VERSION_1_2) { + ti_device_->set_cap(DeviceCapability::spirv_version, 0x10500); + } else if (physical_device_properties.apiVersion >= VK_API_VERSION_1_1) { ti_device_->set_cap(DeviceCapability::spirv_version, 0x10300); } @@ -391,7 +427,9 @@ void EmbeddedVulkanDevice::create_logical_device() { vkEnumerateDeviceExtensionProperties( physical_device_, nullptr, &extension_count, extension_properties.data()); - bool has_surface = false, has_swapchain = false; + bool has_swapchain = false; + + bool portability_subset_enabled = false; for (auto &ext : extension_properties) { TI_TRACE("Vulkan device extension {} ({})", ext.extensionName, @@ -403,33 +441,40 @@ void EmbeddedVulkanDevice::create_logical_device() { TI_WARN( "Potential non-conformant Vulkan implementation, enabling " "VK_KHR_portability_subset"); + portability_subset_enabled = true; enabled_extensions.push_back(ext.extensionName); } else if (name == VK_KHR_SWAPCHAIN_EXTENSION_NAME) { has_swapchain = true; enabled_extensions.push_back(ext.extensionName); } else if (name == VK_EXT_SHADER_ATOMIC_FLOAT_EXTENSION_NAME) { enabled_extensions.push_back(ext.extensionName); - } else if (name == "VK_EXT_shader_atomic_float2") { - // FIXME: This feature requires vulkan headers with - // VK_EXT_shader_atomic_float2 - /* + } else if (name == VK_EXT_SHADER_ATOMIC_FLOAT_2_EXTENSION_NAME) { enabled_extensions.push_back(ext.extensionName); - */ } else if (name == VK_KHR_SHADER_ATOMIC_INT64_EXTENSION_NAME) { - // ti_device_->set_cap(DeviceCapability::vk_has_atomic_i64, true); - // enabled_extensions.push_back(ext.extensionName); + enabled_extensions.push_back(ext.extensionName); } else if (name == VK_KHR_SYNCHRONIZATION_2_EXTENSION_NAME) { enabled_extensions.push_back(ext.extensionName); } else if (name == VK_KHR_SPIRV_1_4_EXTENSION_NAME) { - ti_device_->set_cap(DeviceCapability::spirv_version, 0x10400); - enabled_extensions.push_back(ext.extensionName); - } else if (name == VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME) { + if (ti_device_->get_cap(DeviceCapability::spirv_version) < 0x10400) { + ti_device_->set_cap(DeviceCapability::spirv_version, 0x10400); + enabled_extensions.push_back(ext.extensionName); + } + } else if (name == VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME || + name == VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME) { ti_device_->set_cap(DeviceCapability::vk_has_external_memory, true); enabled_extensions.push_back(ext.extensionName); } else if (name == VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME) { enabled_extensions.push_back(ext.extensionName); } else if (name == VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME) { enabled_extensions.push_back(ext.extensionName); + } else if (name == VK_KHR_GET_MEMORY_REQUIREMENTS_2_EXTENSION_NAME) { + enabled_extensions.push_back(ext.extensionName); + } else if (name == VK_KHR_DEDICATED_ALLOCATION_EXTENSION_NAME) { + enabled_extensions.push_back(ext.extensionName); + } else if (name == VK_KHR_BIND_MEMORY_2_EXTENSION_NAME) { + enabled_extensions.push_back(ext.extensionName); + } else if (name == VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME) { + enabled_extensions.push_back(ext.extensionName); } else if (std::find(params_.additional_device_extensions.begin(), params_.additional_device_extensions.end(), name) != params_.additional_device_extensions.end()) { @@ -460,11 +505,45 @@ void EmbeddedVulkanDevice::create_logical_device() { } if (device_supported_features.wideLines) { device_features.wideLines = true; + ti_device_->set_cap(DeviceCapability::wide_lines, true); } else if (params_.is_for_ui) { TI_WARN_IF(!device_features.wideLines, "Taichi GPU GUI requires wide lines support"); } + if (physical_device_properties.apiVersion >= VK_API_VERSION_1_1) { + VkPhysicalDeviceSubgroupProperties subgroup_properties{}; + subgroup_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + subgroup_properties.pNext = NULL; + + VkPhysicalDeviceProperties2 physical_device_properties{}; + physical_device_properties.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + physical_device_properties.pNext = &subgroup_properties; + + vkGetPhysicalDeviceProperties2(physical_device_, + &physical_device_properties); + + if (subgroup_properties.supportedOperations & + VK_SUBGROUP_FEATURE_BASIC_BIT) { + ti_device_->set_cap(DeviceCapability::spirv_has_subgroup_basic, true); + } + if (subgroup_properties.supportedOperations & + VK_SUBGROUP_FEATURE_VOTE_BIT) { + ti_device_->set_cap(DeviceCapability::spirv_has_subgroup_vote, true); + } + if (subgroup_properties.supportedOperations & + VK_SUBGROUP_FEATURE_ARITHMETIC_BIT) { + ti_device_->set_cap(DeviceCapability::spirv_has_subgroup_arithmetic, + true); + } + if (subgroup_properties.supportedOperations & + VK_SUBGROUP_FEATURE_BALLOT_BIT) { + ti_device_->set_cap(DeviceCapability::spirv_has_subgroup_ballot, true); + } + } + create_info.pEnabledFeatures = &device_features; create_info.enabledExtensionCount = enabled_extensions.size(); create_info.ppEnabledExtensionNames = enabled_extensions.data(); @@ -478,16 +557,32 @@ void EmbeddedVulkanDevice::create_logical_device() { VkPhysicalDeviceShaderAtomicFloatFeaturesEXT shader_atomic_float_feature{}; shader_atomic_float_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT; + VkPhysicalDeviceShaderAtomicFloat2FeaturesEXT shader_atomic_float_2_feature{}; + shader_atomic_float_2_feature.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_2_FEATURES_EXT; VkPhysicalDeviceFloat16Int8FeaturesKHR shader_f16_i8_feature{}; shader_f16_i8_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR; + VkPhysicalDeviceBufferDeviceAddressFeaturesKHR + buffer_device_address_feature{}; + buffer_device_address_feature.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_KHR; if (ti_device_->get_cap(DeviceCapability::vk_has_physical_features2)) { VkPhysicalDeviceFeatures2KHR features2{}; features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; +#define CHECK_EXTENSION(ext) \ + std::find(enabled_extensions.begin(), enabled_extensions.end(), ext) != \ + enabled_extensions.end() + +#define CHECK_VERSION(major, minor) \ + physical_device_properties.apiVersion >= \ + VK_MAKE_API_VERSION(0, major, minor, 0) + // Variable ptr - { + if (CHECK_VERSION(1, 1) || + CHECK_EXTENSION(VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME)) { features2.pNext = &variable_ptr_feature; vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2); @@ -500,38 +595,93 @@ void EmbeddedVulkanDevice::create_logical_device() { } // Atomic float - { + if (CHECK_EXTENSION(VK_EXT_SHADER_ATOMIC_FLOAT_EXTENSION_NAME)) { features2.pNext = &shader_atomic_float_feature; vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2); - if (shader_atomic_float_feature.shaderBufferFloat32AtomicAdd) { ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float_add, true); - } else if (shader_atomic_float_feature.shaderBufferFloat64AtomicAdd) { + } + if (shader_atomic_float_feature.shaderBufferFloat64AtomicAdd) { ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float64_add, true); - } else if (shader_atomic_float_feature.shaderBufferFloat32Atomics) { + } + if (shader_atomic_float_feature.shaderBufferFloat32Atomics) { ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float, true); - } else if (shader_atomic_float_feature.shaderBufferFloat64Atomics) { + } + if (shader_atomic_float_feature.shaderBufferFloat64Atomics) { ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float64, true); } *pNextEnd = &shader_atomic_float_feature; pNextEnd = &shader_atomic_float_feature.pNext; } + // Atomic float 2 + if (CHECK_EXTENSION(VK_EXT_SHADER_ATOMIC_FLOAT_2_EXTENSION_NAME)) { + features2.pNext = &shader_atomic_float_2_feature; + vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2); + if (shader_atomic_float_2_feature.shaderBufferFloat16AtomicAdd) { + ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float_add, true); + } + if (shader_atomic_float_2_feature.shaderBufferFloat16AtomicMinMax) { + ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float16_minmax, + true); + } + if (shader_atomic_float_2_feature.shaderBufferFloat16Atomics) { + ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float16, true); + } + if (shader_atomic_float_2_feature.shaderBufferFloat32AtomicMinMax) { + ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float_minmax, + true); + } + if (shader_atomic_float_2_feature.shaderBufferFloat64AtomicMinMax) { + ti_device_->set_cap(DeviceCapability::spirv_has_atomic_float64_minmax, + true); + } + *pNextEnd = &shader_atomic_float_2_feature; + pNextEnd = &shader_atomic_float_2_feature.pNext; + } + // F16 / I8 +#ifdef __APPLE__ { +#else + if (CHECK_VERSION(1, 2) || + CHECK_EXTENSION(VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME)) { +#endif features2.pNext = &shader_f16_i8_feature; vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2); if (shader_f16_i8_feature.shaderFloat16) { ti_device_->set_cap(DeviceCapability::spirv_has_float16, true); - } else if (shader_f16_i8_feature.shaderInt8) { + } + if (shader_f16_i8_feature.shaderInt8) { + ti_device_->set_cap(DeviceCapability::spirv_has_int8, true); + } + if (portability_subset_enabled) { + // TODO: investigate why MoltenVK isn't reporting int8 caps. See #3252 ti_device_->set_cap(DeviceCapability::spirv_has_int8, true); } *pNextEnd = &shader_f16_i8_feature; pNextEnd = &shader_f16_i8_feature.pNext; } + // Buffer Device Address + if (CHECK_VERSION(1, 2) || + CHECK_EXTENSION(VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME)) { + features2.pNext = &buffer_device_address_feature; + vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2); + + if (CHECK_VERSION(1, 3) || + buffer_device_address_feature.bufferDeviceAddress) { + if (device_supported_features.shaderInt64) { + ti_device_->set_cap( + DeviceCapability::spirv_has_physical_storage_buffer, true); + } + } + *pNextEnd = &buffer_device_address_feature; + pNextEnd = &buffer_device_address_feature.pNext; + } + // TODO: add atomic min/max feature } @@ -546,14 +696,18 @@ void EmbeddedVulkanDevice::create_logical_device() { "failed to create logical device"); VulkanLoader::instance().load_device(device_); - if (params_.is_for_ui) { + if (queue_family_indices_.compute_family.has_value()) { + vkGetDeviceQueue(device_, queue_family_indices_.compute_family.value(), 0, + &compute_queue_); + } + if (queue_family_indices_.graphics_family.has_value()) { vkGetDeviceQueue(device_, queue_family_indices_.graphics_family.value(), 0, &graphics_queue_); } - vkGetDeviceQueue(device_, queue_family_indices_.compute_family.value(), 0, - &compute_queue_); -} // namespace vulkan + // Dump capabilities + ti_device_->print_all_cap(); +} } // namespace vulkan } // namespace lang diff --git a/taichi/backends/vulkan/embedded_device.h b/taichi/backends/vulkan/vulkan_device_creator.h similarity index 91% rename from taichi/backends/vulkan/embedded_device.h rename to taichi/backends/vulkan/vulkan_device_creator.h index 1be3f7e11e94e..c9da956d7c409 100644 --- a/taichi/backends/vulkan/embedded_device.h +++ b/taichi/backends/vulkan/vulkan_device_creator.h @@ -4,10 +4,7 @@ #define VK_USE_PLATFORM_WIN32_KHR 1 #endif -#include -#define VK_NO_PROTOTYPES -#include -#include +#include "taichi/backends/vulkan/vulkan_common.h" #include @@ -45,7 +42,7 @@ struct VulkanQueueFamilyIndices { * This class creates a VulkanDevice instance. The underlying Vk* resources are * embedded directly inside the class. */ -class EmbeddedVulkanDevice { +class TI_DLL_EXPORT VulkanDeviceCreator { public: struct Params { std::optional api_version; @@ -58,8 +55,8 @@ class EmbeddedVulkanDevice { std::function surface_creator; }; - explicit EmbeddedVulkanDevice(const Params ¶ms); - ~EmbeddedVulkanDevice(); + explicit VulkanDeviceCreator(const Params ¶ms); + ~VulkanDeviceCreator(); const VulkanDevice *device() const { return ti_device_.get(); diff --git a/taichi/backends/vulkan/vulkan_loader.cpp b/taichi/backends/vulkan/vulkan_loader.cpp new file mode 100644 index 0000000000000..6f0044245abb9 --- /dev/null +++ b/taichi/backends/vulkan/vulkan_loader.cpp @@ -0,0 +1,137 @@ +#include "taichi/backends/vulkan/vulkan_common.h" + +#include "taichi/lang_util.h" +#include "taichi/backends/vulkan/vulkan_loader.h" +#include "taichi/common/logging.h" + +namespace taichi { +namespace lang { +namespace vulkan { + +VulkanLoader::VulkanLoader() { +} + +bool VulkanLoader::check_vulkan_device() { + bool found_device_with_compute = false; + + // We create an temporary Vulkan instance to probe the Vulkan devices. + // Otherwise, in the case of a CPU only VM with Vulkan installed, Vulkan will + // not run as there is no GPU available, but the fallback will not happen + // because Vulkan API is available. + + VkApplicationInfo app_info{}; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pApplicationName = "Checking Vulkan Device"; + app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0); + app_info.pEngineName = "No Engine"; + app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0); + app_info.apiVersion = VK_API_VERSION_1_0; + + VkInstanceCreateInfo create_info{}; + create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + create_info.pApplicationInfo = &app_info; + + VkInstance instance{VK_NULL_HANDLE}; + VkResult res = vkCreateInstance(&create_info, kNoVkAllocCallbacks, &instance); + + do { + if (res != VK_SUCCESS) { + TI_WARN("Can not create Vulkan instance"); + break; + } + + load_instance(instance); + + uint32_t device_count = 0; + vkEnumeratePhysicalDevices(instance, &device_count, nullptr); + + if (device_count == 0) { + TI_WARN("Can not find Vulkan capable devices"); + break; + } + + std::vector devices(device_count); + vkEnumeratePhysicalDevices(instance, &device_count, devices.data()); + + for (int i = 0; i < devices.size(); i++) { + const auto &physical_device = devices[i]; + + uint32_t queue_family_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, + &queue_family_count, nullptr); + if (queue_family_count > 0) { + std::vector queue_families(queue_family_count); + vkGetPhysicalDeviceQueueFamilyProperties( + physical_device, &queue_family_count, queue_families.data()); + + for (auto &queue : queue_families) { + if (queue.queueFlags & VK_QUEUE_COMPUTE_BIT) { + found_device_with_compute = true; + } + } + } + } + } while (false); + + if (instance) { + vkDestroyInstance(instance, kNoVkAllocCallbacks); + } + + return found_device_with_compute; +} + +bool VulkanLoader::init() { + std::call_once(init_flag_, [&]() { + if (initialized) { + return; + } +#if defined(TI_EMSCRIPTENED) + initialized = true; +#elif defined(__APPLE__) + vulkan_rt_ = std::make_unique(runtime_lib_dir() + "/libMoltenVK.dylib"); + PFN_vkGetInstanceProcAddr get_proc_addr = (PFN_vkGetInstanceProcAddr)vulkan_rt_->load_function("vkGetInstanceProcAddr"); + + volkInitializeCustom(get_proc_addr); + initialized = true; +#else + VkResult result = volkInitialize(); + initialized = result == VK_SUCCESS; +#endif + initialized = initialized && check_vulkan_device(); + }); + return initialized; +} + +void VulkanLoader::load_instance(VkInstance instance) { + vulkan_instance_ = instance; +#if defined(TI_EMSCRIPTENED) +#else + volkLoadInstance(instance); +#endif +} +void VulkanLoader::load_device(VkDevice device) { + vulkan_device_ = device; +#if defined(TI_EMSCRIPTENED) +#else + volkLoadDevice(device); +#endif +} + +PFN_vkVoidFunction VulkanLoader::load_function(const char *name) { + auto result = + vkGetInstanceProcAddr(VulkanLoader::instance().vulkan_instance_, name); + TI_WARN_IF(result == nullptr, "loaded vulkan function {} is nullptr", name); + return result; +} + +bool is_vulkan_api_available() { + return VulkanLoader::instance().init(); +} + +void set_vulkan_visible_device(std::string id) { + VulkanLoader::instance().visible_device_id = id; +} + +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/loader.h b/taichi/backends/vulkan/vulkan_loader.h similarity index 68% rename from taichi/backends/vulkan/loader.h rename to taichi/backends/vulkan/vulkan_loader.h index 220a7411e551f..408b0262d71c0 100644 --- a/taichi/backends/vulkan/loader.h +++ b/taichi/backends/vulkan/vulkan_loader.h @@ -3,7 +3,8 @@ #include #include -#include +#include "taichi/backends/vulkan/vulkan_common.h" +#include "taichi/system/dynamic_loader.h" namespace taichi { namespace lang { @@ -20,10 +21,16 @@ class VulkanLoader { VulkanLoader(VulkanLoader const &) = delete; void operator=(VulkanLoader const &) = delete; + bool check_vulkan_device(); + void load_instance(VkInstance instance_); void load_device(VkDevice device_); bool init(); PFN_vkVoidFunction load_function(const char *name); + VkInstance get_instance() { + return vulkan_instance_; + } + std::string visible_device_id; private: std::once_flag init_flag_; @@ -31,12 +38,18 @@ class VulkanLoader { VulkanLoader(); +#if defined(__APPLE__) + std::unique_ptr vulkan_rt_{nullptr}; +#endif + VkInstance vulkan_instance_{VK_NULL_HANDLE}; VkDevice vulkan_device_{VK_NULL_HANDLE}; }; bool is_vulkan_api_available(); +void set_vulkan_visible_device(std::string id); + } // namespace vulkan } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/vulkan_memory_allocator.cpp b/taichi/backends/vulkan/vulkan_memory_allocator.cpp index 61ed8d8f2ec44..7c3523845c418 100644 --- a/taichi/backends/vulkan/vulkan_memory_allocator.cpp +++ b/taichi/backends/vulkan/vulkan_memory_allocator.cpp @@ -1,5 +1,6 @@ -#include "embedded_device.h" +#include "vulkan_device_creator.h" #define VMA_IMPLEMENTATION #define VMA_DEDICATED_ALLOCATION 0 +#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0 #include "vk_mem_alloc.h" diff --git a/taichi/backends/vulkan/vulkan_program.cpp b/taichi/backends/vulkan/vulkan_program.cpp index 80f749456fb8a..78d996243da4c 100644 --- a/taichi/backends/vulkan/vulkan_program.cpp +++ b/taichi/backends/vulkan/vulkan_program.cpp @@ -1,13 +1,88 @@ #include "taichi/backends/vulkan/vulkan_program.h" + +#include "taichi/backends/vulkan/aot_module_builder_impl.h" +#include "taichi/backends/vulkan/snode_tree_manager.h" + +#if !defined(ANDROID) && !defined(TI_EMSCRIPTENED) +#include "GLFW/glfw3.h" +#endif + using namespace taichi::lang::vulkan; namespace taichi { namespace lang { +namespace { +std::vector get_required_instance_extensions() { +#ifdef ANDROID + std::vector extensions; + + extensions.push_back(VK_KHR_SURFACE_EXTENSION_NAME); + extensions.push_back(VK_KHR_ANDROID_SURFACE_EXTENSION_NAME); + extensions.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); + + return extensions; +#else + std::vector extensions; + +#ifndef TI_EMSCRIPTENED + uint32_t glfw_ext_count = 0; + const char **glfw_extensions; + glfw_extensions = glfwGetRequiredInstanceExtensions(&glfw_ext_count); + + for (int i = 0; i < glfw_ext_count; ++i) { + extensions.push_back(glfw_extensions[i]); + } +#endif + // VulkanDeviceCreator will check that these are supported + extensions.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); +#if TI_WITH_CUDA + // so that we can do cuda-vk interop + extensions.push_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME); + extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME); +#endif // TI_WITH_CUDA + return extensions; +#endif +} + +std::vector get_required_device_extensions() { + static std::vector extensions { + VK_KHR_SWAPCHAIN_EXTENSION_NAME, +#if TI_WITH_CUDA + // so that we can do cuda-vk interop + VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME, +#ifdef _WIN64 + VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME, +#else + VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME, +#endif +#endif // TI_WITH_CUDA + }; + + return extensions; +} +} // namespace + +VulkanProgramImpl::VulkanProgramImpl(CompileConfig &config) + : ProgramImpl(config) { +} + +FunctionType compile_to_executable(Kernel *kernel, + VkRuntime *runtime, + SNodeTreeManager *snode_tree_mgr) { + auto handle = runtime->register_taichi_kernel( + std::move(run_codegen(kernel, runtime->get_ti_device(), + snode_tree_mgr->get_compiled_structs()))); + return [runtime, handle](RuntimeContext &ctx) { + runtime->launch_kernel(handle, &ctx); + }; +} + FunctionType VulkanProgramImpl::compile(Kernel *kernel, OffloadedStmt *offloaded) { - vulkan::lower(kernel); - return vulkan::compile_to_executable(kernel, vulkan_runtime_.get()); + spirv::lower(kernel); + return compile_to_executable(kernel, vulkan_runtime_.get(), + snode_tree_mgr_.get()); } void VulkanProgramImpl::materialize_runtime(MemoryPool *memory_pool, @@ -16,25 +91,104 @@ void VulkanProgramImpl::materialize_runtime(MemoryPool *memory_pool, *result_buffer_ptr = (uint64 *)memory_pool->allocate( sizeof(uint64) * taichi_result_buffer_entries, 8); - EmbeddedVulkanDevice::Params evd_params; +#ifndef TI_EMSCRIPTENED +// Android is meant to be embedded in other application only so the creation of +// the device and other states is left to the caller/host. +// The following code is only used when Taichi is running on its own. +#ifndef ANDROID + GLFWwindow *glfw_window = nullptr; +#ifdef __APPLE__ + glfwInitVulkanLoader(vkGetInstanceProcAddr); +#endif + + if (glfwInit()) { + // glfw init success + glfwWindowHint(GLFW_VISIBLE, GLFW_FALSE); + glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); + glfwWindowHint(GLFW_COCOA_MENUBAR, GLFW_FALSE); + glfw_window = glfwCreateWindow(1, 1, "Dummy Window", nullptr, nullptr); + + if (glfwVulkanSupported() != GLFW_TRUE) { + TI_WARN("GLFW reports no Vulkan support"); + } + } +#endif +#endif + + VulkanDeviceCreator::Params evd_params; evd_params.api_version = VulkanEnvSettings::kApiVersion(); - embedded_device_ = std::make_unique(evd_params); +#if !defined(ANDROID) && !defined(TI_EMSCRIPTENED) + if (glfw_window) { + // then we should be able to create a device with graphics abilities + evd_params.additional_instance_extensions = + get_required_instance_extensions(); + evd_params.additional_device_extensions = get_required_device_extensions(); + evd_params.is_for_ui = true; + evd_params.surface_creator = [&](VkInstance instance) -> VkSurfaceKHR { + VkSurfaceKHR surface = VK_NULL_HANDLE; + TI_TRACE("before glfwCreateWindowSurface {} {}", (void *)glfw_window, + (void *)instance); + uint status = VK_SUCCESS; + if ((status = glfwCreateWindowSurface(instance, glfw_window, nullptr, + &surface)) != VK_SUCCESS) { + TI_ERROR("Failed to create window surface! err: {}", status); + throw std::runtime_error("failed to create window surface!"); + } + return surface; + }; + } +#endif + + embedded_device_ = std::make_unique(evd_params); vulkan::VkRuntime::Params params; params.host_result_buffer = *result_buffer_ptr; params.device = embedded_device_->device(); vulkan_runtime_ = std::make_unique(std::move(params)); + snode_tree_mgr_ = + std::make_unique(vulkan_runtime_.get()); +} + +void VulkanProgramImpl::compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) { + if (vulkan_runtime_) { + snode_tree_mgr_->materialize_snode_tree(tree); + } else { + CompiledSNodeStructs compiled_structs = + vulkan::compile_snode_structs(*tree->root()); + aot_compiled_snode_structs_.push_back(compiled_structs); + } } void VulkanProgramImpl::materialize_snode_tree( SNodeTree *tree, std::vector> &, - std::unordered_map &, uint64 *result_buffer) { - vulkan_runtime_->materialize_snode_tree(tree); + snode_tree_mgr_->materialize_snode_tree(tree); +} + +std::unique_ptr VulkanProgramImpl::make_aot_module_builder() { + if (vulkan_runtime_) { + return std::make_unique( + snode_tree_mgr_->get_compiled_structs()); + } else { + return std::make_unique(aot_compiled_snode_structs_); + } +} + +DeviceAllocation VulkanProgramImpl::allocate_memory_ndarray( + std::size_t alloc_size, + uint64 *result_buffer) { + auto &ndarray = + ref_ndarry_.emplace_back(get_compute_device()->allocate_memory_unique( + {alloc_size, /*host_write=*/false, /*host_read=*/false, + /*export_sharing=*/false})); + return *ndarray; } VulkanProgramImpl::~VulkanProgramImpl() { + ref_ndarry_.clear(); vulkan_runtime_.reset(); embedded_device_.reset(); } diff --git a/taichi/backends/vulkan/vulkan_program.h b/taichi/backends/vulkan/vulkan_program.h index dd2fcf55ea872..a94f2abbb2ba7 100644 --- a/taichi/backends/vulkan/vulkan_program.h +++ b/taichi/backends/vulkan/vulkan_program.h @@ -1,19 +1,22 @@ #pragma once -#include "taichi/backends/vulkan/codegen_vulkan.h" +#include "taichi/codegen/spirv/spirv_codegen.h" +#include "taichi/codegen/spirv/snode_struct_compiler.h" +#include "taichi/codegen/spirv/kernel_utils.h" + +#include "taichi/backends/vulkan/vulkan_device_creator.h" +#include "taichi/backends/vulkan/vulkan_utils.h" +#include "taichi/backends/vulkan/vulkan_loader.h" #include "taichi/backends/vulkan/runtime.h" -#include "taichi/backends/vulkan/snode_struct_compiler.h" +#include "taichi/backends/vulkan/snode_tree_manager.h" +#include "taichi/backends/vulkan/vulkan_device.h" +#include "vk_mem_alloc.h" + #include "taichi/system/memory_pool.h" #include "taichi/common/logging.h" #include "taichi/struct/snode_tree.h" #include "taichi/program/snode_expr_utils.h" #include "taichi/program/program_impl.h" - -#include "taichi/backends/vulkan/embedded_device.h" -#include "taichi/backends/vulkan/vulkan_utils.h" -#include "taichi/backends/vulkan/loader.h" - -#include "vk_mem_alloc.h" -#include "taichi/backends/vulkan/vulkan_device.h" +#include "taichi/program/program.h" #include @@ -21,13 +24,12 @@ namespace taichi { namespace lang { namespace vulkan { -class EmbeddedVulkanDevice; +class VulkanDeviceCreator; } class VulkanProgramImpl : public ProgramImpl { public: - VulkanProgramImpl(CompileConfig &config) : ProgramImpl(config) { - } + VulkanProgramImpl(CompileConfig &config); FunctionType compile(Kernel *kernel, OffloadedStmt *offloaded) override; std::size_t get_snode_num_dynamically_allocated( @@ -36,33 +38,60 @@ class VulkanProgramImpl : public ProgramImpl { return 0; // TODO: support sparse in vulkan } + void compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) override; + void materialize_runtime(MemoryPool *memory_pool, KernelProfilerBase *profiler, uint64 **result_buffer_ptr) override; void materialize_snode_tree(SNodeTree *tree, std::vector> &, - std::unordered_map &, uint64 *result_buffer) override; void synchronize() override { vulkan_runtime_->synchronize(); } - std::unique_ptr make_aot_module_builder() override { - // TODO: implement vk aot + std::unique_ptr make_aot_module_builder() override; + + virtual void destroy_snode_tree(SNodeTree *snode_tree) override { + TI_ASSERT(snode_tree_mgr_ != nullptr); + snode_tree_mgr_->destroy_snode_tree(snode_tree); + } + + DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, + uint64 *result_buffer) override; + + Device *get_compute_device() override { + if (embedded_device_) { + return embedded_device_->device(); + } return nullptr; } - virtual void destroy_snode_tree(SNodeTree *snode_tree) override { - vulkan_runtime_->destroy_snode_tree(snode_tree); + Device *get_graphics_device() override { + if (embedded_device_) { + return embedded_device_->device(); + } + return nullptr; } - ~VulkanProgramImpl() override; + DevicePtr get_snode_tree_device_ptr(int tree_id) override { + return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); + } + + ~VulkanProgramImpl(); private: - std::unique_ptr embedded_device_{nullptr}; - std::unique_ptr vulkan_runtime_; + std::unique_ptr embedded_device_{nullptr}; + std::unique_ptr vulkan_runtime_{nullptr}; + std::unique_ptr snode_tree_mgr_{nullptr}; + std::vector aot_compiled_snode_structs_; + + // This is a hack until NDArray is properlly owned by programs + std::vector> ref_ndarry_; }; } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/vulkan_utils.h b/taichi/backends/vulkan/vulkan_utils.h index 0be74a99e3487..675bacd80587e 100644 --- a/taichi/backends/vulkan/vulkan_utils.h +++ b/taichi/backends/vulkan/vulkan_utils.h @@ -8,9 +8,7 @@ #include #endif -#include -#define VK_NO_PROTOTYPES -#include +#include "taichi/backends/vulkan/vulkan_common.h" #include #include diff --git a/taichi/backends/wasm/aot_module_builder_impl.cpp b/taichi/backends/wasm/aot_module_builder_impl.cpp index ea5031a74c4a0..7f50aaa606f60 100644 --- a/taichi/backends/wasm/aot_module_builder_impl.cpp +++ b/taichi/backends/wasm/aot_module_builder_impl.cpp @@ -42,7 +42,8 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, name_list_.push_back(name); } -void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier, +void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, diff --git a/taichi/backends/wasm/aot_module_builder_impl.h b/taichi/backends/wasm/aot_module_builder_impl.h index 10b73dcff69b3..d4bef96444446 100644 --- a/taichi/backends/wasm/aot_module_builder_impl.h +++ b/taichi/backends/wasm/aot_module_builder_impl.h @@ -3,8 +3,9 @@ #include #include -#include "taichi/program/aot_module_builder.h" +#include "taichi/aot/module_builder.h" #include "taichi/program/kernel.h" +#include "taichi/llvm/llvm_fwd.h" #include "taichi/backends/wasm/codegen_wasm.h" @@ -24,7 +25,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder { void add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) override; - void add_per_backend_field(const std::string &Identifier, + void add_field_per_backend(const std::string &Identifier, + const SNode *rep_snode, bool is_scalar, DataType dt, std::vector shape, diff --git a/taichi/backends/wasm/codegen_wasm.cpp b/taichi/backends/wasm/codegen_wasm.cpp index 81c590040a396..a49d70b2f3886 100644 --- a/taichi/backends/wasm/codegen_wasm.cpp +++ b/taichi/backends/wasm/codegen_wasm.cpp @@ -231,7 +231,7 @@ class CodeGenLLVMWASM : public CodeGenLLVM { }); tlctx->add_module(std::move(module)); auto kernel_symbol = tlctx->lookup_function_pointer(offloaded_task_name); - return [=](Context &context) { + return [=](RuntimeContext &context) { TI_TRACE("Launching Taichi Kernel Function"); auto func = (int32(*)(void *))kernel_symbol; func(&context); diff --git a/taichi/backends/wasm/codegen_wasm.h b/taichi/backends/wasm/codegen_wasm.h index 0171afc3f8465..2a92711ecfd87 100644 --- a/taichi/backends/wasm/codegen_wasm.h +++ b/taichi/backends/wasm/codegen_wasm.h @@ -4,11 +4,14 @@ #include "taichi/codegen/codegen.h" +#ifdef TI_WITH_LLVM #include "llvm/IR/Module.h" +#endif namespace taichi { namespace lang { +#ifdef TI_WITH_LLVM class ModuleGenValue { public: ModuleGenValue(std::unique_ptr module, @@ -18,6 +21,7 @@ class ModuleGenValue { std::unique_ptr module; std::vector name_list; }; +#endif class CodeGenWASM : public KernelCodeGen { public: @@ -25,10 +29,12 @@ class CodeGenWASM : public KernelCodeGen { : KernelCodeGen(kernel, ir) { } - virtual FunctionType codegen() override; + FunctionType codegen() override; +#ifdef TI_WITH_LLVM std::unique_ptr modulegen( std::unique_ptr &&module); // AOT Module Gen +#endif }; } // namespace lang diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index ce85aad4ef6cf..263f15d039d4d 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -3,8 +3,10 @@ #include "codegen.h" #include "taichi/util/statistics.h" +#if defined(TI_WITH_LLVM) #include "taichi/backends/cpu/codegen_cpu.h" #include "taichi/backends/wasm/codegen_wasm.h" +#endif #if defined(TI_WITH_CUDA) #include "taichi/backends/cuda/codegen_cuda.h" #endif @@ -31,6 +33,7 @@ KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir) std::unique_ptr KernelCodeGen::create(Arch arch, Kernel *kernel, Stmt *stmt) { +#ifdef TI_WITH_LLVM if (arch_is_cpu(arch) && arch != Arch::wasm) { return std::make_unique(kernel, stmt); } else if (arch == Arch::wasm) { @@ -44,6 +47,9 @@ std::unique_ptr KernelCodeGen::create(Arch arch, } else { TI_NOT_IMPLEMENTED } +#else + TI_ERROR("Llvm disabled"); +#endif } TLANG_NAMESPACE_END diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index a372ea8f481dd..241cb15161bb1 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1,9 +1,14 @@ +#ifdef TI_WITH_LLVM #include "taichi/codegen/codegen_llvm.h" #include "taichi/ir/statements.h" #include "taichi/struct/struct_llvm.h" #include "taichi/util/file_sequence_writer.h" +#include "llvm/IR/Module.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Linker/Linker.h" + TLANG_NAMESPACE_BEGIN // TODO: sort function definitions to match declaration order in header @@ -22,7 +27,7 @@ void OffloadedTask::end() { codegen->offloaded_tasks.push_back(*this); } -void OffloadedTask::operator()(Context *context) { +void OffloadedTask::operator()(RuntimeContext *context) { TI_ASSERT(func); func(context); } @@ -119,9 +124,6 @@ CodeGenStmtGuard make_while_after_loop_guard(CodeGenLLVM *cg) { } // namespace // CodeGenLLVM - -uint64 CodeGenLLVM::task_counter = 0; - void CodeGenLLVM::visit(Block *stmt_list) { for (auto &stmt : stmt_list->statements) { stmt->accept(this); @@ -135,18 +137,6 @@ void CodeGenLLVM::visit(AllocaStmt *stmt) { auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size); - // Initialize as zero - for (int i = 0; i < tensor_type->get_num_elements(); ++i) { - auto origin_address = builder->CreatePtrToInt( - llvm_val[stmt], llvm::Type::getInt64Ty(*llvm_context)); - int address_offset = i * data_type_size(tensor_type->get_element_type()); - auto target_address = builder->CreateAdd( - origin_address, tlctx->get_constant((int64)address_offset)); - auto target_ptr = builder->CreateIntToPtr( - target_address, llvm::PointerType::get(type, 0)); - builder->CreateStore( - tlctx->get_constant(tensor_type->get_element_type(), 0), target_ptr); - } } else { TI_ASSERT(stmt->width() == 1); llvm_val[stmt] = @@ -159,27 +149,39 @@ void CodeGenLLVM::visit(AllocaStmt *stmt) { } void CodeGenLLVM::visit(RandStmt *stmt) { - llvm_val[stmt] = create_call( - fmt::format("rand_{}", data_type_name(stmt->ret_type)), {get_context()}); + if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + // Promoting to f32 since there's no rand_f16 support in runtime.cpp. + auto val_f32 = create_call("rand_f32", {get_context()}); + llvm_val[stmt] = + builder->CreateFPTrunc(val_f32, llvm::Type::getHalfTy(*llvm_context)); + } else { + llvm_val[stmt] = + create_call(fmt::format("rand_{}", data_type_name(stmt->ret_type)), + {get_context()}); + } } void CodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { auto input = llvm_val[stmt->operand]; auto input_taichi_type = stmt->operand->ret_type; + if (input_taichi_type->is_primitive(PrimitiveTypeID::f16)) { + // Promote to f32 since we don't have f16 support for extra unary ops in in + // runtime.cpp. + input = builder->CreateFPExt(input, llvm::Type::getFloatTy(*llvm_context)); + input_taichi_type = PrimitiveType::f32; + } + auto op = stmt->op_type; auto input_type = input->getType(); #define UNARY_STD(x) \ else if (op == UnaryOpType::x) { \ if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { \ - llvm_val[stmt] = \ - builder->CreateCall(get_runtime_function(#x "_f32"), input); \ + llvm_val[stmt] = create_call(#x "_f32", input); \ } else if (input_taichi_type->is_primitive(PrimitiveTypeID::f64)) { \ - llvm_val[stmt] = \ - builder->CreateCall(get_runtime_function(#x "_f64"), input); \ + llvm_val[stmt] = create_call(#x "_f64", input); \ } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \ - llvm_val[stmt] = \ - builder->CreateCall(get_runtime_function(#x "_i32"), input); \ + llvm_val[stmt] = create_call(#x "_i32", input); \ } else { \ TI_NOT_IMPLEMENTED \ } \ @@ -206,6 +208,11 @@ void CodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { TI_NOT_IMPLEMENTED } #undef UNARY_STD + if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + // Convert back to f16 + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); + } } std::unique_ptr CodeGenLLVM::emit_struct_meta_object( @@ -306,7 +313,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, this->ir = kernel->ir.get(); initialize_context(); - context_ty = get_runtime_type("Context"); + context_ty = get_runtime_type("RuntimeContext"); physical_coordinate_ty = get_runtime_type(kLLVMPhysicalCoordinatesName); kernel_name = kernel->name + "_kernel"; @@ -334,6 +341,9 @@ llvm::Value *CodeGenLLVM::cast_int(llvm::Value *input_val, } } +void CodeGenLLVM::visit(DecorationStmt *stmt) { +} + void CodeGenLLVM::visit(UnaryOpStmt *stmt) { auto input = llvm_val[stmt->operand]; auto input_type = input->getType(); @@ -352,24 +362,41 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm_val[stmt] = llvm_val[stmt->operand]; } else if (is_real(from) != is_real(to)) { if (is_real(from) && is_integral(to)) { - cast_op = llvm::Instruction::CastOps::FPToSI; + cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI + : llvm::Instruction::CastOps::FPToUI; } else if (is_integral(from) && is_real(to)) { - cast_op = llvm::Instruction::CastOps::SIToFP; + cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP + : llvm::Instruction::CastOps::UIToFP; } else { TI_P(data_type_name(from)); TI_P(data_type_name(to)); TI_NOT_IMPLEMENTED; } - llvm_val[stmt] = - builder->CreateCast(cast_op, llvm_val[stmt->operand], - tlctx->get_data_type(stmt->cast_type)); + auto cast_type = to->is_primitive(PrimitiveTypeID::f16) + ? PrimitiveType::f32 + : stmt->cast_type; + + llvm_val[stmt] = builder->CreateCast(cast_op, llvm_val[stmt->operand], + tlctx->get_data_type(cast_type)); + + if (to->is_primitive(PrimitiveTypeID::f16)) { + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); + } } else if (is_real(from) && is_real(to)) { if (data_type_size(from) < data_type_size(to)) { llvm_val[stmt] = builder->CreateFPExt( llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); } else { - llvm_val[stmt] = builder->CreateFPTrunc( - llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); + if (to->is_primitive(PrimitiveTypeID::f16)) { + llvm_val[stmt] = builder->CreateFPTrunc( + builder->CreateFPTrunc(llvm_val[stmt->operand], + llvm::Type::getFloatTy(*llvm_context)), + llvm::Type::getHalfTy(*llvm_context)); + } else { + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type)); + } } } else if (!is_real(from) && !is_real(to)) { // TODO: implement casting into custom integer type @@ -402,15 +429,19 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateNeg(input, "neg"); } } + UNARY_INTRINSIC(round) UNARY_INTRINSIC(floor) UNARY_INTRINSIC(ceil) - else emit_extra_unary(stmt); + else { + emit_extra_unary(stmt); + } #undef UNARY_INTRINSIC } void CodeGenLLVM::visit(BinaryOpStmt *stmt) { auto op = stmt->op_type; auto ret_type = stmt->ret_type; + if (op == BinaryOpType::add) { if (is_real(stmt->ret_type)) { llvm_val[stmt] = @@ -477,88 +508,44 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateLShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::max) { +#define BINARYOP_MAX(x) \ + else if (ret_type->is_primitive(PrimitiveTypeID::x)) { \ + llvm_val[stmt] = \ + create_call("max_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ + } + if (is_real(ret_type)) { llvm_val[stmt] = builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - create_call("max_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else { - TI_P(data_type_name(ret_type)); - TI_NOT_IMPLEMENTED } - } else if (op == BinaryOpType::atan2) { - if (arch_is_cpu(current_arch())) { - if (ret_type->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = create_call( - "atan2_f32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { - llvm_val[stmt] = create_call( - "atan2_f64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else { - TI_P(data_type_name(ret_type)); - TI_NOT_IMPLEMENTED - } - } else if (current_arch() == Arch::cuda) { - if (ret_type->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = create_call( - "__nv_atan2f", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { - llvm_val[stmt] = create_call( - "__nv_atan2", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else { - TI_P(data_type_name(ret_type)); - TI_NOT_IMPLEMENTED - } - } else { - TI_NOT_IMPLEMENTED - } - } else if (op == BinaryOpType::pow) { - if (arch_is_cpu(current_arch())) { - if (ret_type->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = - create_call("pow_f32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { - llvm_val[stmt] = - create_call("pow_f64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - create_call("pow_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i64)) { - llvm_val[stmt] = - create_call("pow_i64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else { - TI_P(data_type_name(ret_type)); - TI_NOT_IMPLEMENTED - } - } else if (current_arch() == Arch::cuda) { - if (ret_type->is_primitive(PrimitiveTypeID::f32)) { - llvm_val[stmt] = create_call( - "__nv_powf", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { - llvm_val[stmt] = - create_call("__nv_pow", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - create_call("pow_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type->is_primitive(PrimitiveTypeID::i64)) { - llvm_val[stmt] = - create_call("pow_i64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else { - TI_P(data_type_name(ret_type)); - TI_NOT_IMPLEMENTED - } - } else { + BINARYOP_MAX(u16) + BINARYOP_MAX(i16) + BINARYOP_MAX(u32) + BINARYOP_MAX(i32) + BINARYOP_MAX(u64) + BINARYOP_MAX(i64) + else { + TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED } } else if (op == BinaryOpType::min) { +#define BINARYOP_MIN(x) \ + else if (ret_type->is_primitive(PrimitiveTypeID::x)) { \ + llvm_val[stmt] = \ + create_call("min_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ + } + if (is_real(ret_type)) { llvm_val[stmt] = builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { - llvm_val[stmt] = - create_call("min_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else { + } + BINARYOP_MIN(u16) + BINARYOP_MIN(i16) + BINARYOP_MIN(u32) + BINARYOP_MIN(i32) + BINARYOP_MIN(u64) + BINARYOP_MIN(i64) + else { TI_P(data_type_name(ret_type)); TI_NOT_IMPLEMENTED } @@ -630,8 +617,61 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(PrimitiveType::i32)); } else { - TI_P(binary_op_type_name(op)); - TI_NOT_IMPLEMENTED + // This branch contains atan2 and pow which use runtime.cpp function for + // **real** type. We don't have f16 support there so promoting to f32 is + // necessary. + llvm::Value *lhs = llvm_val[stmt->lhs]; + llvm::Value *rhs = llvm_val[stmt->rhs]; + if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { + lhs = builder->CreateFPExt(lhs, llvm::Type::getFloatTy(*llvm_context)); + } + if (stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::f16)) { + rhs = builder->CreateFPExt(rhs, llvm::Type::getFloatTy(*llvm_context)); + } + if (ret_type->is_primitive(PrimitiveTypeID::f16)) { + ret_type = PrimitiveType::f32; + } + + if (op == BinaryOpType::atan2) { + if (arch_is_cpu(current_arch())) { + if (ret_type->is_primitive(PrimitiveTypeID::f32)) { + llvm_val[stmt] = create_call("atan2_f32", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { + llvm_val[stmt] = create_call("atan2_f64", {lhs, rhs}); + } else { + TI_P(data_type_name(ret_type)); + TI_NOT_IMPLEMENTED + } + } else { + TI_NOT_IMPLEMENTED + } + } else if (op == BinaryOpType::pow) { + if (arch_is_cpu(current_arch())) { + if (ret_type->is_primitive(PrimitiveTypeID::f32)) { + llvm_val[stmt] = create_call("pow_f32", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::f64)) { + llvm_val[stmt] = create_call("pow_f64", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::i32)) { + llvm_val[stmt] = create_call("pow_i32", {lhs, rhs}); + } else if (ret_type->is_primitive(PrimitiveTypeID::i64)) { + llvm_val[stmt] = create_call("pow_i64", {lhs, rhs}); + } else { + TI_P(data_type_name(ret_type)); + TI_NOT_IMPLEMENTED + } + } else { + TI_NOT_IMPLEMENTED + } + } else { + TI_P(binary_op_type_name(op)); + TI_NOT_IMPLEMENTED + } + + // Convert back to f16 if applicable. + if (stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { + llvm_val[stmt] = builder->CreateFPTrunc( + llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context)); + } } } @@ -654,6 +694,8 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { return llvm::Type::getFloatTy(*llvm_context); } else if (dt->is_primitive(PrimitiveTypeID::f64)) { return llvm::Type::getDoubleTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + return llvm::Type::getHalfTy(*llvm_context); } else { TI_NOT_IMPLEMENTED; } @@ -712,7 +754,7 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, value = builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64)); args.push_back(value); - return builder->CreateCall(runtime_printf, args); + return create_call(runtime_printf, args); } llvm::Value *CodeGenLLVM::create_print(std::string tag, llvm::Value *value) { @@ -726,6 +768,23 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, llvm::Value *value) { tag, TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32), value); + else if (value->getType() == llvm::Type::getHalfTy(*llvm_context)) { + auto extended = + builder->CreateFPExt(value, llvm::Type::getFloatTy(*llvm_context)); + return create_print( + tag, + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32), + extended); + } else if (value->getType() == llvm::Type::getInt64Ty(*llvm_context)) + return create_print( + tag, + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i64), + value); + else if (value->getType() == llvm::Type::getInt16Ty(*llvm_context)) + return create_print( + tag, + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i16), + value); else TI_NOT_IMPLEMENTED } @@ -738,7 +797,8 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32)) + if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || + arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) value = builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64)); args.push_back(value); @@ -754,7 +814,7 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { args.insert(args.begin(), builder->CreateGlobalStringPtr(formats.c_str(), "format_string")); - llvm_val[stmt] = builder->CreateCall(runtime_printf, args); + llvm_val[stmt] = create_call(runtime_printf, args); } void CodeGenLLVM::visit(ConstStmt *stmt) { @@ -763,9 +823,18 @@ void CodeGenLLVM::visit(ConstStmt *stmt) { if (val.dt->is_primitive(PrimitiveTypeID::f32)) { llvm_val[stmt] = llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float32())); + } else if (val.dt->is_primitive(PrimitiveTypeID::f16)) { + llvm_val[stmt] = llvm::ConstantFP::get(llvm::Type::getHalfTy(*llvm_context), + val.val_float32()); } else if (val.dt->is_primitive(PrimitiveTypeID::f64)) { llvm_val[stmt] = llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float64())); + } else if (val.dt->is_primitive(PrimitiveTypeID::i16)) { + llvm_val[stmt] = llvm::ConstantInt::get( + *llvm_context, llvm::APInt(16, (uint64)val.val_int16(), true)); + } else if (val.dt->is_primitive(PrimitiveTypeID::u16)) { + llvm_val[stmt] = llvm::ConstantInt::get( + *llvm_context, llvm::APInt(16, (uint64)val.val_uint16(), false)); } else if (val.dt->is_primitive(PrimitiveTypeID::i32)) { llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(32, (uint64)val.val_int32(), true)); @@ -798,7 +867,16 @@ void CodeGenLLVM::visit(WhileControlStmt *stmt) { void CodeGenLLVM::visit(ContinueStmt *stmt) { using namespace llvm; - if (stmt->as_return()) { + auto stmt_in_off_range_for = [stmt]() { + TI_ASSERT(stmt->scope != nullptr); + if (auto *offl = stmt->scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return offl->task_type == OffloadedStmt::TaskType::range_for; + } + return false; + }; + if (stmt_in_off_range_for()) { builder->CreateRetVoid(); } else { TI_ASSERT(current_loop_reentry != nullptr); @@ -858,12 +936,13 @@ void CodeGenLLVM::emit_gc(OffloadedStmt *stmt) { } llvm::Value *CodeGenLLVM::create_call(llvm::Value *func, - std::vector args) { + llvm::ArrayRef args) { check_func_call_signature(func, args); return builder->CreateCall(func, args); } + llvm::Value *CodeGenLLVM::create_call(std::string func_name, - std::vector args) { + llvm::ArrayRef args) { auto func = get_runtime_function(func_name); return create_call(func, args); } @@ -940,8 +1019,56 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) { create_naive_range_for(for_stmt); } +llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) { + llvm::Type *dest_ty = nullptr; + TI_ASSERT(!type->is()); + if (auto cit = type->cast()) { + if (cit->get_is_signed()) + dest_ty = tlctx->get_data_type(PrimitiveType::i32); + else + dest_ty = tlctx->get_data_type(PrimitiveType::u32); + } else { + dest_ty = tlctx->get_data_type(type); + } + auto dest_bits = dest_ty->getPrimitiveSizeInBits(); + if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) { + // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa + // which doesn't mean anything. + // So we truncate to 32 bits first and then fptrunc to half if applicable + auto truncated = + builder->CreateTrunc(val, llvm::Type::getIntNTy(*llvm_context, 32)); + auto casted = builder->CreateBitCast(truncated, + llvm::Type::getFloatTy(*llvm_context)); + return builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context)); + } else { + auto truncated = builder->CreateTrunc( + val, llvm::Type::getIntNTy(*llvm_context, dest_bits)); + + return builder->CreateBitCast(truncated, dest_ty); + } +} + +llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) { + auto intermediate_bits = 0; + if (auto cit = type->cast()) { + intermediate_bits = data_type_bits(cit->get_compute_type()); + } else { + intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits(); + } + llvm::Type *dest_ty = tlctx->get_data_type(); + llvm::Type *intermediate_type = nullptr; + if (val->getType() == llvm::Type::getHalfTy(*llvm_context)) { + val = builder->CreateFPExt(val, tlctx->get_data_type()); + intermediate_type = tlctx->get_data_type(); + } else { + intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits); + } + return builder->CreateZExt(builder->CreateBitCast(val, intermediate_type), + dest_ty); +} + void CodeGenLLVM::visit(ArgLoadStmt *stmt) { - auto raw_arg = call(builder.get(), "Context_get_args", get_context(), + auto raw_arg = call(builder.get(), "RuntimeContext_get_args", get_context(), tlctx->get_constant(stmt->arg_id)); llvm::Type *dest_ty = nullptr; @@ -950,41 +1077,24 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { - TI_ASSERT(!stmt->ret_type->is()); - if (auto cit = stmt->ret_type->cast()) { - if (cit->get_is_signed()) - dest_ty = tlctx->get_data_type(PrimitiveType::i32); - else - dest_ty = tlctx->get_data_type(PrimitiveType::u32); - } else { - dest_ty = tlctx->get_data_type(stmt->ret_type); - } - auto dest_bits = dest_ty->getPrimitiveSizeInBits(); - auto truncated = builder->CreateTrunc( - raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits)); - llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty); + llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type); } } void CodeGenLLVM::visit(ReturnStmt *stmt) { - if (stmt->ret_type.is_pointer()) { + auto types = stmt->element_types(); + if (std::any_of(types.begin(), types.end(), + [](const DataType &t) { return t.is_pointer(); })) { TI_NOT_IMPLEMENTED } else { - auto intermediate_bits = 0; - if (auto cit = stmt->value->ret_type->cast()) { - intermediate_bits = data_type_bits(cit->get_compute_type()); - } else { - intermediate_bits = - tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits(); + TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value); + int idx{0}; + for (auto &value : stmt->values) { + create_call( + "RuntimeContext_store_result", + {get_context(), bitcast_to_u64(llvm_val[value], value->ret_type), + tlctx->get_constant(idx++)}); } - llvm::Type *intermediate_type = - llvm::Type::getIntNTy(*llvm_context, intermediate_bits); - llvm::Type *dest_ty = tlctx->get_data_type(); - auto extended = builder->CreateZExt( - builder->CreateBitCast(llvm_val[stmt->value], intermediate_type), - dest_ty); - builder->CreateCall(get_runtime_function("LLVMRuntime_store_result"), - {get_runtime(), extended}); } } @@ -1071,6 +1181,41 @@ void CodeGenLLVM::visit(SNodeOpStmt *stmt) { } } +llvm::Value *CodeGenLLVM::atomic_op_using_cas( + llvm::Value *dest, + llvm::Value *val, + std::function op) { + using namespace llvm; + BasicBlock *body = BasicBlock::Create(*llvm_context, "while_loop_body", func); + BasicBlock *after_loop = + BasicBlock::Create(*llvm_context, "after_while", func); + + builder->CreateBr(body); + builder->SetInsertPoint(body); + + llvm::Value *old_val; + + { + old_val = builder->CreateLoad(dest); + auto new_val = op(old_val, val); + dest = + builder->CreateBitCast(dest, llvm::Type::getInt16PtrTy(*llvm_context)); + auto atomicCmpXchg = builder->CreateAtomicCmpXchg( + dest, + builder->CreateBitCast(old_val, llvm::Type::getInt16Ty(*llvm_context)), + builder->CreateBitCast(new_val, llvm::Type::getInt16Ty(*llvm_context)), + AtomicOrdering::SequentiallyConsistent, + AtomicOrdering::SequentiallyConsistent); + // Check whether CAS was succussful + auto ok = builder->CreateExtractValue(atomicCmpXchg, 1); + builder->CreateCondBr(builder->CreateNot(ok), body, after_loop); + } + + builder->SetInsertPoint(after_loop); + + return old_val; +} + void CodeGenLLVM::visit(AtomicOpStmt *stmt) { // auto mask = stmt->parent->mask(); // TODO: deal with mask when vectorized @@ -1098,33 +1243,49 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::min) { - if (is_integral(stmt->val->ret_type)) { + if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u32)) { + old_value = create_call("atomic_min_u32", + {llvm_val[stmt->dest], llvm_val[stmt->val]}); + } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u64)) { + old_value = create_call("atomic_min_u64", + {llvm_val[stmt->dest], llvm_val[stmt->val]}); + } else if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); + } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) { + old_value = atomic_op_using_cas( + llvm_val[stmt->dest], llvm_val[stmt->val], + [&](auto v1, auto v2) { return builder->CreateMinNum(v1, v2); }); } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f32)) { - old_value = - builder->CreateCall(get_runtime_function("atomic_min_f32"), + old_value = create_call("atomic_min_f32", {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f64)) { - old_value = - builder->CreateCall(get_runtime_function("atomic_min_f64"), + old_value = create_call("atomic_min_f64", {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED } } else if (stmt->op_type == AtomicOpType::max) { - if (is_integral(stmt->val->ret_type)) { + if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u32)) { + old_value = create_call("atomic_max_u32", + {llvm_val[stmt->dest], llvm_val[stmt->val]}); + } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::u64)) { + old_value = create_call("atomic_max_u64", + {llvm_val[stmt->dest], llvm_val[stmt->val]}); + } else if (is_integral(stmt->val->ret_type)) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); + } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f16)) { + old_value = atomic_op_using_cas( + llvm_val[stmt->dest], llvm_val[stmt->val], + [&](auto v1, auto v2) { return builder->CreateMaxNum(v1, v2); }); } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f32)) { - old_value = - builder->CreateCall(get_runtime_function("atomic_max_f32"), + old_value = create_call("atomic_max_f32", {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else if (stmt->val->ret_type->is_primitive(PrimitiveTypeID::f64)) { - old_value = - builder->CreateCall(get_runtime_function("atomic_max_f64"), + old_value = create_call("atomic_max_f64", {llvm_val[stmt->dest], llvm_val[stmt->val]}); } else { TI_NOT_IMPLEMENTED @@ -1199,7 +1360,7 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { auto val_type = ptr_type->get_pointee_type(); if (val_type->is()) { llvm_val[stmt] = load_as_custom_int(llvm_val[stmt->src], val_type); - } else if (auto cft = val_type->cast()) { + } else if (val_type->cast()) { TI_ASSERT(stmt->src->is()); llvm_val[stmt] = load_custom_float(stmt->src); } else { @@ -1401,14 +1562,19 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { } void CodeGenLLVM::visit(PtrOffsetStmt *stmt) { - auto origin_address = builder->CreatePtrToInt( - llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); - auto address_offset = builder->CreateSExt( - llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); - auto target_address = builder->CreateAdd(origin_address, address_offset); - auto dt = stmt->ret_type.ptr_removed(); - llvm_val[stmt] = builder->CreateIntToPtr( - target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); + if (stmt->is_local_ptr()) { + llvm_val[stmt] = + builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + } else { + auto origin_address = builder->CreatePtrToInt( + llvm_val[stmt->origin], llvm::Type::getInt64Ty(*llvm_context)); + auto address_offset = builder->CreateSExt( + llvm_val[stmt->offset], llvm::Type::getInt64Ty(*llvm_context)); + auto target_address = builder->CreateAdd(origin_address, address_offset); + auto dt = stmt->ret_type.ptr_removed(); + llvm_val[stmt] = builder->CreateIntToPtr( + target_address, llvm::PointerType::get(tlctx->get_data_type(dt), 0)); + } } void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { @@ -1420,8 +1586,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { std::vector sizes(num_indices); for (int i = 0; i < num_indices; i++) { - auto raw_arg = builder->CreateCall( - get_runtime_function("Context_get_extra_args"), + auto raw_arg = create_call( + "RuntimeContext_get_extra_args", {get_context(), tlctx->get_constant(arg_id), tlctx->get_constant(i)}); sizes[i] = raw_arg; } @@ -1443,8 +1609,8 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { void CodeGenLLVM::visit(ExternalTensorShapeAlongAxisStmt *stmt) { const auto arg_id = stmt->arg_id; const auto axis = stmt->axis; - llvm_val[stmt] = builder->CreateCall( - get_runtime_function("Context_get_extra_args"), + llvm_val[stmt] = create_call( + "RuntimeContext_get_extra_args", {get_context(), tlctx->get_constant(arg_id), tlctx->get_constant(axis)}); } @@ -1457,9 +1623,9 @@ std::string CodeGenLLVM::init_offloaded_task_function(OffloadedStmt *stmt, llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), {llvm::PointerType::get(context_ty, 0)}, false); - auto task_kernel_name = fmt::format("{}_{}_{}{}", kernel_name, task_counter, - stmt->task_name(), suffix); - task_counter += 1; + auto task_kernel_name = + fmt::format("{}_{}_{}{}", kernel_name, kernel->get_next_task_id(), + stmt->task_name(), suffix); func = llvm::Function::Create(task_function_type, llvm::Function::ExternalLinkage, task_kernel_name, module.get()); @@ -1549,7 +1715,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { { // Create the loop body function auto guard = get_function_creation_guard({ - llvm::PointerType::get(get_runtime_type("Context"), 0), + llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), get_tls_buffer_type(), llvm::PointerType::get(get_runtime_type("Element"), 0), tlctx->get_data_type(), @@ -1645,6 +1811,9 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { auto struct_for_body_bb = BasicBlock::Create(*llvm_context, "struct_for_body_body", func); + auto lrg = make_loop_reentry_guard(this); + current_loop_reentry = body_tail_bb; + builder->CreateBr(loop_test_bb); { @@ -1840,8 +2009,10 @@ void CodeGenLLVM::visit(LoopIndexStmt *stmt) { void CodeGenLLVM::visit(LoopLinearIndexStmt *stmt) { if (stmt->loop->is() && - stmt->loop->as()->task_type == - OffloadedStmt::TaskType::struct_for) { + (stmt->loop->as()->task_type == + OffloadedStmt::TaskType::struct_for || + stmt->loop->as()->task_type == + OffloadedStmt::TaskType::mesh_for)) { llvm_val[stmt] = create_call("thread_idx"); } else { TI_NOT_IMPLEMENTED; @@ -1862,12 +2033,6 @@ void CodeGenLLVM::visit(BlockCornerIndexStmt *stmt) { } } -void CodeGenLLVM::visit(BlockDimStmt *stmt) { - TI_NOT_IMPLEMENTED // No need for this statement for now. Untested so mark - // it as a loud failure. - llvm_val[stmt] = create_call("block_dim", {}); -} - void CodeGenLLVM::visit(GlobalTemporaryStmt *stmt) { auto runtime = get_runtime(); auto buffer = call("get_temporary_pointer", runtime, @@ -1991,6 +2156,75 @@ void CodeGenLLVM::visit(LoopUniqueStmt *stmt) { llvm_val[stmt] = llvm_val[stmt->input]; } +void CodeGenLLVM::visit_call_bitcode(ExternalFuncCallStmt *stmt) { + TI_ASSERT(stmt->type == ExternalFuncCallStmt::BITCODE); + std::vector arg_values; + for (const auto &s : stmt->arg_stmts) + arg_values.push_back(llvm_val[s]); + // Link external module to the core module + if (linked_modules.find(stmt->bc_filename) == linked_modules.end()) { + linked_modules.insert(stmt->bc_filename); + std::unique_ptr external_module = + module_from_bitcode_file(stmt->bc_filename, llvm_context); + auto *func_ptr = external_module->getFunction(stmt->bc_funcname); + TI_ASSERT_INFO(func_ptr != nullptr, "{} is not found in {}.", + stmt->bc_funcname, stmt->bc_filename); + auto link_error = + llvm::Linker::linkModules(*module, std::move(external_module)); + TI_ASSERT(!link_error); + } + // Retrieve function again. Do it here to detect name conflicting. + auto *func_ptr = module->getFunction(stmt->bc_funcname); + // Convert pointer type from a[n * m] to a[n][m] + for (int i = 0; i < func_ptr->getFunctionType()->getNumParams(); ++i) { + TI_ASSERT_INFO(func_ptr->getArg(i)->getType()->getTypeID() == + arg_values[i]->getType()->getTypeID(), + "TypeID {} != {} with {}", + (int)func_ptr->getArg(i)->getType()->getTypeID(), + (int)arg_values[i]->getType()->getTypeID(), i); + auto tmp_value = arg_values[i]; + arg_values[i] = + builder->CreatePointerCast(tmp_value, func_ptr->getArg(i)->getType()); + } + create_call(func_ptr, arg_values); +} + +void CodeGenLLVM::visit_call_shared_object(ExternalFuncCallStmt *stmt) { + TI_ASSERT(stmt->type == ExternalFuncCallStmt::SHARED_OBJECT); + std::vector arg_types; + std::vector arg_values; + + for (const auto &s : stmt->arg_stmts) { + TI_ASSERT(s->width() == 1); + arg_types.push_back(tlctx->get_data_type(s->ret_type)); + arg_values.push_back(llvm_val[s]); + } + + for (const auto &s : stmt->output_stmts) { + TI_ASSERT(s->width() == 1); + auto t = tlctx->get_data_type(s->ret_type); + auto ptr = llvm::PointerType::get(t, 0); + arg_types.push_back(ptr); + arg_values.push_back(llvm_val[s]); + } + + auto func_type = llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), + arg_types, false); + auto func_ptr_type = llvm::PointerType::get(func_type, 0); + + auto addr = tlctx->get_constant((std::size_t)stmt->so_func); + auto func = builder->CreateIntToPtr(addr, func_ptr_type); + create_call(func, arg_values); +} + +void CodeGenLLVM::visit(ExternalFuncCallStmt *stmt) { + TI_NOT_IMPLEMENTED +} + +void CodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { + llvm_val[stmt] = get_arg(2); +} + void CodeGenLLVM::eliminate_unused_functions() { TaichiLLVMContext::eliminate_unused_functions( module.get(), [&](std::string func_name) { @@ -2013,8 +2247,23 @@ FunctionType CodeGenLLVM::compile_module_to_executable() { } auto offloaded_tasks_local = offloaded_tasks; auto kernel_name_ = kernel_name; - return [=](Context &context) { + return [offloaded_tasks_local, kernel_name_, + kernel = this->kernel](RuntimeContext &context) { TI_TRACE("Launching kernel {}", kernel_name_); + auto args = kernel->args; + // For taichi ndarrays, context.args saves pointer to its + // |DeviceAllocation|, CPU backend actually want to use the raw ptr here. + for (int i = 0; i < (int)args.size(); i++) { + if (args[i].is_array && context.is_device_allocation[i] && + args[i].size > 0) { + DeviceAllocation *ptr = + static_cast(context.get_arg(i)); + uint64 host_ptr = (uint64)kernel->program->get_llvm_program_impl() + ->get_ndarray_alloc_info_ptr(*ptr); + context.set_arg(i, host_ptr); + context.set_device_allocation(i, false); + } + } for (auto task : offloaded_tasks_local) { task(&context); } @@ -2053,22 +2302,32 @@ llvm::Type *CodeGenLLVM::get_tls_buffer_type() { } std::vector CodeGenLLVM::get_xlogue_argument_types() { - return {llvm::PointerType::get(get_runtime_type("Context"), 0), + return {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), get_tls_buffer_type()}; } +std::vector CodeGenLLVM::get_mesh_xlogue_argument_types() { + return {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0), + get_tls_buffer_type(), tlctx->get_data_type()}; +} + llvm::Type *CodeGenLLVM::get_xlogue_function_type() { return llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), get_xlogue_argument_types(), false); } +llvm::Type *CodeGenLLVM::get_mesh_xlogue_function_type() { + return llvm::FunctionType::get(llvm::Type::getVoidTy(*llvm_context), + get_mesh_xlogue_argument_types(), false); +} + llvm::Value *CodeGenLLVM::get_root(int snode_tree_id) { return create_call("LLVMRuntime_get_roots", {get_runtime(), tlctx->get_constant(snode_tree_id)}); } llvm::Value *CodeGenLLVM::get_runtime() { - auto runtime_ptr = create_call("Context_get_runtime", {get_context()}); + auto runtime_ptr = create_call("RuntimeContext_get_runtime", {get_context()}); return builder->CreateBitCast( runtime_ptr, llvm::PointerType::get(get_runtime_type("LLVMRuntime"), 0)); } @@ -2106,4 +2365,52 @@ llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr &block) { return xlogue; } +llvm::Value *CodeGenLLVM::create_mesh_xlogue(std::unique_ptr &block) { + llvm::Value *xlogue; + + auto xlogue_type = get_mesh_xlogue_function_type(); + auto xlogue_ptr_type = llvm::PointerType::get(xlogue_type, 0); + + if (block) { + auto guard = get_function_creation_guard(get_mesh_xlogue_argument_types()); + block->accept(this); + xlogue = guard.body; + } else { + xlogue = llvm::ConstantPointerNull::get(xlogue_ptr_type); + } + + return xlogue; +} + +void CodeGenLLVM::visit(FuncCallStmt *stmt) { + if (!func_map.count(stmt->func)) { + auto guard = get_function_creation_guard( + {llvm::PointerType::get(get_runtime_type("RuntimeContext"), 0)}); + func_map.insert({stmt->func, guard.body}); + stmt->func->ir->accept(this); + } + llvm::Function *llvm_func = func_map[stmt->func]; + auto *new_ctx = builder->CreateAlloca(get_runtime_type("RuntimeContext")); + call("RuntimeContext_set_runtime", new_ctx, get_runtime()); + for (int i = 0; i < stmt->args.size(); i++) { + auto *val = + bitcast_to_u64(llvm_val[stmt->args[i]], stmt->args[i]->ret_type); + call("RuntimeContext_set_args", new_ctx, + llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val); + } + llvm::Value *result_buffer = nullptr; + if (stmt->ret_type->is() && + !stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { + result_buffer = builder->CreateAlloca(tlctx->get_data_type()); + call("RuntimeContext_set_result_buffer", new_ctx, result_buffer); + create_call(llvm_func, {new_ctx}); + auto *ret_val_u64 = builder->CreateLoad(result_buffer); + llvm_val[stmt] = bitcast_from_u64(ret_val_u64, stmt->ret_type); + } else { + create_call(llvm_func, {new_ctx}); + } +} + TLANG_NAMESPACE_END + +#endif // #ifdef TI_WITH_LLVM diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 86ffb2f182439..e32ee3cdce298 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -1,5 +1,6 @@ // The LLVM backend for CPUs/NVPTX/AMDGPU #pragma once +#ifdef TI_WITH_LLVM #include #include @@ -30,7 +31,7 @@ class OffloadedTask { void compile(); - void operator()(Context *context); + void operator()(RuntimeContext *context); }; class FunctionCreationGuard { @@ -48,8 +49,6 @@ class FunctionCreationGuard { class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { public: - static uint64 task_counter; - Kernel *kernel; IRNode *ir; Program *prog; @@ -72,9 +71,12 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { std::unique_ptr current_task; std::vector offloaded_tasks; llvm::BasicBlock *func_body_bb; + std::set linked_modules; std::unordered_map> loop_vars_llvm; + std::unordered_map func_map; + using IRVisitor::visit; using LLVMModuleBuilder::call; @@ -98,8 +100,12 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { std::vector get_xlogue_argument_types(); + std::vector get_mesh_xlogue_argument_types(); + llvm::Type *get_xlogue_function_type(); + llvm::Type *get_mesh_xlogue_function_type(); + llvm::Value *get_root(int snode_tree_id); llvm::Value *get_runtime(); @@ -136,10 +142,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void emit_gc(OffloadedStmt *stmt); llvm::Value *create_call(llvm::Value *func, - std::vector args = {}); + llvm::ArrayRef args = {}); llvm::Value *create_call(std::string func_name, - std::vector args = {}); + llvm::ArrayRef args = {}); llvm::Value *call(SNode *snode, llvm::Value *node_ptr, const std::string &method, @@ -166,6 +172,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { virtual void emit_extra_unary(UnaryOpStmt *stmt); + void visit(DecorationStmt *stmt) override; + void visit(UnaryOpStmt *stmt) override; void visit(BinaryOpStmt *stmt) override; @@ -306,6 +314,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { virtual void create_offload_range_for(OffloadedStmt *stmt) = 0; + virtual void create_offload_mesh_for(OffloadedStmt *stmt) { + TI_NOT_IMPLEMENTED; + } + void create_offload_struct_for(OffloadedStmt *stmt, bool spmd = false); void visit(LoopIndexStmt *stmt) override; @@ -314,8 +326,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(BlockCornerIndexStmt *stmt) override; - void visit(BlockDimStmt *stmt) override; - void visit(GlobalTemporaryStmt *stmt) override; void visit(ThreadLocalPtrStmt *stmt) override; @@ -344,8 +354,18 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(LoopUniqueStmt *stmt) override; + void visit_call_bitcode(ExternalFuncCallStmt *stmt); + + void visit_call_shared_object(ExternalFuncCallStmt *stmt); + + void visit(ExternalFuncCallStmt *stmt) override; + + void visit(MeshPatchIndexStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); + llvm::Value *create_mesh_xlogue(std::unique_ptr &block); + llvm::Value *extract_exponent_from_float(llvm::Value *f); llvm::Value *extract_digits_from_float(llvm::Value *f, bool full); @@ -355,7 +375,19 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *get_exponent_offset(llvm::Value *exponent, CustomFloatType *cft); - ~CodeGenLLVM() = default; + llvm::Value *atomic_op_using_cas( + llvm::Value *dest, + llvm::Value *val, + std::function op); + + void visit(FuncCallStmt *stmt) override; + + llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type); + llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type); + + ~CodeGenLLVM() override = default; }; TLANG_NAMESPACE_END + +#endif // #ifdef TI_WITH_LLVM diff --git a/taichi/codegen/codegen_llvm_quant.cpp b/taichi/codegen/codegen_llvm_quant.cpp index 0943c01b373d7..f136f2fb8e970 100644 --- a/taichi/codegen/codegen_llvm_quant.cpp +++ b/taichi/codegen/codegen_llvm_quant.cpp @@ -1,3 +1,4 @@ +#ifdef TI_WITH_LLVM #include "taichi/codegen/codegen_llvm.h" #include "taichi/ir/statements.h" @@ -669,3 +670,5 @@ llvm::Value *CodeGenLLVM::load_custom_float(Stmt *ptr_stmt) { } TLANG_NAMESPACE_END + +#endif // #ifdef TI_WITH_LLVM diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp new file mode 100644 index 0000000000000..b29e03b460372 --- /dev/null +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -0,0 +1,120 @@ +#include "taichi/codegen/spirv/kernel_utils.h" + +#include + +#include "taichi/program/kernel.h" +#define TI_RUNTIME_HOST +#include "taichi/program/context.h" +#undef TI_RUNTIME_HOST + +namespace taichi { +namespace lang { +namespace spirv { + +// static +std::string TaskAttributes::buffers_name(BufferInfo b) { + if (b.type == BufferType::Args) { + return "Args"; + } + if (b.type == BufferType::Rets) { + return "Rets"; + } + if (b.type == BufferType::GlobalTmps) { + return "GlobalTmps"; + } + if (b.type == BufferType::Root) { + return std::string("Root: ") + std::to_string(b.root_id); + } + TI_ERROR("unrecognized buffer type"); +} + +std::string TaskAttributes::debug_string() const { + std::string result; + result += fmt::format( + "", + TaskAttributes::buffers_name(buffer), binding); +} + +KernelContextAttributes::KernelContextAttributes(const Kernel &kernel) + : args_bytes_(0), + rets_bytes_(0), + extra_args_bytes_(RuntimeContext::extra_args_size) { + arg_attribs_vec_.reserve(kernel.args.size()); + for (const auto &ka : kernel.args) { + ArgAttributes aa; + aa.dt = ka.dt; + const size_t dt_bytes = data_type_size(aa.dt); + aa.is_array = ka.is_array; + if (aa.is_array) { + aa.field_dim = ka.total_dim - ka.element_shape.size(); + aa.element_shape = ka.element_shape; + } + aa.stride = dt_bytes; + aa.index = arg_attribs_vec_.size(); + arg_attribs_vec_.push_back(aa); + } + for (const auto &kr : kernel.rets) { + RetAttributes ra; + size_t dt_bytes{0}; + if (auto tensor_type = kr.dt->cast()) { + ra.dt = tensor_type->get_element_type(); + dt_bytes = data_type_size(ra.dt); + ra.is_array = true; + ra.stride = tensor_type->get_num_elements() * dt_bytes; + } else { + ra.dt = kr.dt; + dt_bytes = data_type_size(ra.dt); + ra.is_array = false; + ra.stride = dt_bytes; + } + ra.index = ret_attribs_vec_.size(); + ret_attribs_vec_.push_back(ra); + } + + auto arange_args = [](auto *vec, size_t offset, bool is_ret) -> size_t { + size_t bytes = offset; + for (int i = 0; i < vec->size(); ++i) { + auto &attribs = (*vec)[i]; + const size_t dt_bytes = (attribs.is_array && !is_ret) + ? sizeof(uint64_t) + : data_type_size(attribs.dt); + // Align bytes to the nearest multiple of dt_bytes + bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes; + attribs.offset_in_mem = bytes; + bytes += is_ret ? attribs.stride : dt_bytes; + TI_TRACE( + " at={} {} offset_in_mem={} stride={}", + (*vec)[i].is_array ? (is_ret ? "array" : "vector ptr") : "scalar", i, + attribs.offset_in_mem, attribs.stride); + } + return bytes - offset; + }; + + TI_TRACE("args:"); + args_bytes_ = arange_args(&arg_attribs_vec_, 0, false); + // Align to extra args + args_bytes_ = (args_bytes_ + 4 - 1) / 4 * 4; + + TI_TRACE("rets:"); + rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true); + + TI_TRACE("sizes: args={} rets={}", args_bytes(), rets_bytes()); + TI_ASSERT(has_rets() == (rets_bytes_ > 0)); +} + +} // namespace spirv +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/kernel_utils.h b/taichi/codegen/spirv/kernel_utils.h similarity index 79% rename from taichi/backends/vulkan/kernel_utils.h rename to taichi/codegen/spirv/kernel_utils.h index 0e1d0ed691aa5..21d528dba12c3 100644 --- a/taichi/backends/vulkan/kernel_utils.h +++ b/taichi/codegen/spirv/kernel_utils.h @@ -6,6 +6,7 @@ #include "taichi/ir/offloaded_task_type.h" #include "taichi/ir/type.h" +#include "taichi/backends/device.h" namespace taichi { namespace lang { @@ -13,21 +14,17 @@ namespace lang { class Kernel; class SNode; -namespace vulkan { +namespace spirv { /** * Per offloaded task attributes. */ struct TaskAttributes { - enum class BufferType { - Root, - GlobalTmps, - Context, - }; + enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr }; struct BufferInfo { BufferType type; - int root_id{-1}; // only used if type==Root + int root_id{-1}; // only used if type==Root or type==ExtArr BufferInfo() = default; @@ -42,11 +39,13 @@ struct TaskAttributes { if (type != other.type) { return false; } - if (type == BufferType::Root) { + if (type == BufferType::Root || type == BufferType::ExtArr) { return root_id == other.root_id; } return true; } + + TI_IO_DEF(type, root_id); }; struct BufferInfoHasher { @@ -55,7 +54,7 @@ struct TaskAttributes { using std::size_t; using std::string; - return hash()(buf.type); + return hash()(buf.type) ^ buf.root_id; } }; @@ -64,9 +63,12 @@ struct TaskAttributes { int binding{0}; std::string debug_string() const; + + TI_IO_DEF(buffer, binding); }; std::string name; + std::string source_path; // Total number of threads to launch (i.e. threads per grid). Note that this // is only advisory, because eventually this number is also determined by the // runtime config. This works because grid strided loop is supported. @@ -89,6 +91,8 @@ struct TaskAttributes { inline bool const_range() const { return (const_begin && const_end); } + + TI_IO_DEF(begin, end, const_begin, const_end); }; std::vector buffer_binds; // Only valid when |task_type| is range_for. @@ -97,14 +101,21 @@ struct TaskAttributes { static std::string buffers_name(BufferInfo b); std::string debug_string() const; + + TI_IO_DEF(name, + advisory_total_num_threads, + advisory_num_threads_per_group, + task_type, + buffer_binds, + range_for_attribs); }; /** * This class contains the attributes descriptors for both the input args and * the return values of a Taichi kernel. * - * Note that all Vulkan tasks (shaders) belonging to the same Taichi kernel will - * share the same kernel args (i.e. they use the same Vulkan buffer for input + * Note that all SPIRV tasks (shaders) belonging to the same Taichi kernel will + * share the same kernel args (i.e. they use the same device buffer for input * args and return values). This is because kernel arguments is a Taichi-level * concept. * @@ -131,16 +142,20 @@ class KernelContextAttributes { int index{-1}; DataType dt; bool is_array{false}; + std::vector element_shape; + std::size_t field_dim{0}; + + TI_IO_DEF(stride, offset_in_mem, index, is_array, element_shape, field_dim); }; public: /** - * This is mostly the same as Kernel::Arg, with Vulkan specific attributes. + * This is mostly the same as Kernel::Arg, with device specific attributes. */ struct ArgAttributes : public AttribsBase {}; /** - * This is mostly the same as Kernel::Ret, with Vulkan specific attributes. + * This is mostly the same as Kernel::Ret, with device specific attributes. */ struct RetAttributes : public AttribsBase {}; @@ -190,21 +205,6 @@ class KernelContextAttributes { return rets_bytes_; } - /** - * Offset (in bytes) of the return values in the memory. - */ - inline size_t rets_mem_offset() const { - return args_bytes(); - } - - /** - * Total size in bytes of the input args and return values. - * - * This *excludes* the extra args bytes. - */ - inline size_t ctx_bytes() const { - return args_bytes() + rets_bytes(); - } /** * Number of bytes needed by the extra arguments. * @@ -219,15 +219,14 @@ class KernelContextAttributes { * Offset (in bytes) of the extra arguments in the memory. */ inline size_t extra_args_mem_offset() const { - return ctx_bytes(); + return args_bytes(); } - /** - * Total bytes needed for allocating the Vulkan buffer. - */ - inline size_t total_bytes() const { - return ctx_bytes() + extra_args_bytes(); - } + TI_IO_DEF(arg_attribs_vec_, + ret_attribs_vec_, + args_bytes_, + rets_bytes_, + extra_args_bytes_); private: std::vector arg_attribs_vec_; @@ -239,7 +238,7 @@ class KernelContextAttributes { }; /** - * Groups all the Vulkan kernels generated from a single ti.kernel. + * Groups all the device kernels generated from a single ti.kernel. */ struct TaichiKernelAttributes { // Taichi kernel name @@ -250,8 +249,10 @@ struct TaichiKernelAttributes { std::vector tasks_attribs; KernelContextAttributes ctx_attribs; + + TI_IO_DEF(name, is_jit_evaluator, tasks_attribs, ctx_attribs); }; -} // namespace vulkan +} // namespace spirv } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/snode_struct_compiler.cpp b/taichi/codegen/spirv/snode_struct_compiler.cpp similarity index 60% rename from taichi/backends/vulkan/snode_struct_compiler.cpp rename to taichi/codegen/spirv/snode_struct_compiler.cpp index c66dc9dbecaf5..6be96bb8c8663 100644 --- a/taichi/backends/vulkan/snode_struct_compiler.cpp +++ b/taichi/codegen/spirv/snode_struct_compiler.cpp @@ -1,10 +1,8 @@ -#include "taichi/backends/vulkan/snode_struct_compiler.h" - -#include "taichi/backends/vulkan/data_type_utils.h" +#include "taichi/codegen/spirv/snode_struct_compiler.h" namespace taichi { namespace lang { -namespace vulkan { +namespace spirv { namespace { class StructCompiler { @@ -16,7 +14,7 @@ class StructCompiler { result.root = &root; result.root_size = compute_snode_size(&root); result.snode_descriptors = std::move(snode_descriptors_); - TI_TRACE("Vulkan RootBuffer size={}", result.root_size); + TI_TRACE("RootBuffer size={}", result.root_size); return result; } @@ -27,7 +25,7 @@ class StructCompiler { SNodeDescriptor sn_desc; sn_desc.snode = sn; if (is_place) { - sn_desc.cell_stride = vk_data_type_size(sn->dt); + sn_desc.cell_stride = data_type_size(sn->dt); sn_desc.container_stride = sn_desc.cell_stride; } else { std::size_t cell_stride = 0; @@ -39,8 +37,17 @@ class StructCompiler { ->second.mem_offset_in_parent_cell = child_offset; } sn_desc.cell_stride = cell_stride; - sn_desc.container_stride = - cell_stride * sn_desc.cells_per_container_pot(); + + if (sn->type == SNodeType::bitmasked) { + size_t num_cells = sn_desc.cells_per_container_pot(); + size_t bitmask_num_words = + num_cells % 32 == 0 ? (num_cells / 32) : (num_cells / 32 + 1); + sn_desc.container_stride = + cell_stride * num_cells + bitmask_num_words * 4; + } else { + sn_desc.container_stride = + cell_stride * sn_desc.cells_per_container_pot(); + } } sn->cell_size_bytes = sn_desc.cell_stride; @@ -54,6 +61,28 @@ class StructCompiler { sn_desc.total_num_cells_from_root *= e.num_elements_from_root; } + // Sum the bits per axis + SNode *snode_head = sn; + do { + for (int i = 0; i < taichi_max_num_indices; i++) { + const AxisExtractor &extractor = snode_head->extractors[i]; + if (extractor.active) { + sn_desc.axis_bits_sum[i] += extractor.num_bits; + } + } + } while ((snode_head = snode_head->parent)); + // Find the start bit + sn_desc.axis_start_bit[0] = 0; + for (int i = 1; i < taichi_max_num_indices; i++) { + sn_desc.axis_start_bit[i] = + sn_desc.axis_bits_sum[i - 1] + sn_desc.axis_start_bit[i - 1]; + } + TI_TRACE("Indices at SNode {}", sn->get_name()); + for (int i = 0; i < taichi_max_num_indices; i++) { + TI_TRACE("Index {}: {}..{}", i, sn_desc.axis_start_bit[i], + sn_desc.axis_start_bit[i] + sn_desc.axis_bits_sum[i]); + } + TI_TRACE("SNodeDescriptor"); TI_TRACE("* snode={}", sn_desc.snode->id); TI_TRACE("* type={} (is_place={})", sn_desc.snode->node_type_name, @@ -75,7 +104,7 @@ class StructCompiler { } // namespace -int SNodeDescriptor::cells_per_container_pot() const { +size_t SNodeDescriptor::cells_per_container_pot() const { return snode->num_cells_per_container; } @@ -84,6 +113,6 @@ CompiledSNodeStructs compile_snode_structs(SNode &root) { return compiler.run(root); } -} // namespace vulkan +} // namespace spirv } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/snode_struct_compiler.h b/taichi/codegen/spirv/snode_struct_compiler.h similarity index 76% rename from taichi/backends/vulkan/snode_struct_compiler.h rename to taichi/codegen/spirv/snode_struct_compiler.h index 068556fd90c9b..f2b4cba6cbe5d 100644 --- a/taichi/backends/vulkan/snode_struct_compiler.h +++ b/taichi/codegen/spirv/snode_struct_compiler.h @@ -7,18 +7,18 @@ namespace taichi { namespace lang { -namespace vulkan { +namespace spirv { struct SNodeDescriptor { const SNode *snode = nullptr; // Stride (bytes) of a single cell. - int cell_stride = 0; + size_t cell_stride = 0; // Number of cells per container, padded to Power of Two (pot). - int cells_per_container_pot() const; + size_t cells_per_container_pot() const; // Bytes of a single container. - int container_stride = 0; + size_t container_stride = 0; // Total number of CELLS of this SNode, NOT padded to PoT. // For example, for a layout of @@ -27,10 +27,13 @@ struct SNodeDescriptor { // .dense(ti.ij, (5, 3)) // S2 // |total_num_cells_from_root| for S2 is 3x2x5x3 = 90. That is, S2 has a total // of 90 cells. Note that the number of S2 (container) itself is 3x2=6! - int total_num_cells_from_root = 0; + size_t total_num_cells_from_root = 0; // An SNode can have multiple number of components, where each component // starts at a fixed offset in its parent cell's memory. - int mem_offset_in_parent_cell = 0; + size_t mem_offset_in_parent_cell = 0; + + int axis_bits_sum[taichi_max_num_indices] = {0}; + int axis_start_bit[taichi_max_num_indices] = {0}; SNode *get_child(int ch_i) const { return snode->ch[ch_i].get(); @@ -41,15 +44,15 @@ using SNodeDescriptorsMap = std::unordered_map; struct CompiledSNodeStructs { // Root buffer size in bytes. - size_t root_size; + size_t root_size{0}; // Root SNode - const SNode *root; + const SNode *root{nullptr}; // Map from SNode ID to its descriptor. SNodeDescriptorsMap snode_descriptors; }; CompiledSNodeStructs compile_snode_structs(SNode &root); -} // namespace vulkan +} // namespace spirv } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/codegen_vulkan.cpp b/taichi/codegen/spirv/spirv_codegen.cpp similarity index 51% rename from taichi/backends/vulkan/codegen_vulkan.cpp rename to taichi/codegen/spirv/spirv_codegen.cpp index 8eb6dcf556004..821ccb12b91f6 100644 --- a/taichi/backends/vulkan/codegen_vulkan.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1,4 +1,4 @@ -#include "taichi/backends/vulkan/codegen_vulkan.h" +#include "taichi/codegen/spirv/spirv_codegen.h" #include #include @@ -8,24 +8,26 @@ #include "taichi/ir/statements.h" #include "taichi/ir/ir.h" #include "taichi/util/line_appender.h" -#include "taichi/backends/vulkan/kernel_utils.h" -#include "taichi/backends/vulkan/runtime.h" +#include "taichi/codegen/spirv/kernel_utils.h" #include "taichi/backends/opengl/opengl_data_types.h" -#include "taichi/backends/vulkan/spirv_ir_builder.h" -#include "taichi/backends/vulkan/spirv_snode_compiler.h" +#include "taichi/codegen/spirv/spirv_ir_builder.h" #include "taichi/ir/transforms.h" +#include "taichi/math/arithmetic.h" #include #include namespace taichi { namespace lang { -namespace vulkan { +namespace spirv { namespace { constexpr char kRootBufferName[] = "root_buffer"; constexpr char kGlobalTmpsBufferName[] = "global_tmps_buffer"; -constexpr char kContextBufferName[] = "context_buffer"; +constexpr char kArgsBufferName[] = "args_buffer"; +constexpr char kRetBufferName[] = "ret_buffer"; +constexpr char kListgenBufferName[] = "listgen_buffer"; +constexpr char kExtArrBufferName[] = "ext_arr_buffer"; constexpr int kMaxNumThreadsGridStrideLoop = 65536; @@ -41,8 +43,14 @@ std::string buffer_instance_name(BufferInfo b) { return std::string(kRootBufferName) + "_" + std::to_string(b.root_id); case BufferType::GlobalTmps: return kGlobalTmpsBufferName; - case BufferType::Context: - return kContextBufferName; + case BufferType::Args: + return kArgsBufferName; + case BufferType::Rets: + return kRetBufferName; + case BufferType::ListGen: + return kListgenBufferName; + case BufferType::ExtArr: + return kExtArrBufferName; default: TI_NOT_IMPLEMENTED; break; @@ -61,14 +69,16 @@ class TaskCodegen : public IRVisitor { int task_id_in_kernel; }; + const bool use_64bit_pointers = false; + explicit TaskCodegen(const Params ¶ms) - : task_ir_(params.task_ir), + : device_(params.device), + task_ir_(params.task_ir), compiled_structs_(params.compiled_structs), ctx_attribs_(params.ctx_attribs), task_name_(fmt::format("{}_t{:02d}", params.ti_kernel_name, - params.task_id_in_kernel)), - device_(params.device) { + params.task_id_in_kernel)) { allow_undefined_visitor = true; invoke_default_visitor = true; @@ -94,13 +104,20 @@ class TaskCodegen : public IRVisitor { kernel_function_ = ir_->new_function(); // void main(); ir_->debug(spv::OpName, kernel_function_, "main"); + compile_args_struct(); + compile_ret_struct(); + if (task_ir_->task_type == OffloadedTaskType::serial) { generate_serial_kernel(task_ir_); } else if (task_ir_->task_type == OffloadedTaskType::range_for) { // struct_for is automatically lowered to ranged_for for dense snodes generate_range_for_kernel(task_ir_); + } else if (task_ir_->task_type == OffloadedTaskType::listgen) { + generate_listgen_kernel(task_ir_); + } else if (task_ir_->task_type == OffloadedTaskType::struct_for) { + generate_struct_for_kernel(task_ir_); } else { - TI_ERROR("Unsupported offload type={} on Vulkan arch", + TI_ERROR("Unsupported offload type={} on SPIR-V codegen", task_ir_->task_name()); } // Headers need global information, so it has to be delayed after visiting @@ -120,7 +137,9 @@ class TaskCodegen : public IRVisitor { void visit(Block *stmt) override { for (auto &s : stmt->statements) { - s->accept(this); + if (offload_loop_motion_.find(s.get()) == offload_loop_motion_.end()) { + s->accept(this); + } } } @@ -206,16 +225,10 @@ class TaskCodegen : public IRVisitor { } void visit(GetRootStmt *stmt) override { - // Should we assert |root_stmt_| is assigned only once? const int root_id = snode_to_root_.at(stmt->root()->id); root_stmts_[root_id] = stmt; - get_buffer_value({BufferType::Root, root_id}); - spirv::SType root_ptr = ir_->get_pointer_type( - spirv_snodes_.at(root_id).root_stype, spv::StorageClassStorageBuffer); - spirv::Value root_val = - ir_->make_value(spv::OpAccessChain, root_ptr, - get_buffer_value({BufferType::Root, root_id}), - ir_->const_i32_zero_, ir_->const_i32_zero_); + // get_buffer_value({BufferType::Root, root_id}, PrimitiveType::u32); + spirv::Value root_val = make_pointer(0); ir_->register_value(stmt->raw_name(), root_val); } @@ -228,27 +241,104 @@ class TaskCodegen : public IRVisitor { TI_ASSERT(snode_descs.at(stmt->input_snode->id).get_child(stmt->chid) == out_snode); + const auto &desc = snode_descs.at(out_snode->id); + spirv::Value input_ptr_val = ir_->query_value(stmt->input_ptr->raw_name()); - spirv::Value offset = - ir_->int_immediate_number(ir_->i32_type(), stmt->chid); - spirv::Value val; + spirv::Value offset = make_pointer(desc.mem_offset_in_parent_cell); + spirv::Value val = ir_->add(input_ptr_val, offset); + ir_->register_value(stmt->raw_name(), val); + if (out_snode->is_place()) { TI_ASSERT(ptr_to_buffers_.count(stmt) == 0); ptr_to_buffers_[stmt] = BufferInfo(BufferType::Root, root); + } + } - spirv::SType dt_ptr = ir_->get_pointer_type( - ir_->get_primitive_buffer_type(true, out_snode->dt), - spv::StorageClassStorageBuffer); - val = ir_->make_value(spv::OpAccessChain, dt_ptr, input_ptr_val, offset); + enum class ActivationOp { activate, deactivate, query }; + + spirv::Value bitmasked_activation(ActivationOp op, + spirv::Value parent_ptr, + int root_id, + const SNode *sn, + spirv::Value input_index) { + spirv::SType ptr_dt = parent_ptr.stype; + const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; + const auto &desc = snode_descs.at(sn->id); + + auto bitmask_word_index = + ir_->make_value(spv::OpShiftRightLogical, ptr_dt, input_index, + ir_->uint_immediate_number(ptr_dt, 5)); + auto bitmask_bit_index = + ir_->make_value(spv::OpBitwiseAnd, ptr_dt, input_index, + ir_->uint_immediate_number(ptr_dt, 31)); + auto bitmask_mask = ir_->make_value(spv::OpShiftLeftLogical, ptr_dt, + ir_->const_i32_one_, bitmask_bit_index); + + auto buffer = get_buffer_value(BufferInfo(BufferType::Root, root_id), + PrimitiveType::u32); + auto bitmask_word_ptr = + ir_->make_value(spv::OpShiftLeftLogical, ptr_dt, bitmask_word_index, + ir_->uint_immediate_number(ir_->u32_type(), 2)); + bitmask_word_ptr = ir_->add( + bitmask_word_ptr, + make_pointer(desc.cell_stride * desc.cells_per_container_pot())); + bitmask_word_ptr = ir_->add(parent_ptr, bitmask_word_ptr); + bitmask_word_ptr = ir_->make_value( + spv::OpShiftRightLogical, ir_->u32_type(), bitmask_word_ptr, + ir_->uint_immediate_number(ir_->u32_type(), 2)); + bitmask_word_ptr = + ir_->struct_array_access(ir_->u32_type(), buffer, bitmask_word_ptr); + + if (op == ActivationOp::activate) { + return ir_->make_value(spv::OpAtomicOr, ir_->u32_type(), bitmask_word_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, bitmask_mask); + } else if (op == ActivationOp::deactivate) { + bitmask_mask = ir_->make_value(spv::OpNot, ir_->u32_type(), bitmask_mask); + return ir_->make_value(spv::OpAtomicAnd, ir_->u32_type(), + bitmask_word_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, bitmask_mask); } else { - spirv::SType snode_array = - spirv_snodes_[root].query_snode_array_stype(out_snode->id); - spirv::SType snode_array_ptr = - ir_->get_pointer_type(snode_array, spv::StorageClassStorageBuffer); - val = ir_->make_value(spv::OpAccessChain, snode_array_ptr, input_ptr_val, - offset); + auto bitmask_val = ir_->load_variable(bitmask_word_ptr, ir_->u32_type()); + auto bit = ir_->make_value(spv::OpShiftRightLogical, ir_->u32_type(), + bitmask_val, bitmask_bit_index); + bit = ir_->make_value(spv::OpBitwiseAnd, ir_->u32_type(), bit, + ir_->uint_immediate_number(ir_->u32_type(), 1)); + return ir_->make_value(spv::OpUGreaterThan, ir_->bool_type(), bit, + ir_->uint_immediate_number(ir_->u32_type(), 0)); + } + } + + void visit(SNodeOpStmt *stmt) override { + const int root_id = snode_to_root_.at(stmt->snode->id); + std::string parent = stmt->ptr->raw_name(); + spirv::Value parent_val = ir_->query_value(parent); + + if (stmt->snode->type == SNodeType::bitmasked) { + spirv::Value input_index_val = + ir_->cast(parent_val.stype, ir_->query_value(stmt->val->raw_name())); + + if (stmt->op_type == SNodeOpType::is_active) { + auto is_active = + bitmasked_activation(ActivationOp::query, parent_val, root_id, + stmt->snode, input_index_val); + is_active = + ir_->cast(ir_->get_primitive_type(stmt->ret_type), is_active); + is_active = ir_->make_value(spv::OpSNegate, is_active.stype, is_active); + ir_->register_value(stmt->raw_name(), is_active); + } else if (stmt->op_type == SNodeOpType::deactivate) { + bitmasked_activation(ActivationOp::deactivate, parent_val, root_id, + stmt->snode, input_index_val); + } else if (stmt->op_type == SNodeOpType::activate) { + bitmasked_activation(ActivationOp::activate, parent_val, root_id, + stmt->snode, input_index_val); + } else { + TI_NOT_IMPLEMENTED; + } + } else { + TI_NOT_IMPLEMENTED; } - ir_->register_value(stmt->raw_name(), val); } void visit(SNodeLookupStmt *stmt) override { @@ -256,52 +346,59 @@ class TaskCodegen : public IRVisitor { bool is_root{false}; // Eliminate first root snode access const int root_id = snode_to_root_.at(stmt->snode->id); std::string parent; - spirv::SType snode_struct; + if (stmt->input_snode) { parent = stmt->input_snode->raw_name(); - if (stmt->snode->id == compiled_structs_[root_id].root->id) { - is_root = true; - snode_struct = spirv_snodes_.at(root_id).root_stype; - } else if (!is_root) { - snode_struct = - spirv_snodes_.at(root_id).query_snode_struct_stype(stmt->snode->id); - } } else { TI_ASSERT(root_stmts_.at(root_id) != nullptr); parent = root_stmts_.at(root_id)->raw_name(); - snode_struct = spirv_snodes_.at(root_id).root_stype; - is_root = true; } const auto *sn = stmt->snode; - if (stmt->activate && !(sn->type == SNodeType::dense)) { - // Sparse SNode not supported yet. - TI_NOT_IMPLEMENTED; - } spirv::Value parent_val = ir_->query_value(parent); + + if (stmt->activate) { + if (sn->type == SNodeType::dense) { + // Do nothing + } else if (sn->type == SNodeType::bitmasked) { + spirv::Value input_index_val = + ir_->query_value(stmt->input_index->raw_name()); + bitmasked_activation(ActivationOp::activate, parent_val, root_id, sn, + input_index_val); + } else { + TI_NOT_IMPLEMENTED; + } + } + spirv::Value val; if (is_root) { val = parent_val; // Assert Root[0] access at first time } else { - spirv::Value input_index_val = - ir_->query_value(stmt->input_index->raw_name()); - spirv::SType snode_struct_ptr = - ir_->get_pointer_type(snode_struct, spv::StorageClassStorageBuffer); - val = ir_->make_value(spv::OpAccessChain, snode_struct_ptr, parent_val, - input_index_val); + const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; + const auto &desc = snode_descs.at(sn->id); + + spirv::Value input_index_val = ir_->cast( + parent_val.stype, ir_->query_value(stmt->input_index->raw_name())); + spirv::Value stride = make_pointer(desc.cell_stride); + spirv::Value offset = ir_->mul(input_index_val, stride); + val = ir_->add(parent_val, offset); } ir_->register_value(stmt->raw_name(), val); } void visit(RandStmt *stmt) override { spirv::Value val; - spirv::Value global_tmp = get_buffer_value(BufferType::GlobalTmps); + spirv::Value global_tmp = + get_buffer_value(BufferType::GlobalTmps, PrimitiveType::u32); if (stmt->element_type()->is_primitive(PrimitiveTypeID::i32)) { val = ir_->rand_i32(global_tmp); } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::u32)) { val = ir_->rand_u32(global_tmp); } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::f32)) { val = ir_->rand_f32(global_tmp); + } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::f16)) { + auto highp_val = ir_->rand_f32(global_tmp); + val = ir_->cast(ir_->f16_type(), highp_val); } else { TI_ERROR("rand only support 32-bit type"); } @@ -321,29 +418,46 @@ class TaskCodegen : public IRVisitor { void visit(BitExtractStmt *stmt) override { spirv::Value input_val = ir_->query_value(stmt->input->raw_name()); - spirv::Value tmp0 = - ir_->int_immediate_number(ir_->i32_type(), stmt->bit_begin); - spirv::Value tmp1 = ir_->int_immediate_number( - ir_->i32_type(), stmt->bit_end - stmt->bit_begin); - spirv::Value tmp2 = ir_->make_value(spv::OpShiftRightArithmetic, - ir_->i32_type(), input_val, tmp0); - spirv::Value tmp3 = ir_->make_value( - spv::OpShiftLeftLogical, ir_->i32_type(), ir_->const_i32_one_, tmp1); - spirv::Value tmp4 = ir_->sub(tmp3, ir_->const_i32_one_); - spirv::Value val = - ir_->make_value(spv::OpBitwiseAnd, ir_->i32_type(), tmp2, tmp4); + auto stype = input_val.stype; + spirv::Value tmp0 = ir_->int_immediate_number(stype, stmt->bit_begin); + spirv::Value tmp1 = + ir_->int_immediate_number(stype, stmt->bit_end - stmt->bit_begin); + spirv::Value tmp2 = + ir_->make_value(spv::OpShiftRightArithmetic, stype, input_val, tmp0); + spirv::Value tmp3 = + ir_->make_value(spv::OpShiftLeftLogical, stype, + ir_->int_immediate_number(stype, 1), tmp1); + spirv::Value tmp4 = ir_->sub(tmp3, ir_->int_immediate_number(stype, 1)); + spirv::Value val = ir_->make_value(spv::OpBitwiseAnd, stype, tmp2, tmp4); ir_->register_value(stmt->raw_name(), val); } void visit(LoopIndexStmt *stmt) override { - TI_ASSERT(stmt->index == 0); // TODO: multiple indices const auto stmt_name = stmt->raw_name(); if (stmt->loop->is()) { const auto type = stmt->loop->as()->task_type; if (type == OffloadedTaskType::range_for) { TI_ASSERT(stmt->index == 0); spirv::Value loop_var = ir_->query_value("ii"); - spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_); + // spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_); + ir_->register_value(stmt_name, loop_var); + } else if (type == OffloadedTaskType::struct_for) { + SNode *snode = stmt->loop->as()->snode; + spirv::Value val = ir_->query_value("ii"); + // FIXME: packed layout (non POT) + int root_id = snode_to_root_[snode->id]; + const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; + const int *axis_start_bit = snode_descs.at(snode->id).axis_start_bit; + const int *axis_bits_sum = snode_descs.at(snode->id).axis_bits_sum; + val = + ir_->make_value(spv::OpShiftRightLogical, ir_->u32_type(), val, + ir_->uint_immediate_number( + ir_->u32_type(), axis_start_bit[stmt->index])); + val = ir_->make_value( + spv::OpBitwiseAnd, ir_->u32_type(), val, + ir_->uint_immediate_number(ir_->u32_type(), + (1 << axis_bits_sum[stmt->index]) - 1)); + val = ir_->cast(ir_->i32_type(), val); ir_->register_value(stmt_name, val); } else { TI_NOT_IMPLEMENTED; @@ -361,55 +475,19 @@ class TaskCodegen : public IRVisitor { void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); const auto dt = stmt->val->element_type(); - bool struct_compiled = false; - spirv::Value buffer_ptr; + const auto &primitive_buffer_type = ir_->get_primitive_type(dt); + spirv::Value val = ir_->query_value(stmt->val->raw_name()); - if (ptr_to_buffers_.at(stmt->dest).type == BufferType::Root) { - buffer_ptr = ir_->query_value(stmt->dest->raw_name()); - buffer_ptr.flag = - spirv::ValueKind::kVariablePtr; // make this value could store/load - struct_compiled = true; - } else { - buffer_ptr = at_buffer(stmt->dest, dt); - } - const auto &primitive_buffer_type = - ir_->get_primitive_buffer_type(struct_compiled, dt); - if (buffer_ptr.stype.element_type_id == val.stype.id) { - // No bit cast - ir_->store_variable(buffer_ptr, val); - } else { - ir_->store_variable( - buffer_ptr, - ir_->make_value(spv::OpBitcast, primitive_buffer_type, val)); - } + store_buffer(stmt->dest, val); } void visit(GlobalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); auto dt = stmt->element_type(); - bool struct_compiled = false; - spirv::Value buffer_ptr; - spirv::Value val; - if (ptr_to_buffers_.at(stmt->src).type == BufferType::Root) { - buffer_ptr = ir_->query_value(stmt->src->raw_name()); - buffer_ptr.flag = - spirv::ValueKind::kVariablePtr; // make this value could store/load - struct_compiled = true; - } else { - buffer_ptr = at_buffer(stmt->src, dt); - } + const auto &primitive_buffer_type = ir_->get_primitive_type(dt); - const auto &primitive_buffer_type = - ir_->get_primitive_buffer_type(struct_compiled, dt); - if (buffer_ptr.stype.element_type_id == val.stype.id) { - // No bit cast - val = ir_->load_variable(buffer_ptr, primitive_buffer_type); - } else { - val = ir_->make_value( - spv::OpBitcast, ir_->get_primitive_type(dt), - ir_->load_variable(buffer_ptr, primitive_buffer_type)); - } + auto val = load_buffer(stmt->src, dt); ir_->register_value(stmt->raw_name(), val); } @@ -420,33 +498,37 @@ class TaskCodegen : public IRVisitor { const auto offset_in_mem = arg_attribs.offset_in_mem; if (stmt->is_ptr) { // Do not shift! We are indexing the buffers at byte granularity. - spirv::Value val = - ir_->int_immediate_number(ir_->i32_type(), offset_in_mem); - ir_->register_value(stmt->raw_name(), val); + // spirv::Value val = + // ir_->int_immediate_number(ir_->i32_type(), offset_in_mem); + // ir_->register_value(stmt->raw_name(), val); } else { const auto dt = arg_attribs.dt; - spirv::Value idx_val = ir_->int_immediate_number( - ir_->i32_type(), (offset_in_mem / sizeof(int32_t))); - spirv::Value buffer_val = ir_->struct_array_access( - ir_->i32_type(), get_buffer_value(BufferType::Context), idx_val); - spirv::Value val = - ir_->make_value(spv::OpBitcast, ir_->get_primitive_type(dt), - ir_->load_variable(buffer_val, ir_->i32_type())); + const auto val_type = ir_->get_primitive_type(dt); + spirv::Value buffer_val = ir_->make_value( + spv::OpAccessChain, + ir_->get_pointer_type(val_type, spv::StorageClassUniform), + get_buffer_value(BufferType::Args, PrimitiveType::i32), + ir_->int_immediate_number(ir_->i32_type(), arg_id)); + buffer_val.flag = ValueKind::kVariablePtr; + spirv::Value val = ir_->load_variable(buffer_val, val_type); ir_->register_value(stmt->raw_name(), val); } } void visit(ReturnStmt *stmt) override { - // TODO: use stmt->ret_id instead of 0 as index - const auto &ret_attribs = ctx_attribs_->rets()[0]; - const int index_in_buffer = ret_attribs.offset_in_mem / sizeof(int32_t); - spirv::Value idx_val = - ir_->int_immediate_number(ir_->i32_type(), index_in_buffer); - spirv::Value buffer_val = ir_->struct_array_access( - ir_->i32_type(), get_buffer_value(BufferType::Context), idx_val); - spirv::Value val = ir_->query_value(stmt->value->raw_name()); - ir_->store_variable(buffer_val, - ir_->make_value(spv::OpBitcast, ir_->i32_type(), val)); + // Now we only support one ret + auto dt = stmt->element_types()[0]; + for (int i = 0; i < stmt->values.size(); i++) { + spirv::Value buffer_val = ir_->make_value( + spv::OpAccessChain, + ir_->get_storage_pointer_type(ir_->get_primitive_type(dt)), + get_buffer_value(BufferType::Rets, dt), + ir_->int_immediate_number(ir_->i32_type(), 0), + ir_->int_immediate_number(ir_->i32_type(), i)); + buffer_val.flag = ValueKind::kVariablePtr; + spirv::Value val = ir_->query_value(stmt->values[i]->raw_name()); + ir_->store_variable(buffer_val, val); + } } void visit(GlobalTemporaryStmt *stmt) override { @@ -461,40 +543,41 @@ class TaskCodegen : public IRVisitor { const auto name = stmt->raw_name(); const auto arg_id = stmt->arg_id; const auto axis = stmt->axis; - const auto extra_args_mem_offset = ctx_attribs_->extra_args_mem_offset(); - const auto extra_args_index_base = - (extra_args_mem_offset / sizeof(int32_t)); - spirv::Value index = ir_->int_immediate_number( - ir_->i32_type(), - extra_args_index_base + arg_id * taichi_max_num_indices + axis); - spirv::Value var_ptr = ir_->struct_array_access( - ir_->i32_type(), get_buffer_value(BufferType::Context), index); + + const auto extra_args_member_index = ctx_attribs_->args().size(); + + const auto extra_arg_index = (arg_id * taichi_max_num_indices) + axis; + spirv::Value var_ptr = ir_->make_value( + spv::OpAccessChain, + ir_->get_pointer_type(ir_->i32_type(), spv::StorageClassUniform), + get_buffer_value(BufferType::Args, PrimitiveType::i32), + ir_->int_immediate_number(ir_->i32_type(), + extra_args_member_index + extra_arg_index)); spirv::Value var = ir_->load_variable(var_ptr, ir_->i32_type()); + ir_->register_value(name, var); } void visit(ExternalPtrStmt *stmt) override { // Used mostly for transferring data between host (e.g. numpy array) and - // Vulkan. + // device. TI_ASSERT(stmt->width() == 1); spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0); + const auto *argload = stmt->base_ptrs[0]->as(); + const int arg_id = argload->arg_id; { - const auto *argload = stmt->base_ptrs[0]->as(); - const int arg_id = argload->arg_id; const int num_indices = stmt->indices.size(); std::vector size_var_names; - const auto extra_args_mem_offset = ctx_attribs_->extra_args_mem_offset(); - const auto extra_args_index_base = - (extra_args_mem_offset / sizeof(int32_t)); + const auto extra_args_member_index = ctx_attribs_->args().size(); for (int i = 0; i < num_indices; i++) { std::string var_name = fmt::format("{}_size{}_", stmt->raw_name(), i); - const auto extra_arg_linear_index_offset = - (arg_id * taichi_max_num_indices) + i; - const auto extra_arg_linear_index = - extra_args_index_base + extra_arg_linear_index_offset; - spirv::Value var_ptr = ir_->struct_array_access( - ir_->i32_type(), get_buffer_value(BufferType::Context), - ir_->int_immediate_number(ir_->i32_type(), extra_arg_linear_index)); + const auto extra_arg_index = (arg_id * taichi_max_num_indices) + i; + spirv::Value var_ptr = ir_->make_value( + spv::OpAccessChain, + ir_->get_pointer_type(ir_->i32_type(), spv::StorageClassUniform), + get_buffer_value(BufferType::Args, PrimitiveType::i32), + ir_->int_immediate_number( + ir_->i32_type(), extra_args_member_index + extra_arg_index)); spirv::Value var = ir_->load_variable(var_ptr, ir_->i32_type()); ir_->register_value(var_name, var); size_var_names.push_back(std::move(var_name)); @@ -502,19 +585,40 @@ class TaskCodegen : public IRVisitor { for (int i = 0; i < num_indices; i++) { spirv::Value size_var = ir_->query_value(size_var_names[i]); spirv::Value indices = ir_->query_value(stmt->indices[i]->raw_name()); - spirv::Value tmp; linear_offset = ir_->mul(linear_offset, size_var); linear_offset = ir_->add(linear_offset, indices); } linear_offset = ir_->make_value( spv::OpShiftLeftLogical, ir_->i32_type(), linear_offset, - ir_->int_immediate_number(ir_->i32_type(), 2)); + ir_->int_immediate_number(ir_->i32_type(), + log2int(ir_->get_primitive_type_size( + argload->ret_type.ptr_removed())))); + ir_->decorate(spv::OpDecorate, linear_offset, + spv::DecorationNoSignedWrap); } - spirv::Value val = ir_->add( - ir_->query_value(stmt->base_ptrs[0]->raw_name()), linear_offset); - ir_->register_value(stmt->raw_name(), val); - ptr_to_buffers_[stmt] = BufferType::Context; + if (device_->get_cap(DeviceCapability::spirv_has_physical_storage_buffer)) { + spirv::Value addr_ptr = ir_->make_value( + spv::OpAccessChain, + ir_->get_pointer_type(ir_->u64_type(), spv::StorageClassUniform), + get_buffer_value(BufferType::Args, PrimitiveType::i32), + ir_->int_immediate_number(ir_->i32_type(), arg_id)); + spirv::Value addr = ir_->load_variable(addr_ptr, ir_->u64_type()); + addr = ir_->add(addr, ir_->make_value(spv::OpSConvert, ir_->u64_type(), + linear_offset)); + ir_->register_value(stmt->raw_name(), addr); + } else { + ir_->register_value(stmt->raw_name(), linear_offset); + } + + if (ctx_attribs_->args()[arg_id].is_array) { + ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id}; + } else { + ptr_to_buffers_[stmt] = BufferType::Args; + } + } + + void visit(DecorationStmt *stmt) override { } void visit(UnaryOpStmt *stmt) override { @@ -623,15 +727,15 @@ class TaskCodegen : public IRVisitor { const uint32_t instruction = instruction_id; \ if (is_real(src_dt)) { \ if (data_type_bits(src_dt) > max_bits) { \ - TI_ERROR( \ - "[glsl450] the operand type of instruction {}({}) must <= {}bits", \ - #instruction, instruction_id, max_bits); \ + TI_ERROR("Instruction {}({}) does not {}bits operation", #instruction, \ + instruction_id, data_type_bits(src_dt)); \ } \ val = ir_->call_glsl450(src_type, instruction, operand_val); \ } else { \ TI_NOT_IMPLEMENTED \ } \ } + UNARY_OP_TO_SPIRV(round, Round, 1, 64) UNARY_OP_TO_SPIRV(floor, Floor, 8, 64) UNARY_OP_TO_SPIRV(ceil, Ceil, 9, 64) UNARY_OP_TO_SPIRV(sin, Sin, 13, 32) @@ -658,6 +762,11 @@ class TaskCodegen : public IRVisitor { spirv::Value rhs_value = ir_->query_value(rhs_name); spirv::Value bin_value = spirv::Value(); + TI_WARN_IF(lhs_value.stype.id != rhs_value.stype.id, + "${} type {} != ${} type {}", lhs_name, + lhs_value.stype.dt->to_string(), rhs_name, + rhs_value.stype.dt->to_string()); + if (false) { } #define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ @@ -812,25 +921,124 @@ class TaskCodegen : public IRVisitor { TI_ASSERT(stmt->width() == 1); const auto dt = stmt->dest->element_type().ptr_removed(); + spirv::Value data = ir_->query_value(stmt->val->raw_name()); + spirv::Value val; + bool use_subgroup_reduction = false; + + if (stmt->is_reduction && + device_->get_cap(DeviceCapability::spirv_has_subgroup_arithmetic)) { + spv::Op atomic_op = spv::OpNop; + bool negation = false; + if (is_integral(dt)) { + if (stmt->op_type == AtomicOpType::add) { + atomic_op = spv::OpGroupIAdd; + } else if (stmt->op_type == AtomicOpType::sub) { + atomic_op = spv::OpGroupIAdd; + negation = true; + } else if (stmt->op_type == AtomicOpType::min) { + atomic_op = is_signed(dt) ? spv::OpGroupSMin : spv::OpGroupUMin; + } else if (stmt->op_type == AtomicOpType::max) { + atomic_op = is_signed(dt) ? spv::OpGroupSMax : spv::OpGroupUMax; + } + } else if (is_real(dt)) { + if (stmt->op_type == AtomicOpType::add) { + atomic_op = spv::OpGroupFAdd; + } else if (stmt->op_type == AtomicOpType::sub) { + atomic_op = spv::OpGroupFAdd; + negation = true; + } else if (stmt->op_type == AtomicOpType::min) { + atomic_op = spv::OpGroupFMin; + } else if (stmt->op_type == AtomicOpType::max) { + atomic_op = spv::OpGroupFMax; + } + } + + if (atomic_op != spv::OpNop) { + spirv::Value scope_subgroup = + ir_->int_immediate_number(ir_->i32_type(), 3); + spirv::Value operation_reduce = ir_->const_i32_zero_; + if (negation) { + if (is_integral(dt)) { + data = ir_->make_value(spv::OpSNegate, data.stype, data); + } else { + data = ir_->make_value(spv::OpFNegate, data.stype, data); + } + } + data = ir_->make_value(atomic_op, ir_->get_primitive_type(dt), + scope_subgroup, operation_reduce, data); + val = data; + use_subgroup_reduction = true; + } + } + + spirv::Label then_label; + spirv::Label merge_label; + + if (use_subgroup_reduction) { + spirv::Value subgroup_id = ir_->get_subgroup_invocation_id(); + spirv::Value cond = ir_->make_value(spv::OpIEqual, ir_->bool_type(), + subgroup_id, ir_->const_i32_zero_); + + then_label = ir_->new_label(); + merge_label = ir_->new_label(); + ir_->make_inst(spv::OpSelectionMerge, merge_label, + spv::SelectionControlMaskNone); + ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label); + ir_->start_label(then_label); + } + spirv::Value addr_ptr; - bool is_compiled_struct = false; - if (ptr_to_buffers_.at(stmt->dest).type == BufferType::Root) { - addr_ptr = ir_->query_value(stmt->dest->raw_name()); - addr_ptr.flag = - spirv::ValueKind::kVariablePtr; // make this value could store/load - is_compiled_struct = true; + + if (dt->is_primitive(PrimitiveTypeID::f64)) { + if (device_->get_cap(DeviceCapability::spirv_has_atomic_float64_add) && + stmt->op_type == AtomicOpType::add) { + addr_ptr = at_buffer(stmt->dest, dt); + } else { + addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt)); + } + } else if (dt->is_primitive(PrimitiveTypeID::f32)) { + if (device_->get_cap(DeviceCapability::spirv_has_atomic_float_add) && + stmt->op_type == AtomicOpType::add) { + addr_ptr = at_buffer(stmt->dest, dt); + } else { + addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt)); + } } else { addr_ptr = at_buffer(stmt->dest, dt); } - spirv::Value data = ir_->query_value(stmt->val->raw_name()); - spirv::Value val; - if (dt->is_primitive(PrimitiveTypeID::f32)) { - if (device_->get_cap(DeviceCapability::spirv_has_atomic_float_add) && - stmt->op_type == AtomicOpType::add && is_compiled_struct) { - val = ir_->make_value( - spv::OpAtomicFAddEXT, ir_->get_primitive_type(dt), addr_ptr, - ir_->uint_immediate_number(ir_->u32_type(), 1), - ir_->uint_immediate_number(ir_->u32_type(), 0), data); + + auto ret_type = ir_->get_primitive_type(dt); + + if (is_real(dt)) { + spv::Op atomic_fp_op; + if (stmt->op_type == AtomicOpType::add) { + atomic_fp_op = spv::OpAtomicFAddEXT; + } + + bool use_native_atomics = false; + + if (dt->is_primitive(PrimitiveTypeID::f64)) { + if (device_->get_cap(DeviceCapability::spirv_has_atomic_float64_add) && + stmt->op_type == AtomicOpType::add) { + use_native_atomics = true; + } + } else if (dt->is_primitive(PrimitiveTypeID::f32)) { + if (device_->get_cap(DeviceCapability::spirv_has_atomic_float_add) && + stmt->op_type == AtomicOpType::add) { + use_native_atomics = true; + } + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + if (device_->get_cap(DeviceCapability::spirv_has_atomic_float16_add) && + stmt->op_type == AtomicOpType::add) { + use_native_atomics = true; + } + } + + if (use_native_atomics) { + val = + ir_->make_value(atomic_fp_op, ir_->get_primitive_type(dt), addr_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, data); } else { val = ir_->float_atomic(stmt->op_type, addr_ptr, data); } @@ -854,13 +1062,34 @@ class TaskCodegen : public IRVisitor { TI_NOT_IMPLEMENTED } - val = - ir_->make_value(op, ir_->get_primitive_type(dt), addr_ptr, - ir_->uint_immediate_number(ir_->u32_type(), 1), - ir_->uint_immediate_number(ir_->u32_type(), 0), data); + auto uint_type = ir_->get_primitive_uint_type(dt); + + if (data.stype.id != addr_ptr.stype.element_type_id) { + data = ir_->make_value(spv::OpBitcast, ret_type, data); + } + + // Semantics = (UniformMemory 0x40) | (AcquireRelease 0x8) + ir_->make_inst( + spv::OpMemoryBarrier, ir_->const_i32_one_, + ir_->uint_immediate_number( + ir_->u32_type(), spv::MemorySemanticsAcquireReleaseMask | + spv::MemorySemanticsUniformMemoryMask)); + val = ir_->make_value(op, ret_type, addr_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, data); + + if (val.stype.id != ret_type.id) { + val = ir_->make_value(spv::OpBitcast, ret_type, val); + } } else { - TI_ERROR("Vulkan only supports 32-bit atomic data types"); + TI_NOT_IMPLEMENTED } + + if (use_subgroup_reduction) { + ir_->make_inst(spv::OpBranch, merge_label); + ir_->start_label(merge_label); + } + ir_->register_value(stmt->raw_name(), val); } @@ -1005,7 +1234,16 @@ class TaskCodegen : public IRVisitor { } void visit(ContinueStmt *stmt) override { - if (stmt->as_return()) { + auto stmt_in_off_for = [stmt]() { + TI_ASSERT(stmt->scope != nullptr); + if (auto *offl = stmt->scope->cast(); offl) { + TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || + offl->task_type == OffloadedStmt::TaskType::struct_for); + return true; + } + return false; + }; + if (stmt_in_off_for()) { // Return means end THIS main loop and start next loop, not exit kernel ir_->make_inst(spv::OpBranch, return_label()); } else { @@ -1017,18 +1255,21 @@ class TaskCodegen : public IRVisitor { private: void emit_headers() { + /* for (int root = 0; root < compiled_structs_.size(); ++root) { get_buffer_value({BufferType::Root, root}); } + */ std::array group_size = { task_attribs_.advisory_num_threads_per_group, 1, 1}; ir_->set_work_group_size(group_size); std::vector buffers; if (device_->get_cap(DeviceCapability::spirv_version) > 0x10300) { for (const auto &bb : task_attribs_.buffer_binds) { - const auto it = buffer_value_map_.find(bb.buffer); - if (it != buffer_value_map_.end()) { - buffers.push_back(it->second); + for (auto &it : buffer_value_map_) { + if (it.first.first == bb.buffer) { + buffers.push_back(it.second); + } } } } @@ -1039,7 +1280,7 @@ class TaskCodegen : public IRVisitor { void generate_serial_kernel(OffloadedStmt *stmt) { task_attribs_.name = task_name_; task_attribs_.task_type = OffloadedTaskType::serial; - task_attribs_.buffer_binds = get_common_buffer_binds(); + // task_attribs_.buffer_binds = get_common_buffer_binds(); task_attribs_.advisory_total_num_threads = 1; task_attribs_.advisory_num_threads_per_group = 1; @@ -1066,12 +1307,23 @@ class TaskCodegen : public IRVisitor { ir_->start_label(merge_label); ir_->make_inst(spv::OpReturn); // return; ir_->make_inst(spv::OpFunctionEnd); // } Close kernel + + task_attribs_.buffer_binds = get_buffer_binds(); + } + + void gen_array_range(Stmt *stmt) { + int num_operands = stmt->num_operands(); + for (int i = 0; i < num_operands; i++) { + gen_array_range(stmt->operand(i)); + } + offload_loop_motion_.insert(stmt); + stmt->accept(this); } void generate_range_for_kernel(OffloadedStmt *stmt) { task_attribs_.name = task_name_; task_attribs_.task_type = OffloadedTaskType::range_for; - task_attribs_.buffer_binds = get_common_buffer_binds(); + // task_attribs_.buffer_binds = get_common_buffer_binds(); task_attribs_.range_for_attribs = TaskAttributes::RangeForAttributes(); auto &range_for_attribs = task_attribs_.range_for_attribs.value(); @@ -1094,34 +1346,46 @@ class TaskCodegen : public IRVisitor { false); // Named Constant task_attribs_.advisory_total_num_threads = num_elems; } else { - if (!stmt->const_begin) { - spirv::Value begin_idx = ir_->make_value( - spv::OpShiftRightArithmetic, ir_->i32_type(), - ir_->int_immediate_number(ir_->i32_type(), stmt->begin_offset), - ir_->int_immediate_number(ir_->i32_type(), 2)); - begin_expr_value = ir_->load_variable( - ir_->struct_array_access(ir_->i32_type(), - get_buffer_value(BufferType::GlobalTmps), - begin_idx), - ir_->i32_type()); - } else { - begin_expr_value = ir_->int_immediate_number( - ir_->i32_type(), stmt->begin_value, false); // Named Constant - } spirv::Value end_expr_value; - if (!stmt->const_end) { - spirv::Value end_idx = ir_->make_value( - spv::OpShiftRightArithmetic, ir_->i32_type(), - ir_->int_immediate_number(ir_->i32_type(), stmt->end_offset), - ir_->int_immediate_number(ir_->i32_type(), 2)); - end_expr_value = ir_->load_variable( - ir_->struct_array_access(ir_->i32_type(), - get_buffer_value(BufferType::GlobalTmps), - end_idx), - ir_->i32_type()); + if (stmt->end_stmt) { + // Range from args + TI_ASSERT(stmt->const_begin); + begin_expr_value = ir_->int_immediate_number(ir_->i32_type(), + stmt->begin_value, false); + gen_array_range(stmt->end_stmt); + end_expr_value = ir_->query_value(stmt->end_stmt->raw_name()); } else { - end_expr_value = - ir_->int_immediate_number(ir_->i32_type(), stmt->end_value, true); + // Range from gtmp / constant + if (!stmt->const_begin) { + spirv::Value begin_idx = ir_->make_value( + spv::OpShiftRightArithmetic, ir_->i32_type(), + ir_->int_immediate_number(ir_->i32_type(), stmt->begin_offset), + ir_->int_immediate_number(ir_->i32_type(), 2)); + begin_expr_value = ir_->load_variable( + ir_->struct_array_access( + ir_->i32_type(), + get_buffer_value(BufferType::GlobalTmps, PrimitiveType::i32), + begin_idx), + ir_->i32_type()); + } else { + begin_expr_value = ir_->int_immediate_number( + ir_->i32_type(), stmt->begin_value, false); // Named Constant + } + if (!stmt->const_end) { + spirv::Value end_idx = ir_->make_value( + spv::OpShiftRightArithmetic, ir_->i32_type(), + ir_->int_immediate_number(ir_->i32_type(), stmt->end_offset), + ir_->int_immediate_number(ir_->i32_type(), 2)); + end_expr_value = ir_->load_variable( + ir_->struct_array_access( + ir_->i32_type(), + get_buffer_value(BufferType::GlobalTmps, PrimitiveType::i32), + end_idx), + ir_->i32_type()); + } else { + end_expr_value = + ir_->int_immediate_number(ir_->i32_type(), stmt->end_value, true); + } } total_elems = ir_->sub(end_expr_value, begin_expr_value); task_attribs_.advisory_total_num_threads = kMaxNumThreadsGridStrideLoop; @@ -1142,20 +1406,20 @@ class TaskCodegen : public IRVisitor { // https://www.khronos.org/opengl/wiki/Compute_Shader#Inputs // HLSL & WGSL cross compilers do not support this builtin - /* spirv::Value total_invocs = ir_->cast( ir_->i32_type(), ir_->mul(ir_->get_num_work_groups(0), ir_->uint_immediate_number( ir_->u32_type(), task_attribs_.advisory_num_threads_per_group, true))); - */ + /* const int group_x = (task_attribs_.advisory_total_num_threads + task_attribs_.advisory_num_threads_per_group - 1) / task_attribs_.advisory_num_threads_per_group; spirv::Value total_invocs = ir_->uint_immediate_number( ir_->i32_type(), group_x * task_attribs_.advisory_num_threads_per_group, false); + */ ir_->debug(spv::OpName, total_invocs, total_invocs_name); @@ -1198,65 +1462,399 @@ class TaskCodegen : public IRVisitor { ir_->make_inst(spv::OpReturn); ir_->make_inst(spv::OpFunctionEnd); + + task_attribs_.buffer_binds = get_buffer_binds(); + } + + void generate_listgen_kernel(OffloadedStmt *stmt) { + task_attribs_.name = task_name_; + task_attribs_.task_type = OffloadedTaskType::listgen; + // task_attribs_.buffer_binds = get_common_buffer_binds(); + task_attribs_.advisory_total_num_threads = 1; + task_attribs_.advisory_num_threads_per_group = 32; + + auto snode = stmt->snode; + + TI_TRACE("Listgen for {}", snode->get_name()); + + std::vector snode_path; + std::vector snode_path_num_cells; + std::vector> + snode_path_index_start_bit; + int total_num_cells = 1; + int root_id = 0; + { + // Construct the SNode path to the chosen node + auto snode_head = snode; + std::array start_indices{0}; + do { + snode_path.push_back(snode_head); + snode_path_num_cells.push_back(total_num_cells); + snode_path_index_start_bit.push_back(start_indices); + total_num_cells *= snode_head->num_cells_per_container; + root_id = snode_head->id; + for (int i = 0; i < taichi_max_num_indices; i++) { + start_indices[i] += snode_head->extractors[i].num_bits; + } + } while ((snode_head = snode_head->parent)); + } + + const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; + const auto sn_desc = snode_descs.at(snode->id); + + for (int i = snode_path.size() - 1; i >= 0; i--) { + const auto &desc = snode_descs.at(snode_path[i]->id); + TI_TRACE("- {} ({})", snode_path[i]->get_name(), + snode_path[i]->type_name()); + TI_TRACE(" is_place: {}, num_axis: {}, num_cells: {}", + snode_path[i]->is_place(), snode_path[i]->num_active_indices, + desc.cells_per_container_pot()); + } + + ir_->start_function(kernel_function_); + + if (snode->type == SNodeType::bitmasked) { + task_attribs_.advisory_total_num_threads = total_num_cells; + int num_cells = snode->num_cells_per_container; + + TI_INFO("ListGen {} * {}", total_num_cells / num_cells, num_cells); + + auto listgen_buffer = + get_buffer_value(BufferInfo(BufferType::ListGen), PrimitiveType::i32); + auto invoc_index = ir_->get_global_invocation_id(0); + + auto container_ptr = make_pointer(0); + std::vector linear_indices(snode_path.size()); + for (int i = snode_path.size() - 1; i >= 0; i--) { + // Offset the ptr to the cell on layer up + SNode *this_snode = snode_path[i]; + const auto &this_snode_desc = snode_descs.at(this_snode->id); + + auto snode_linear_index = + ir_->uint_immediate_number(ir_->u32_type(), 0); + if (this_snode->num_active_indices) { + for (int idx = 0; idx < taichi_max_num_indices; idx++) { + if (this_snode->extractors[idx].active) { + auto axis_local_index = ir_->make_value( + spv::OpShiftRightLogical, ir_->u32_type(), invoc_index, + ir_->uint_immediate_number( + ir_->u32_type(), sn_desc.axis_start_bit[idx] + + snode_path_index_start_bit[i][idx])); + axis_local_index = ir_->make_value( + spv::OpBitwiseAnd, ir_->u32_type(), axis_local_index, + ir_->uint_immediate_number( + ir_->u32_type(), + (1 << this_snode->extractors[idx].num_bits) - 1)); + snode_linear_index = ir_->make_value( + spv::OpBitwiseOr, ir_->u32_type(), + ir_->make_value(spv::OpShiftLeftLogical, ir_->u32_type(), + snode_linear_index, + ir_->uint_immediate_number( + ir_->u32_type(), + this_snode->extractors[idx].num_bits)), + axis_local_index); + } + } + } + + if (i > 0) { + const auto &next_snode_desc = snode_descs.at(snode_path[i - 1]->id); + if (this_snode->num_active_indices) { + container_ptr = ir_->add( + container_ptr, + ir_->mul(snode_linear_index, + ir_->uint_immediate_number( + ir_->u32_type(), this_snode_desc.cell_stride))); + } else { + container_ptr = ir_->add( + container_ptr, + make_pointer(next_snode_desc.mem_offset_in_parent_cell)); + } + } + + linear_indices[i] = snode_linear_index; + } + + // Check current bitmask mask within the cell + auto index_is_active = + bitmasked_activation(ActivationOp::query, container_ptr, root_id, + snode, linear_indices[0]); + + auto if_then_label = ir_->new_label(); + auto if_merge_label = ir_->new_label(); + + ir_->make_inst(spv::OpSelectionMerge, if_merge_label, + spv::SelectionControlMaskNone); + ir_->make_inst(spv::OpBranchConditional, index_is_active, if_then_label, + if_merge_label); + // if (is_active) + { + ir_->start_label(if_then_label); + + auto listgen_count_ptr = ir_->struct_array_access( + ir_->u32_type(), listgen_buffer, ir_->const_i32_zero_); + auto index_count = ir_->make_value( + spv::OpAtomicIAdd, ir_->u32_type(), listgen_count_ptr, + /*scope=*/ir_->const_i32_one_, + /*semantics=*/ir_->const_i32_zero_, + ir_->uint_immediate_number(ir_->u32_type(), 1)); + auto listgen_index_ptr = ir_->struct_array_access( + ir_->u32_type(), listgen_buffer, + ir_->add(ir_->uint_immediate_number(ir_->u32_type(), 1), + index_count)); + ir_->store_variable(listgen_index_ptr, invoc_index); + ir_->make_inst(spv::OpBranch, if_merge_label); + } + ir_->start_label(if_merge_label); + } else if (snode->type == SNodeType::dense) { + // Why?? + } else { + TI_NOT_IMPLEMENTED; + } + + ir_->make_inst(spv::OpReturn); // return; + ir_->make_inst(spv::OpFunctionEnd); // } Close kernel + + task_attribs_.buffer_binds = get_buffer_binds(); + } + + void generate_struct_for_kernel(OffloadedStmt *stmt) { + task_attribs_.name = task_name_; + task_attribs_.task_type = OffloadedTaskType::struct_for; + // task_attribs_.buffer_binds = get_common_buffer_binds(); + task_attribs_.advisory_total_num_threads = 65536; + task_attribs_.advisory_num_threads_per_group = 128; + + // The computation for a single work is wrapped inside a function, so that + // we can do grid-strided loop. + ir_->start_function(kernel_function_); + const spirv::Label func_label = ir_->current_label(); + + auto snode = stmt->snode; + + auto listgen_buffer = + get_buffer_value(BufferType::ListGen, PrimitiveType::u32); + auto listgen_count_ptr = ir_->struct_array_access( + ir_->u32_type(), listgen_buffer, ir_->const_i32_zero_); + auto listgen_count = ir_->load_variable(listgen_count_ptr, ir_->u32_type()); + + auto invoc_index = ir_->get_global_invocation_id(0); + + spirv::Label loop_head = ir_->new_label(); + spirv::Label loop_body = ir_->new_label(); + spirv::Label loop_merge = ir_->new_label(); + + auto loop_index_var = ir_->alloca_variable(ir_->u32_type()); + ir_->store_variable(loop_index_var, invoc_index); + + ir_->make_inst(spv::OpBranch, loop_head); + ir_->start_label(loop_head); + // for (; index < list_size; index += gl_NumWorkGroups.x * + // gl_WorkGroupSize.x) + auto loop_index = ir_->load_variable(loop_index_var, ir_->u32_type()); + auto loop_cond = ir_->make_value(spv::OpULessThan, ir_->bool_type(), + loop_index, listgen_count); + ir_->make_inst(spv::OpLoopMerge, loop_merge, loop_body, + spv::LoopControlMaskNone); + ir_->make_inst(spv::OpBranchConditional, loop_cond, loop_body, loop_merge); + { + ir_->start_label(loop_body); + auto listgen_index_ptr = ir_->struct_array_access( + ir_->u32_type(), listgen_buffer, + ir_->add(ir_->uint_immediate_number(ir_->u32_type(), 1), loop_index)); + auto listgen_index = + ir_->load_variable(listgen_index_ptr, ir_->u32_type()); + + // kernel + ir_->register_value("ii", listgen_index); + stmt->body->accept(this); + + // continue + spirv::Value total_invocs = ir_->cast( + ir_->i32_type(), + ir_->mul(ir_->get_num_work_groups(0), + ir_->uint_immediate_number( + ir_->u32_type(), + task_attribs_.advisory_num_threads_per_group, true))); + auto next_index = ir_->add(loop_index, total_invocs); + ir_->store_variable(loop_index_var, next_index); + ir_->make_inst(spv::OpBranch, loop_head); + } + ir_->start_label(loop_merge); + + ir_->make_inst(spv::OpReturn); // return; + ir_->make_inst(spv::OpFunctionEnd); // } Close kernel + + task_attribs_.buffer_binds = get_buffer_binds(); } spirv::Value at_buffer(const Stmt *ptr, DataType dt) { - spirv::Value buffer = get_buffer_value(ptr_to_buffers_.at(ptr)); - // Hardcoded ">> 2" because we only support 32-bit for now. - // return fmt::format("({} >> 2)", s->raw_name()); spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); - spirv::Value idx_val = - ir_->make_value(spv::OpShiftRightArithmetic, ir_->i32_type(), ptr_val, - ir_->int_immediate_number(ir_->i32_type(), 2)); - spirv::Value ret = ir_->struct_array_access( - ir_->get_primitive_buffer_type( - ptr_to_buffers_.at(ptr).type == BufferType::Root, dt), - buffer, idx_val); + + if (ptr_val.stype.dt == PrimitiveType::u64) { + spirv::Value paddr_ptr = ir_->make_value( + spv::OpConvertUToPtr, + ir_->get_pointer_type(ir_->get_primitive_type(dt), + spv::StorageClassPhysicalStorageBuffer), + ptr_val); + paddr_ptr.flag = ValueKind::kPhysicalPtr; + return paddr_ptr; + } + + spirv::Value buffer = get_buffer_value(ptr_to_buffers_.at(ptr), dt); + size_t width = ir_->get_primitive_type_size(dt); + spirv::Value idx_val = ir_->make_value( + spv::OpShiftRightLogical, ptr_val.stype, ptr_val, + ir_->uint_immediate_number(ptr_val.stype, size_t(std::log2(width)))); + spirv::Value ret = + ir_->struct_array_access(ir_->get_primitive_type(dt), buffer, idx_val); + return ret; + } + + spirv::Value load_buffer(const Stmt *ptr, DataType dt) { + spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); + + DataType ti_buffer_type = ir_->get_taichi_uint_type(dt); + + if (ptr_val.stype.dt == PrimitiveType::u64) { + ti_buffer_type = dt; + } + + auto buf_ptr = at_buffer(ptr, ti_buffer_type); + auto val_bits = + ir_->load_variable(buf_ptr, ir_->get_primitive_type(ti_buffer_type)); + auto ret = ti_buffer_type == dt + ? val_bits + : ir_->make_value(spv::OpBitcast, + ir_->get_primitive_type(dt), val_bits); return ret; } - spirv::Value get_buffer_value(BufferInfo buffer) { - const auto it = buffer_value_map_.find(buffer); + void store_buffer(const Stmt *ptr, spirv::Value val) { + spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); + + DataType ti_buffer_type = ir_->get_taichi_uint_type(val.stype.dt); + + if (ptr_val.stype.dt == PrimitiveType::u64) { + ti_buffer_type = val.stype.dt; + } + + auto buf_ptr = at_buffer(ptr, ti_buffer_type); + auto val_bits = + val.stype.dt == ti_buffer_type + ? val + : ir_->make_value(spv::OpBitcast, + ir_->get_primitive_type(ti_buffer_type), val); + ir_->store_variable(buf_ptr, val_bits); + } + + spirv::Value get_buffer_value(BufferInfo buffer, DataType dt) { + auto type = ir_->get_primitive_type(dt); + auto key = std::make_pair(buffer, type.id); + + const auto it = buffer_value_map_.find(key); if (it != buffer_value_map_.end()) { return it->second; } - spirv::Value buffer_value; - if (buffer.type == BufferType::Root) { - spirv_snodes_[buffer.root_id] = compile_spirv_snode_structs( - ir_.get(), &compiled_structs_[buffer.root_id]); // Support native - // SNode structure - buffer_value = - ir_->buffer_argument(spirv_snodes_.at(buffer.root_id).root_stype, 0, - buffer_binding_map_[buffer]); - } else { - buffer_value = - ir_->buffer_argument(ir_->i32_type(), 0, buffer_binding_map_[buffer]); + if (buffer.type == BufferType::Args) { + buffer_binding_map_[key] = 0; + buffer_value_map_[key] = args_buffer_value_; + return args_buffer_value_; + } + + if (buffer.type == BufferType::Rets) { + buffer_binding_map_[key] = 1; + buffer_value_map_[key] = ret_buffer_value_; + return ret_buffer_value_; } - ir_->debug(spv::OpName, buffer_value, buffer_instance_name(buffer)); - buffer_value_map_[buffer] = buffer_value; + + // Binding head starts at 2, so we don't break args and rets + int binding = binding_head_++; + buffer_binding_map_[key] = binding; + + spirv::Value buffer_value = + ir_->buffer_argument(type, 0, binding, buffer_instance_name(buffer)); + buffer_value_map_[key] = buffer_value; TI_TRACE("buffer name = {}, value = {}", buffer_instance_name(buffer), buffer_value.id); return buffer_value; } - std::vector get_common_buffer_binds() { - std::vector result; - int binding = 0; - auto bind_buffer = [&](BufferInfo buffer) { - result.push_back({buffer, binding}); - buffer_binding_map_[buffer] = binding++; - }; + spirv::Value make_pointer(size_t offset) { + if (use_64bit_pointers) { + // This is hacky, should check out how to encode uint64 values in spirv + return ir_->cast(ir_->u64_type(), ir_->uint_immediate_number( + ir_->u32_type(), uint32_t(offset))); + } else { + return ir_->uint_immediate_number(ir_->u32_type(), uint32_t(offset)); + } + } - for (int root = 0; root < compiled_structs_.size(); ++root) { - bind_buffer({BufferType::Root, root}); + void compile_args_struct() { + if (!ctx_attribs_->has_args()) + return; + + std::vector> + struct_components_; + for (auto &arg : ctx_attribs_->args()) { + if (arg.is_array && + device_->get_cap( + DeviceCapability::spirv_has_physical_storage_buffer)) { + struct_components_.emplace_back(ir_->u64_type(), + "arg_ptr" + std::to_string(arg.index), + arg.offset_in_mem); + } else { + struct_components_.emplace_back(ir_->get_primitive_type(arg.dt), + "arg" + std::to_string(arg.index), + arg.offset_in_mem); + } + } + // A compromise for use in constants buffer + // where scalar arrays follow very weird packing rules + for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) { + struct_components_.emplace_back( + ir_->i32_type(), "extra_args" + std::to_string(i), + ctx_attribs_->extra_args_mem_offset() + i * 4); + } + args_struct_type_ = ir_->create_struct_type(struct_components_); + + args_buffer_value_ = + ir_->uniform_struct_argument(args_struct_type_, 0, 0, "args"); + } + + void compile_ret_struct() { + if (!ctx_attribs_->has_rets()) + return; + + std::vector> + struct_components_; + // Now we only have one ret + TI_ASSERT(ctx_attribs_->rets().size() == 1); + for (auto &ret : ctx_attribs_->rets()) { + if (auto tensor_type = ret.dt->cast()) { + struct_components_.emplace_back( + ir_->get_array_type( + ir_->get_primitive_type(tensor_type->get_element_type()), + tensor_type->get_num_elements()), + "ret" + std::to_string(ret.index), ret.offset_in_mem); + } else { + struct_components_.emplace_back( + ir_->get_array_type(ir_->get_primitive_type(ret.dt), 1), + "ret" + std::to_string(ret.index), ret.offset_in_mem); + } } + ret_struct_type_ = ir_->create_struct_type(struct_components_); - bind_buffer(BufferType::GlobalTmps); + ret_buffer_value_ = + ir_->buffer_struct_argument(ret_struct_type_, 0, 1, "rets"); + } - if (!ctx_attribs_->empty()) { - bind_buffer(BufferType::Context); + std::vector get_buffer_binds() { + std::vector result; + for (auto &[key, val] : buffer_binding_map_) { + result.push_back(BufferBind{key.first, int(val)}); } return result; } @@ -1286,16 +1884,37 @@ class TaskCodegen : public IRVisitor { Device *device_; + struct BufferInfoTypeTupleHasher { + std::size_t operator()(const std::pair &buf) const { + return BufferInfoHasher()(buf.first) ^ (buf.second << 5); + } + }; + + spirv::SType args_struct_type_; + spirv::Value args_buffer_value_; + + spirv::SType ret_struct_type_; + spirv::Value ret_buffer_value_; + std::shared_ptr ir_; // spirv binary code builder - std::unordered_map + std::unordered_map, + spirv::Value, + BufferInfoTypeTupleHasher> buffer_value_map_; - std::unordered_map + std::unordered_map, + uint32_t, + BufferInfoTypeTupleHasher> buffer_binding_map_; spirv::Value kernel_function_; spirv::Label kernel_return_label_; bool gen_label_{false}; + + int binding_head_{2}; // Args:0, Ret:1 + + /* std::unordered_map spirv_snodes_; // maps root id to spirv snode + */ OffloadedStmt *const task_ir_; // not owned std::vector compiled_structs_; @@ -1305,11 +1924,14 @@ class TaskCodegen : public IRVisitor { std::vector continue_label_stack_; std::vector merge_label_stack_; + std::unordered_set offload_loop_motion_; + TaskAttributes task_attribs_; std::unordered_map root_stmts_; // maps root id to get root stmt std::unordered_map ptr_to_buffers_; }; +} // namespace static void spriv_message_consumer(spv_message_level_t level, const char *source, @@ -1331,126 +1953,119 @@ static void spriv_message_consumer(spv_message_level_t level, } } -class KernelCodegen { - public: - struct Params { - std::string ti_kernel_name; - Kernel *kernel; - std::vector compiled_structs; - Device *device; - bool enable_spv_opt{true}; - }; - - explicit KernelCodegen(const Params ¶ms) - : params_(params), ctx_attribs_(*params.kernel) { - spirv_opt_ = std::make_unique(SPV_ENV_VULKAN_1_2); - spirv_tools_ = std::make_unique(SPV_ENV_VULKAN_1_2); - - spirv_opt_->SetMessageConsumer(spriv_message_consumer); - - // TODO: Utilize this if KHR_memory_model is supported - // TODO: Profile these passes, remove ones we don't need to speed up JIT - // ref: - // https://github.com/KhronosGroup/SPIRV-Tools/blob/f9893c4549406eb9643e0eb05a521ab70a320fff/source/opt/optimizer.cpp - if (params.enable_spv_opt) { - spirv_opt_->RegisterPerformancePasses(); - } +KernelCodegen::KernelCodegen(const Params ¶ms) + : params_(params), ctx_attribs_(*params.kernel) { + spv_target_env target_env = SPV_ENV_VULKAN_1_0; + uint32_t spirv_version = + params.device->get_cap(DeviceCapability::spirv_version); + + if (spirv_version >= 0x10600) { + target_env = SPV_ENV_VULKAN_1_3; + } else if (spirv_version >= 0x10500) { + target_env = SPV_ENV_VULKAN_1_2; + } else if (spirv_version >= 0x10400) { + target_env = SPV_ENV_VULKAN_1_1_SPIRV_1_4; + } else if (spirv_version >= 0x10300) { + target_env = SPV_ENV_VULKAN_1_1; + } - _spirv_opt_options.set_run_validator(false); + spirv_opt_ = std::make_unique(target_env); + spirv_opt_->SetMessageConsumer(spriv_message_consumer); + if (params.enable_spv_opt) { + // From: SPIRV-Tools/source/opt/optimizer.cpp + spirv_opt_->RegisterPass(spvtools::CreateWrapOpKillPass()) + .RegisterPass(spvtools::CreateDeadBranchElimPass()) + .RegisterPass(spvtools::CreateMergeReturnPass()) + .RegisterPass(spvtools::CreateInlineExhaustivePass()) + .RegisterPass(spvtools::CreateEliminateDeadFunctionsPass()) + .RegisterPass(spvtools::CreateAggressiveDCEPass()) + .RegisterPass(spvtools::CreatePrivateToLocalPass()) + .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass()) + .RegisterPass(spvtools::CreateLocalSingleStoreElimPass()) + .RegisterPass(spvtools::CreateScalarReplacementPass()) + .RegisterPass(spvtools::CreateLocalAccessChainConvertPass()) + .RegisterPass(spvtools::CreateLocalMultiStoreElimPass()) + .RegisterPass(spvtools::CreateCCPPass()) + .RegisterPass(spvtools::CreateLoopUnrollPass(true)) + .RegisterPass(spvtools::CreateRedundancyEliminationPass()) + .RegisterPass(spvtools::CreateCombineAccessChainsPass()) + .RegisterPass(spvtools::CreateSimplificationPass()) + .RegisterPass(spvtools::CreateSSARewritePass()) + .RegisterPass(spvtools::CreateVectorDCEPass()) + .RegisterPass(spvtools::CreateDeadInsertElimPass()) + .RegisterPass(spvtools::CreateIfConversionPass()) + .RegisterPass(spvtools::CreateCopyPropagateArraysPass()) + .RegisterPass(spvtools::CreateReduceLoadSizePass()) + .RegisterPass(spvtools::CreateBlockMergePass()); } + spirv_opt_options_.set_run_validator(false); - using Result = VkRuntime::RegisterParams; + spirv_tools_ = std::make_unique(target_env); +} - Result run() { - Result res; - auto &kernel_attribs = res.kernel_attribs; - auto *root = params_.kernel->ir->as(); - auto &tasks = root->statements; - for (int i = 0; i < tasks.size(); ++i) { - TaskCodegen::Params tp; - tp.task_ir = tasks[i]->as(); - tp.task_id_in_kernel = i; - tp.compiled_structs = params_.compiled_structs; - tp.ctx_attribs = &ctx_attribs_; - tp.ti_kernel_name = params_.ti_kernel_name; - tp.device = params_.device; - - TaskCodegen cgen(tp); - auto task_res = cgen.run(); - - std::vector optimized_spv; - - TI_WARN_IF(!spirv_opt_->Run(task_res.spirv_code.data(), - task_res.spirv_code.size(), &optimized_spv, - _spirv_opt_options), - "SPIRV optimization failed"); - - TI_TRACE("SPIRV-Tools-opt: binary size, before={}, after={}", - task_res.spirv_code.size(), optimized_spv.size()); - - // Enable to dump SPIR-V assembly of kernels +void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, + std::vector> &generated_spirv) { + auto *root = params_.kernel->ir->as(); + auto &tasks = root->statements; + for (int i = 0; i < tasks.size(); ++i) { + TaskCodegen::Params tp; + tp.task_ir = tasks[i]->as(); + tp.task_id_in_kernel = i; + tp.compiled_structs = params_.compiled_structs; + tp.ctx_attribs = &ctx_attribs_; + tp.ti_kernel_name = fmt::format("{}_{}", params_.ti_kernel_name, i); + tp.device = params_.device; + + TaskCodegen cgen(tp); + auto task_res = cgen.run(); + + std::vector optimized_spv(task_res.spirv_code); + + size_t last_size; + do { + last_size = optimized_spv.size(); + bool result = false; + TI_WARN_IF( + (result = !spirv_opt_->Run(optimized_spv.data(), optimized_spv.size(), + &optimized_spv, spirv_opt_options_)), + "SPIRV optimization failed"); + if (result) + break; + } while (last_size != optimized_spv.size()); + + TI_TRACE("SPIRV-Tools-opt: binary size, before={}, after={}", + task_res.spirv_code.size(), optimized_spv.size()); + + // Enable to dump SPIR-V assembly of kernels #if 0 - std::string spirv_asm; - spirv_tools_->Disassemble(optimized_spv, &spirv_asm); - TI_WARN("SPIR-V Assembly dump for {} :\n{}\n\n", params_.ti_kernel_name, - spirv_asm); - - std::ofstream fout((params_.ti_kernel_name).c_str(), - std::ios::binary | std::ios::out); - fout.write(reinterpret_cast(task_res.spirv_code.data()), - task_res.spirv_code.size() * sizeof(uint32_t)); - fout.close(); + std::string spirv_asm; + spirv_tools_->Disassemble(optimized_spv, &spirv_asm); + auto kernel_name = tp.ti_kernel_name; + TI_WARN("SPIR-V Assembly dump for {} :\n{}\n\n", kernel_name, spirv_asm); + + std::ofstream fout(kernel_name + ".spv", std::ios::binary | std::ios::out); + fout.write(reinterpret_cast(optimized_spv.data()), + optimized_spv.size() * sizeof(uint32_t)); + fout.close(); #endif - kernel_attribs.tasks_attribs.push_back(std::move(task_res.task_attribs)); - res.task_spirv_source_codes.push_back(std::move(optimized_spv)); - } - kernel_attribs.ctx_attribs = std::move(ctx_attribs_); - kernel_attribs.name = params_.ti_kernel_name; - kernel_attribs.is_jit_evaluator = params_.kernel->is_evaluator; - return res; + kernel_attribs.tasks_attribs.push_back(std::move(task_res.task_attribs)); + generated_spirv.push_back(std::move(optimized_spv)); } - - private: - Params params_; - KernelContextAttributes ctx_attribs_; - - std::unique_ptr spirv_opt_; - std::unique_ptr spirv_tools_; - spvtools::OptimizerOptions _spirv_opt_options; -}; - -} // namespace + kernel_attribs.ctx_attribs = std::move(ctx_attribs_); + kernel_attribs.name = params_.ti_kernel_name; + kernel_attribs.is_jit_evaluator = params_.kernel->is_evaluator; +} void lower(Kernel *kernel) { auto &config = kernel->program->config; config.demote_dense_struct_fors = true; - irpass::compile_to_executable(kernel->ir.get(), config, kernel, - /*vectorize=*/false, kernel->grad, + irpass::compile_to_executable(kernel->ir.get(), config, kernel, kernel->grad, /*ad_use_stack=*/false, config.print_ir, /*lower_global_access=*/true, /*make_thread_local=*/false); } -FunctionType compile_to_executable(Kernel *kernel, VkRuntime *runtime) { - const auto id = Program::get_kernel_id(); - const auto taichi_kernel_name(fmt::format("{}_k{:04d}_vk", kernel->name, id)); - TI_TRACE("VK codegen for Taichi kernel={}", taichi_kernel_name); - KernelCodegen::Params params; - params.ti_kernel_name = taichi_kernel_name; - params.kernel = kernel; - params.compiled_structs = runtime->get_compiled_structs(); - params.device = runtime->get_ti_device(); - params.enable_spv_opt = - kernel->program->config.external_optimization_level > 0; - KernelCodegen codegen(params); - auto res = codegen.run(); - auto handle = runtime->register_taichi_kernel(std::move(res)); - return [runtime, handle, taichi_kernel_name](Context &ctx) { - runtime->launch_kernel(handle, &ctx); - }; -} - -} // namespace vulkan +} // namespace spirv } // namespace lang } // namespace taichi diff --git a/taichi/codegen/spirv/spirv_codegen.h b/taichi/codegen/spirv/spirv_codegen.h new file mode 100644 index 0000000000000..9858cd8fb183f --- /dev/null +++ b/taichi/codegen/spirv/spirv_codegen.h @@ -0,0 +1,46 @@ +#pragma once + +#include "taichi/lang_util.h" + +#include "taichi/codegen/spirv/snode_struct_compiler.h" +#include "taichi/codegen/spirv/kernel_utils.h" + +#include +#include + +namespace taichi { +namespace lang { + +class Kernel; + +namespace spirv { + +void lower(Kernel *kernel); + +class KernelCodegen { + public: + struct Params { + std::string ti_kernel_name; + Kernel *kernel; + std::vector compiled_structs; + Device *device; + bool enable_spv_opt{true}; + }; + + explicit KernelCodegen(const Params ¶ms); + + void run(TaichiKernelAttributes &kernel_attribs, + std::vector> &generated_spirv); + + private: + Params params_; + KernelContextAttributes ctx_attribs_; + + std::unique_ptr spirv_opt_{nullptr}; + std::unique_ptr spirv_tools_{nullptr}; + spvtools::OptimizerOptions spirv_opt_options_; +}; + +} // namespace spirv +} // namespace lang +} // namespace taichi diff --git a/taichi/backends/vulkan/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp similarity index 68% rename from taichi/backends/vulkan/spirv_ir_builder.cpp rename to taichi/codegen/spirv/spirv_ir_builder.cpp index 0203558c36ff1..18a95b67c57c4 100644 --- a/taichi/backends/vulkan/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -1,8 +1,7 @@ -#include "taichi/backends/vulkan/spirv_ir_builder.h" +#include "taichi/codegen/spirv/spirv_ir_builder.h" namespace taichi { namespace lang { -namespace vulkan { namespace spirv { @@ -70,6 +69,11 @@ void IRBuilder::init_header() { if (device_->get_cap(cap::spirv_has_float64)) { ib_.begin(spv::OpCapability).add(spv::CapabilityFloat64).commit(&header_); } + if (device_->get_cap(cap::spirv_has_physical_storage_buffer)) { + ib_.begin(spv::OpCapability) + .add(spv::CapabilityPhysicalStorageBufferAddresses) + .commit(&header_); + } ib_.begin(spv::OpExtension) .add("SPV_KHR_storage_buffer_storage_class") @@ -93,10 +97,21 @@ void IRBuilder::init_header() { .commit(&header_); } - // memory model - ib_.begin(spv::OpMemoryModel) - .add_seq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) - .commit(&entry_); + if (device_->get_cap(cap::spirv_has_physical_storage_buffer)) { + ib_.begin(spv::OpExtension) + .add("SPV_KHR_physical_storage_buffer") + .commit(&header_); + + // memory model + ib_.begin(spv::OpMemoryModel) + .add_seq(spv::AddressingModelPhysicalStorageBuffer64, + spv::MemoryModelGLSL450) + .commit(&entry_); + } else { + ib_.begin(spv::OpMemoryModel) + .add_seq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) + .commit(&entry_); + } this->init_pre_defs(); } @@ -135,6 +150,9 @@ void IRBuilder::init_pre_defs() { t_uint64_ = declare_primitive_type(get_data_type()); } t_fp32_ = declare_primitive_type(get_data_type()); + if (device_->get_cap(cap::spirv_has_float16)) { + t_fp16_ = declare_primitive_type(PrimitiveType::f16); + } if (device_->get_cap(cap::spirv_has_float64)) { t_fp64_ = declare_primitive_type(get_data_type()); } @@ -176,25 +194,25 @@ PhiValue IRBuilder::make_phi(const SType &out_type, uint32_t num_incoming) { Value IRBuilder::int_immediate_number(const SType &dtype, int64_t value, bool cache) { - return get_const_(dtype, reinterpret_cast(&value), cache); + return get_const(dtype, reinterpret_cast(&value), cache); } Value IRBuilder::uint_immediate_number(const SType &dtype, uint64_t value, bool cache) { - return get_const_(dtype, &value, cache); + return get_const(dtype, &value, cache); } Value IRBuilder::float_immediate_number(const SType &dtype, double value, bool cache) { if (data_type_bits(dtype.dt) == 64) { - return get_const_(dtype, reinterpret_cast(&value), cache); + return get_const(dtype, reinterpret_cast(&value), cache); } else if (data_type_bits(dtype.dt) == 32) { float fvalue = static_cast(value); uint32_t *ptr = reinterpret_cast(&fvalue); uint64_t data = ptr[0]; - return get_const_(dtype, &data, cache); + return get_const(dtype, &data, cache); } else { TI_ERROR("Type {} not supported.", dtype.dt->to_string()); } @@ -209,6 +227,10 @@ SType IRBuilder::get_null_type() { SType IRBuilder::get_primitive_type(const DataType &dt) const { if (dt->is_primitive(PrimitiveTypeID::u1)) { return t_bool_; + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + if (!device_->get_cap(cap::spirv_has_float16)) + TI_ERROR("Type {} not supported.", dt->to_string()); + return t_fp16_; } else if (dt->is_primitive(PrimitiveTypeID::f32)) { return t_fp32_; } else if (dt->is_primitive(PrimitiveTypeID::f64)) { @@ -248,21 +270,49 @@ SType IRBuilder::get_primitive_type(const DataType &dt) const { } } -SType IRBuilder::get_primitive_buffer_type(const bool struct_compiled, - const DataType &dt) const { - if (struct_compiled) { - if (dt->is_primitive(PrimitiveTypeID::f32) && - device_->get_cap(cap::spirv_has_atomic_float_add)) { - return t_fp32_; - } else if (dt->is_primitive(PrimitiveTypeID::f64) && - device_->get_cap(cap::spirv_has_atomic_float64_add)) { - return t_fp64_; - } else if (dt->is_primitive(PrimitiveTypeID::i64) && - device_->get_cap(cap::spirv_has_atomic_i64)) { - return t_int64_; - } +size_t IRBuilder::get_primitive_type_size(const DataType &dt) const { + if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 || + dt == PrimitiveType::f64) { + return 8; + } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 || + dt == PrimitiveType::f32) { + return 4; + } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 || + dt == PrimitiveType::f16) { + return 2; + } else { + return 1; + } +} + +SType IRBuilder::get_primitive_uint_type(const DataType &dt) const { + if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 || + dt == PrimitiveType::f64) { + return t_uint64_; + } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 || + dt == PrimitiveType::f32) { + return t_uint32_; + } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 || + dt == PrimitiveType::f16) { + return t_uint16_; + } else { + return t_uint8_; + } +} + +DataType IRBuilder::get_taichi_uint_type(const DataType &dt) const { + if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 || + dt == PrimitiveType::f64) { + return PrimitiveType::u64; + } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 || + dt == PrimitiveType::f32) { + return PrimitiveType::u32; + } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 || + dt == PrimitiveType::f16) { + return PrimitiveType::u16; + } else { + return PrimitiveType::u8; } - return t_int32_; } SType IRBuilder::get_pointer_type(const SType &value_type, @@ -284,8 +334,18 @@ SType IRBuilder::get_pointer_type(const SType &value_type, return t; } -SType IRBuilder::get_struct_array_type(const SType &value_type, - uint32_t num_elems) { +SType IRBuilder::get_storage_pointer_type(const SType &value_type) { + spv::StorageClass storage_class; + if (device_->get_cap(cap::spirv_version) < 0x10300) { + storage_class = spv::StorageClassUniform; + } else { + storage_class = spv::StorageClassStorageBuffer; + } + + return get_pointer_type(value_type, storage_class); +} + +SType IRBuilder::get_array_type(const SType &value_type, uint32_t num_elems) { SType arr_type; arr_type.id = id_counter_++; arr_type.flag = TypeKind::kPtr; @@ -323,6 +383,14 @@ SType IRBuilder::get_struct_array_type(const SType &value_type, // decorate the array type this->decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); + + return arr_type; +} + +SType IRBuilder::get_struct_array_type(const SType &value_type, + uint32_t num_elems) { + SType arr_type = get_array_type(value_type, num_elems); + // declare struct of array SType struct_type; struct_type.id = id_counter_++; @@ -348,9 +416,105 @@ SType IRBuilder::get_struct_array_type(const SType &value_type, return struct_type; } +SType IRBuilder::create_struct_type( + std::vector> &components) { + SType struct_type; + struct_type.id = id_counter_++; + struct_type.flag = TypeKind::kStruct; + + auto &builder = ib_.begin(spv::OpTypeStruct).add_seq(struct_type); + + for (auto &[type, name, offset] : components) { + builder.add_seq(type); + } + + builder.commit(&global_); + + int i = 0; + for (auto &[type, name, offset] : components) { + this->decorate(spv::OpMemberDecorate, struct_type, i, spv::DecorationOffset, + offset); + this->debug(spv::OpMemberName, struct_type, i, name); + i++; + } + + return struct_type; +} + +Value IRBuilder::buffer_struct_argument(const SType &struct_type, + uint32_t descriptor_set, + uint32_t binding, + const std::string &name) { + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. + spv::StorageClass storage_class; + if (device_->get_cap(cap::spirv_version) < 0x10300) { + storage_class = spv::StorageClassUniform; + } else { + storage_class = spv::StorageClassStorageBuffer; + } + + this->debug(spv::OpName, struct_type, name + "_t"); + + if (device_->get_cap(cap::spirv_version) < 0x10300) { + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. + // runtime array are always decorated as BufferBlock(shader storage buffer) + this->decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); + } else { + this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); + } + + SType ptr_type = get_pointer_type(struct_type, storage_class); + + this->debug(spv::OpName, ptr_type, name + "_ptr"); + + Value val = new_value(ptr_type, ValueKind::kStructArrayPtr); + ib_.begin(spv::OpVariable) + .add_seq(ptr_type, val, storage_class) + .commit(&global_); + + this->debug(spv::OpName, val, name); + + this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, + descriptor_set); + this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + return val; +} + +Value IRBuilder::uniform_struct_argument(const SType &struct_type, + uint32_t descriptor_set, + uint32_t binding, + const std::string &name) { + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. + spv::StorageClass storage_class = spv::StorageClassUniform; + + this->debug(spv::OpName, struct_type, name + "_t"); + + this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); + + SType ptr_type = get_pointer_type(struct_type, storage_class); + + this->debug(spv::OpName, ptr_type, name + "_ptr"); + + Value val = new_value(ptr_type, ValueKind::kStructArrayPtr); + ib_.begin(spv::OpVariable) + .add_seq(ptr_type, val, storage_class) + .commit(&global_); + + this->debug(spv::OpName, val, name); + + this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, + descriptor_set); + this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + return val; +} + Value IRBuilder::buffer_argument(const SType &value_type, uint32_t descriptor_set, - uint32_t binding) { + uint32_t binding, + const std::string &name) { // NOTE: BufferBlock was deprecated in SPIRV 1.3 // use StorageClassStorageBuffer instead. spv::StorageClass storage_class; @@ -361,12 +525,22 @@ Value IRBuilder::buffer_argument(const SType &value_type, } SType sarr_type = get_struct_array_type(value_type, 0); + + auto typed_name = name + "_" + value_type.dt.to_string(); + + this->debug(spv::OpName, sarr_type, typed_name + "_struct_array"); + SType ptr_type = get_pointer_type(sarr_type, storage_class); + + this->debug(spv::OpName, sarr_type, typed_name + "_ptr"); + Value val = new_value(ptr_type, ValueKind::kStructArrayPtr); ib_.begin(spv::OpVariable) .add_seq(ptr_type, val, storage_class) .commit(&global_); + this->debug(spv::OpName, val, typed_name); + this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); @@ -391,6 +565,7 @@ Value IRBuilder::struct_array_access(const SType &res_type, ib_.begin(spv::OpAccessChain) .add_seq(ptr_type, ret, buffer, const_i32_zero_, index) .commit(&function_); + return ret; } @@ -402,51 +577,70 @@ void IRBuilder::set_work_group_size(const std::array group_size) { Value size_z = uint_immediate_number(t_uint32_, static_cast(group_size[2])); - if (gl_work_group_size.id == 0) { - gl_work_group_size.id = id_counter_++; + if (gl_work_group_size_.id == 0) { + gl_work_group_size_.id = id_counter_++; } ib_.begin(spv::OpConstantComposite) - .add_seq(t_v3_uint_, gl_work_group_size, size_x, size_y, size_z) + .add_seq(t_v3_uint_, gl_work_group_size_, size_x, size_y, size_z) .commit(&global_); - this->decorate(spv::OpDecorate, gl_work_group_size, spv::DecorationBuiltIn, + this->decorate(spv::OpDecorate, gl_work_group_size_, spv::DecorationBuiltIn, spv::BuiltInWorkgroupSize); } Value IRBuilder::get_num_work_groups(uint32_t dim_index) { - if (gl_num_work_groups.id == 0) { + if (gl_num_work_groups_.id == 0) { SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput); - gl_num_work_groups = new_value(ptr_type, ValueKind::kVectorPtr); + gl_num_work_groups_ = new_value(ptr_type, ValueKind::kVectorPtr); ib_.begin(spv::OpVariable) - .add_seq(ptr_type, gl_num_work_groups, spv::StorageClassInput) + .add_seq(ptr_type, gl_num_work_groups_, spv::StorageClassInput) .commit(&global_); - this->decorate(spv::OpDecorate, gl_num_work_groups, spv::DecorationBuiltIn, + this->decorate(spv::OpDecorate, gl_num_work_groups_, spv::DecorationBuiltIn, spv::BuiltInNumWorkgroups); } SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); Value ptr = this->make_value( - spv::OpAccessChain, pint_type, gl_num_work_groups, + spv::OpAccessChain, pint_type, gl_num_work_groups_, uint_immediate_number(t_uint32_, static_cast(dim_index))); return this->make_value(spv::OpLoad, t_uint32_, ptr); } + Value IRBuilder::get_global_invocation_id(uint32_t dim_index) { - if (gl_global_invocation_id.id == 0) { + if (gl_global_invocation_id_.id == 0) { SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput); - gl_global_invocation_id = new_value(ptr_type, ValueKind::kVectorPtr); + gl_global_invocation_id_ = new_value(ptr_type, ValueKind::kVectorPtr); ib_.begin(spv::OpVariable) - .add_seq(ptr_type, gl_global_invocation_id, spv::StorageClassInput) + .add_seq(ptr_type, gl_global_invocation_id_, spv::StorageClassInput) .commit(&global_); - this->decorate(spv::OpDecorate, gl_global_invocation_id, + this->decorate(spv::OpDecorate, gl_global_invocation_id_, spv::DecorationBuiltIn, spv::BuiltInGlobalInvocationId); } SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); Value ptr = this->make_value( - spv::OpAccessChain, pint_type, gl_global_invocation_id, + spv::OpAccessChain, pint_type, gl_global_invocation_id_, uint_immediate_number(t_uint32_, static_cast(dim_index))); return this->make_value(spv::OpLoad, t_uint32_, ptr); } +Value IRBuilder::get_subgroup_invocation_id() { + if (subgroup_local_invocation_id_.id == 0) { + SType ptr_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); + subgroup_local_invocation_id_ = + new_value(ptr_type, ValueKind::kVariablePtr); + ib_.begin(spv::OpVariable) + .add_seq(ptr_type, subgroup_local_invocation_id_, + spv::StorageClassInput) + .commit(&global_); + this->decorate(spv::OpDecorate, subgroup_local_invocation_id_, + spv::DecorationBuiltIn, + spv::BuiltInSubgroupLocalInvocationId); + } + + return this->make_value(spv::OpLoad, t_uint32_, + subgroup_local_invocation_id_); +} + #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ TI_ASSERT(a.stype.id == b.stype.id); \ @@ -479,19 +673,8 @@ DEFINE_BUILDER_BINARY_SIGN_OP(div, Div); Value IRBuilder::mod(Value a, Value b) { TI_ASSERT(a.stype.id == b.stype.id); if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { - // a - b * int(float(a) / float(b)) - Value tmp1 = cast(t_fp32_, a); - Value tmp2 = cast(t_fp32_, b); - Value tmp3 = make_value(spv::OpFDiv, t_fp32_, tmp1, tmp2); - // Float division may lose precision - // FIXME: Could we have a better way to do this? - Value eps_p = float_immediate_number(t_fp32_, /*+eps=*/1e-5f, false); - Value eps_n = float_immediate_number(t_fp32_, /*-eps=*/-1e-5f, false); - Value eps = select(ge(tmp3, eps_p), eps_p, eps_n); - Value tmp3_float_fixed = make_value(spv::OpFAdd, t_fp32_, tmp3, eps); - Value tmp4 = cast(a.stype, tmp3_float_fixed); - Value tmp5 = make_value(spv::OpIMul, a.stype, b, tmp4); - return make_value(spv::OpISub, a.stype, a, tmp5); + // FIXME: figure out why OpSRem does not work + return sub(a, mul(b, div(a, b))); } else if (is_integral(a.stype.dt)) { return make_value(spv::OpUMod, a.stype, a, b); } else { @@ -638,15 +821,25 @@ Value IRBuilder::alloca_variable(const SType &type) { Value IRBuilder::load_variable(Value pointer, const SType &res_type) { TI_ASSERT(pointer.flag == ValueKind::kVariablePtr || - pointer.flag == ValueKind::kStructArrayPtr); + pointer.flag == ValueKind::kStructArrayPtr || + pointer.flag == ValueKind::kPhysicalPtr); Value ret = new_value(res_type, ValueKind::kNormal); ib_.begin(spv::OpLoad).add_seq(res_type, ret, pointer).commit(&function_); return ret; } void IRBuilder::store_variable(Value pointer, Value value) { - TI_ASSERT(pointer.flag == ValueKind::kVariablePtr); + TI_ASSERT(pointer.flag == ValueKind::kVariablePtr || + pointer.flag == ValueKind::kPhysicalPtr); TI_ASSERT(value.stype.id == pointer.stype.element_type_id); - ib_.begin(spv::OpStore).add_seq(pointer, value).commit(&function_); + if (pointer.flag == ValueKind::kPhysicalPtr) { + Value alignment = uint_immediate_number( + t_uint32_, get_primitive_type_size(value.stype.dt)); + ib_.begin(spv::OpStore) + .add_seq(pointer, value, spv::MemoryAccessAlignedMask, alignment) + .commit(&function_); + } else { + ib_.begin(spv::OpStore).add_seq(pointer, value).commit(&function_); + } } void IRBuilder::register_value(std::string name, Value value) { @@ -654,7 +847,9 @@ void IRBuilder::register_value(std::string name, Value value) { if (it != value_name_tbl_.end()) { TI_ERROR("{} is existed.", name); } - this->debug(spv::OpName, value, name); // Debug info + this->debug( + spv::OpName, value, + fmt::format("{}_{}", name, value.stype.dt.to_string())); // Debug info value_name_tbl_[name] = value; } @@ -663,127 +858,89 @@ Value IRBuilder::query_value(std::string name) const { if (it != value_name_tbl_.end()) { return it->second; } - TI_ERROR("{} is not existed.", name); + TI_ERROR("Value \"{}\" does not yet exist.", name); +} + +bool IRBuilder::check_value_existence(const std::string &name) const { + return value_name_tbl_.find(name) != value_name_tbl_.end(); } Value IRBuilder::float_atomic(AtomicOpType op_type, Value addr_ptr, Value data) { - auto atomic_func_ = [&](std::function atomic_op) { - // inline function begin - auto &func_ = function_; - Value old_val = alloca_variable(t_int32_); - Value new_val = alloca_variable(t_int32_); - Value cas_val = alloca_variable(t_int32_); - Value ok = alloca_variable(t_int32_); - - store_variable(old_val, const_i32_zero_); - store_variable(new_val, const_i32_zero_); - store_variable(cas_val, const_i32_zero_); - store_variable(ok, const_i32_zero_); - - // while - Label head_label = new_label(); - Label body_label = new_label(); - Label continue_label = new_label(); - Label merge_label = new_label(); - Label true_label = new_label(); - ib_.begin(spv::OpBranch).add(head_label).commit(&func_); - ib_.begin(spv::OpLabel).add(head_label).commit(&func_); - ib_.begin(spv::OpLoopMerge) - .add_seq(merge_label, continue_label, spv::LoopControlMaskNone) - .commit(&func_); - ib_.begin(spv::OpBranch).add(body_label).commit(&func_); - - // body part - ib_.begin(spv::OpLabel).add(body_label).commit(&func_); - Value tmp0 = load_variable(ok, t_int32_); - Value tmp1 = new_value(t_bool_, ValueKind::kNormal); - ib_.begin(spv::OpIEqual) - .add_seq(t_bool_, tmp1, tmp0, const_i32_zero_) - .commit(&func_); - ib_.begin(spv::OpBranchConditional) - .add_seq(tmp1, true_label, merge_label) - .commit(&func_); - ib_.begin(spv::OpLabel).add(true_label).commit(&func_); - Value tmp2 = load_variable(addr_ptr, t_fp32_); - Value tmp2_int = new_value(t_int32_, ValueKind::kNormal); - ib_.begin(spv::OpBitcast).add_seq(t_int32_, tmp2_int, tmp2).commit(&func_); - store_variable(old_val, tmp2_int); - Value tmp3 = load_variable(old_val, t_int32_); - Value tmp4 = new_value(t_fp32_, ValueKind::kNormal); - ib_.begin(spv::OpBitcast).add_seq(t_fp32_, tmp4, tmp3).commit(&func_); - Value tmp5 = new_value(t_fp32_, ValueKind::kNormal); - - // atomic operation - atomic_op(tmp5, tmp4, data); - - Value tmp6 = new_value(t_int32_, ValueKind::kNormal); - ib_.begin(spv::OpBitcast).add_seq(t_int32_, tmp6, tmp5).commit(&func_); - store_variable(new_val, tmp6); - Value tmp7 = load_variable(old_val, t_int32_); - Value tmp8 = load_variable(new_val, t_int32_); - Value tmp9 = new_value(t_int32_, ValueKind::kNormal); - auto const_u32_1 = uint_immediate_number(t_uint32_, 1); - auto const_u32_0 = uint_immediate_number(t_uint32_, 0); - ib_.begin(spv::OpAtomicCompareExchange) - .add_seq(t_int32_, tmp9, addr_ptr, const_u32_1, const_u32_0, - const_u32_0, tmp8, tmp7) - .commit(&func_); - store_variable(cas_val, tmp9); - Value tmp10 = load_variable(cas_val, t_int32_); - Value tmp11 = load_variable(old_val, t_int32_); - Value tmp12 = new_value(t_bool_, ValueKind::kNormal); - ib_.begin(spv::OpIEqual) - .add_seq(t_bool_, tmp12, tmp10, tmp11) - .commit(&func_); - Value tmp13 = new_value(t_int32_, ValueKind::kNormal); - ib_.begin(spv::OpSelect) - .add_seq(t_int32_, tmp13, tmp12, const_i32_one_, const_i32_zero_) - .commit(&func_); - store_variable(ok, tmp13); - ib_.begin(spv::OpBranch).add(continue_label).commit(&func_); - - // continue part - ib_.begin(spv::OpLabel).add(continue_label).commit(&func_); - ib_.begin(spv::OpBranch).add(head_label).commit(&func_); - - // merge part - ib_.begin(spv::OpLabel).add(merge_label).commit(&func_); - Value tmp14 = load_variable(old_val, t_int32_); - Value tmp15 = new_value(t_fp32_, ValueKind::kNormal); - ib_.begin(spv::OpBitcast).add_seq(t_fp32_, tmp15, tmp14).commit(&func_); - return tmp15; - // function end + auto atomic_func_ = [&](std::function atomic_op) { + Value ret_val_int = alloca_variable(t_uint32_); + + // do-while + Label head = new_label(); + Label body = new_label(); + Label branch_true = new_label(); + Label branch_false = new_label(); + Label merge = new_label(); + Label exit = new_label(); + + make_inst(spv::OpBranch, head); + make_inst(spv::OpLabel, head); + make_inst(spv::OpLoopMerge, branch_true, merge, 0); + make_inst(spv::OpBranch, body); + make_inst(spv::OpLabel, body); + // while (true) + { + // int old = addr_ptr[0]; + Value old_val = load_variable(addr_ptr, t_uint32_); + // int new = floatBitsToInt(atomic_op(intBitsToFloat(old), data)); + Value old_float = make_value(spv::OpBitcast, t_fp32_, old_val); + Value new_float = atomic_op(old_float, data); + Value new_val = make_value(spv::OpBitcast, t_uint32_, new_float); + // int loaded = atomicCompSwap(vals[0], old, new); + /* + * Don't need this part, theoretically + auto semantics = uint_immediate_number( + t_uint32_, spv::MemorySemanticsAcquireReleaseMask | + spv::MemorySemanticsUniformMemoryMask); + make_inst(spv::OpMemoryBarrier, const_i32_one_, semantics); + */ + Value loaded = make_value( + spv::OpAtomicCompareExchange, t_uint32_, addr_ptr, + /*scope=*/const_i32_one_, /*semantics if equal=*/const_i32_zero_, + /*semantics if unequal=*/const_i32_zero_, new_val, old_val); + // bool ok = (loaded == old); + Value ok = make_value(spv::OpIEqual, t_bool_, loaded, old_val); + // int ret_val_int = loaded; + store_variable(ret_val_int, loaded); + // if (ok) + make_inst(spv::OpSelectionMerge, branch_false, 0); + make_inst(spv::OpBranchConditional, ok, branch_true, branch_false); + { + make_inst(spv::OpLabel, branch_true); + make_inst(spv::OpBranch, exit); + } + // else + { + make_inst(spv::OpLabel, branch_false); + make_inst(spv::OpBranch, merge); + } + // continue; + make_inst(spv::OpLabel, merge); + make_inst(spv::OpBranch, head); + } + make_inst(spv::OpLabel, exit); + + return make_value(spv::OpBitcast, t_fp32_, + load_variable(ret_val_int, t_uint32_)); }; if (op_type == AtomicOpType::add) { - return atomic_func_([&](Value res, Value lhs, Value rhs) { - ib_.begin(spv::OpFAdd).add_seq(t_fp32_, res, lhs, rhs).commit(&function_); - }); + return atomic_func_([&](Value lhs, Value rhs) { return add(lhs, rhs); }); } else if (op_type == AtomicOpType::sub) { - return atomic_func_([&](Value res, Value lhs, Value rhs) { - ib_.begin(spv::OpFSub).add_seq(t_fp32_, res, lhs, rhs).commit(&function_); - }); + return atomic_func_([&](Value lhs, Value rhs) { return sub(lhs, rhs); }); } else if (op_type == AtomicOpType::min) { - return atomic_func_([&](Value res, Value lhs, Value rhs) { - Value cond = new_value(t_bool_, ValueKind::kNormal); - ib_.begin(spv::OpFOrdLessThan) - .add_seq(t_bool_, cond, lhs, rhs) - .commit(&function_); - ib_.begin(spv::OpSelect) - .add_seq(t_fp32_, res, cond, lhs, rhs) - .commit(&function_); + return atomic_func_([&](Value lhs, Value rhs) { + return call_glsl450(t_fp32_, /*FMin*/ 37, lhs, rhs); }); } else if (op_type == AtomicOpType::max) { - return atomic_func_([&](Value res, Value lhs, Value rhs) { - Value cond = new_value(t_bool_, ValueKind::kNormal); - ib_.begin(spv::OpFOrdGreaterThan) - .add_seq(t_bool_, cond, lhs, rhs) - .commit(&function_); - ib_.begin(spv::OpSelect) - .add_seq(t_fp32_, res, cond, lhs, rhs) - .commit(&function_); + return atomic_func_([&](Value lhs, Value rhs) { + return call_glsl450(t_fp32_, /*FMax*/ 40, lhs, rhs); }); } else { TI_NOT_IMPLEMENTED @@ -799,19 +956,19 @@ Value IRBuilder::rand_u32(Value global_tmp_) { Value _19u = uint_immediate_number(t_uint32_, 19u); Value _8u = uint_immediate_number(t_uint32_, 8u); Value _1000000007u = uint_immediate_number(t_uint32_, 1000000007u); - Value tmp0 = load_variable(_rand_x_, t_uint32_); + Value tmp0 = load_variable(rand_x_, t_uint32_); Value tmp1 = make_value(spv::OpShiftLeftLogical, t_uint32_, tmp0, _11u); Value tmp_t = make_value(spv::OpBitwiseXor, t_uint32_, tmp0, tmp1); // t - store_variable(_rand_x_, load_variable(_rand_y_, t_uint32_)); - store_variable(_rand_y_, load_variable(_rand_z_, t_uint32_)); - Value tmp_w = load_variable(_rand_w_, t_uint32_); // reuse w - store_variable(_rand_z_, tmp_w); + store_variable(rand_x_, load_variable(rand_y_, t_uint32_)); + store_variable(rand_y_, load_variable(rand_z_, t_uint32_)); + Value tmp_w = load_variable(rand_w_, t_uint32_); // reuse w + store_variable(rand_z_, tmp_w); Value tmp2 = make_value(spv::OpShiftRightLogical, t_uint32_, tmp_w, _19u); Value tmp3 = make_value(spv::OpBitwiseXor, t_uint32_, tmp_w, tmp2); Value tmp4 = make_value(spv::OpShiftRightLogical, t_uint32_, tmp_t, _8u); Value tmp5 = make_value(spv::OpBitwiseXor, t_uint32_, tmp_t, tmp4); Value new_w = make_value(spv::OpBitwiseXor, t_uint32_, tmp3, tmp5); - store_variable(_rand_w_, new_w); + store_variable(rand_w_, new_w); Value val = make_value(spv::OpIMul, t_uint32_, new_w, _1000000007u); return val; @@ -840,9 +997,9 @@ Value IRBuilder::rand_i32(Value global_tmp_) { return val; } -Value IRBuilder::get_const_(const SType &dtype, - const uint64_t *pvalue, - bool cache) { +Value IRBuilder::get_const(const SType &dtype, + const uint64_t *pvalue, + bool cache) { auto key = std::make_pair(dtype.id, pvalue[0]); if (cache) { auto it = const_tbl_.find(key); @@ -908,31 +1065,31 @@ SType IRBuilder::declare_primitive_type(DataType dt) { void IRBuilder::init_random_function(Value global_tmp_) { // variables declare SType local_type = get_pointer_type(t_uint32_, spv::StorageClassPrivate); - _rand_x_ = new_value(local_type, ValueKind::kVariablePtr); - _rand_y_ = new_value(local_type, ValueKind::kVariablePtr); - _rand_z_ = new_value(local_type, ValueKind::kVariablePtr); - _rand_w_ = new_value(local_type, ValueKind::kVariablePtr); - global_values.push_back(_rand_x_); - global_values.push_back(_rand_y_); - global_values.push_back(_rand_z_); - global_values.push_back(_rand_w_); + rand_x_ = new_value(local_type, ValueKind::kVariablePtr); + rand_y_ = new_value(local_type, ValueKind::kVariablePtr); + rand_z_ = new_value(local_type, ValueKind::kVariablePtr); + rand_w_ = new_value(local_type, ValueKind::kVariablePtr); + global_values.push_back(rand_x_); + global_values.push_back(rand_y_); + global_values.push_back(rand_z_); + global_values.push_back(rand_w_); ib_.begin(spv::OpVariable) - .add_seq(local_type, _rand_x_, spv::StorageClassPrivate) + .add_seq(local_type, rand_x_, spv::StorageClassPrivate) .commit(&global_); ib_.begin(spv::OpVariable) - .add_seq(local_type, _rand_y_, spv::StorageClassPrivate) + .add_seq(local_type, rand_y_, spv::StorageClassPrivate) .commit(&global_); ib_.begin(spv::OpVariable) - .add_seq(local_type, _rand_z_, spv::StorageClassPrivate) + .add_seq(local_type, rand_z_, spv::StorageClassPrivate) .commit(&global_); ib_.begin(spv::OpVariable) - .add_seq(local_type, _rand_w_, spv::StorageClassPrivate) + .add_seq(local_type, rand_w_, spv::StorageClassPrivate) .commit(&global_); - debug(spv::OpName, _rand_x_, "_rand_x"); - debug(spv::OpName, _rand_y_, "_rand_y"); - debug(spv::OpName, _rand_z_, "_rand_z"); - debug(spv::OpName, _rand_w_, "_rand_w"); - SType gtmp_type = get_pointer_type(t_int32_, spv::StorageClassStorageBuffer); + debug(spv::OpName, rand_x_, "_rand_x"); + debug(spv::OpName, rand_y_, "_rand_y"); + debug(spv::OpName, rand_z_, "_rand_z"); + debug(spv::OpName, rand_w_, "_rand_w"); + SType gtmp_type = get_pointer_type(t_uint32_, spv::StorageClassStorageBuffer); Value rand_gtmp_ = new_value(gtmp_type, ValueKind::kVariablePtr); debug(spv::OpName, rand_gtmp_, "rand_gtmp"); @@ -961,8 +1118,8 @@ void IRBuilder::init_random_function(Value global_tmp_) { Value _362436069u = uint_immediate_number(t_uint32_, 362436069u); Value _521288629u = uint_immediate_number(t_uint32_, 521288629u); Value _88675123u = uint_immediate_number(t_uint32_, 88675123u); - Value _1 = int_immediate_number(t_int32_, 1); - Value _1024 = int_immediate_number(t_int32_, 1024); + Value _1 = int_immediate_number(t_uint32_, 1); + Value _1024 = int_immediate_number(t_uint32_, 1024); // init_rand_ segment (inline to main) // ad-hoc: hope no kernel will use more than 1024 gtmp variables... @@ -974,11 +1131,11 @@ void IRBuilder::init_random_function(Value global_tmp_) { SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); Value tmp0 = new_value(pint_type, ValueKind::kVariablePtr); ib_.begin(spv::OpAccessChain) - .add_seq(pint_type, tmp0, gl_global_invocation_id, + .add_seq(pint_type, tmp0, gl_global_invocation_id_, uint_immediate_number(t_uint32_, 0)) .commit(&func_header_); Value tmp1 = load_var(tmp0, t_uint32_); - Value tmp2_ = load_var(rand_gtmp_, t_int32_); + Value tmp2_ = load_var(rand_gtmp_, t_uint32_); Value tmp2 = new_value(t_uint32_, ValueKind::kNormal); ib_.begin(spv::OpBitcast) .add_seq(t_uint32_, tmp2, tmp2_) @@ -1007,19 +1164,19 @@ void IRBuilder::init_random_function(Value global_tmp_) { ib_.begin(spv::OpIMul) .add_seq(t_uint32_, tmp8, _1000000007u, tmp7) .commit(&func_header_); - store_var(_rand_x_, tmp8); - store_var(_rand_y_, _362436069u); - store_var(_rand_z_, _521288629u); - store_var(_rand_w_, _88675123u); + store_var(rand_x_, tmp8); + store_var(rand_y_, _362436069u); + store_var(rand_z_, _521288629u); + store_var(rand_w_, _88675123u); // Yes, this is not an atomic operation, but just fine since no matter // how RAND_STATE changes, `gl_GlobalInvocationID.x` can still help // us to set different seeds for different threads. // Discussion: // https://github.com/taichi-dev/taichi/pull/912#discussion_r419021918 - Value tmp9 = load_var(rand_gtmp_, t_int32_); - Value tmp10 = new_value(t_int32_, ValueKind::kNormal); + Value tmp9 = load_var(rand_gtmp_, t_uint32_); + Value tmp10 = new_value(t_uint32_, ValueKind::kNormal); ib_.begin(spv::OpIAdd) - .add_seq(t_int32_, tmp10, tmp9, _1) + .add_seq(t_uint32_, tmp10, tmp9, _1) .commit(&func_header_); store_var(rand_gtmp_, tmp10); @@ -1027,6 +1184,5 @@ void IRBuilder::init_random_function(Value global_tmp_) { } } // namespace spirv -} // namespace vulkan } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/spirv_ir_builder.h b/taichi/codegen/spirv/spirv_ir_builder.h similarity index 83% rename from taichi/backends/vulkan/spirv_ir_builder.h rename to taichi/codegen/spirv/spirv_ir_builder.h index dcc116bff5fb7..9066e13971fc6 100644 --- a/taichi/backends/vulkan/spirv_ir_builder.h +++ b/taichi/codegen/spirv/spirv_ir_builder.h @@ -2,18 +2,16 @@ #include -#include "taichi/backends/vulkan/spirv_header.h" -#include "taichi/backends/vulkan/embedded_device.h" +#include #include "taichi/lang_util.h" #include "taichi/ir/type.h" #include "taichi/util/testing.h" -#include "taichi/backends/vulkan/snode_struct_compiler.h" +#include "taichi/codegen/spirv/snode_struct_compiler.h" +#include "taichi/backends/device.h" #include "taichi/ir/statements.h" namespace taichi { namespace lang { -namespace vulkan { - namespace spirv { template @@ -74,6 +72,7 @@ enum class ValueKind { kVectorPtr, kStructArrayPtr, kVariablePtr, + kPhysicalPtr, kFunction, kExtInst }; @@ -169,7 +168,7 @@ class InstrBuilder { InstrBuilder &add_seq(Args &&... args) { AddSeqHelper helper; helper.builder = this; - vulkan::spirv::for_each(helper, std::forward(args)...); + for_each(helper, std::forward(args)...); return *this; } @@ -271,6 +270,9 @@ class IRBuilder { Value make_value(spv::Op op, const SType &out_type, Args &&... args) { Value val = new_value(out_type, ValueKind::kNormal); make_inst(op, out_type, val, std::forward(args)...); + if (out_type.flag == TypeKind::kPtr) { + val.flag = ValueKind::kVariablePtr; + } return val; } @@ -311,19 +313,38 @@ class IRBuilder { SType get_null_type(); // Get the spirv type for a given Taichi data type SType get_primitive_type(const DataType &dt) const; - // Get the spirv type for the buffer for a given Taichi data type - SType get_primitive_buffer_type(const bool struct_compiled, - const DataType &dt) const; + // Get the size in bytes of a given Taichi data type + size_t get_primitive_type_size(const DataType &dt) const; + // Get the spirv uint type with the same size of a given Taichi data type + SType get_primitive_uint_type(const DataType &dt) const; + // Get the Taichi uint type with the same size of a given Taichi data type + DataType get_taichi_uint_type(const DataType &dt) const; + // Get the pointer type that points to value_type + SType get_storage_pointer_type(const SType &value_type); // Get the pointer type that points to value_type SType get_pointer_type(const SType &value_type, spv::StorageClass storage_class); + // Get a value_type[num_elems] type + SType get_array_type(const SType &value_type, uint32_t num_elems); // Get a struct{ value_type[num_elems] } type SType get_struct_array_type(const SType &value_type, uint32_t num_elems); + // Construct a struct type + SType create_struct_type( + std::vector> &components); // Declare buffer argument of function + Value buffer_struct_argument(const SType &struct_type, + uint32_t descriptor_set, + uint32_t binding, + const std::string &name); + Value uniform_struct_argument(const SType &struct_type, + uint32_t descriptor_set, + uint32_t binding, + const std::string &name); Value buffer_argument(const SType &value_type, uint32_t descriptor_set, - uint32_t binding); + uint32_t binding, + const std::string &name); Value struct_array_access(const SType &res_type, Value buffer, Value index); // Declare a new function @@ -349,11 +370,11 @@ class IRBuilder { ib_.add(v); } } - if (gl_global_invocation_id.id != 0) { - ib_.add(gl_global_invocation_id); + if (gl_global_invocation_id_.id != 0) { + ib_.add(gl_global_invocation_id_); } - if (gl_num_work_groups.id != 0) { - ib_.add(gl_num_work_groups); + if (gl_num_work_groups_.id != 0) { + ib_.add(gl_num_work_groups_); } ib_.commit(&entry_); ib_.begin(spv::OpExecutionMode) @@ -379,6 +400,7 @@ class IRBuilder { Value get_work_group_size(uint32_t dim_index); Value get_num_work_groups(uint32_t dim_index); Value get_global_invocation_id(uint32_t dim_index); + Value get_subgroup_invocation_id(); // Expressions Value add(Value a, Value b); @@ -417,8 +439,20 @@ class IRBuilder { void register_value(std::string name, Value value); // Query Value/VariablePointer by name Value query_value(std::string name) const; + // Check whether a value has been evaluated + bool check_value_existence(const std::string &name) const; // Support easy access to trivial data types + SType i64_type() const { + return t_int64_; + } + SType u64_type() const { + return t_uint64_; + } + SType f64_type() const { + return t_fp64_; + } + SType i32_type() const { return t_int32_; } @@ -428,6 +462,24 @@ class IRBuilder { SType f32_type() const { return t_fp32_; } + + SType i16_type() const { + return t_int16_; + } + SType u16_type() const { + return t_uint16_; + } + SType f16_type() const { + return t_fp16_; + } + + SType i8_type() const { + return t_int8_; + } + SType u8_type() const { + return t_uint8_; + } + SType bool_type() const { return t_bool_; } @@ -451,7 +503,7 @@ class IRBuilder { return val; } - Value get_const_(const SType &dtype, const uint64_t *pvalue, bool cache); + Value get_const(const SType &dtype, const uint64_t *pvalue, bool cache); SType declare_primitive_type(DataType dt); void init_random_function(Value global_tmp_); @@ -477,22 +529,24 @@ class IRBuilder { SType t_uint16_; SType t_uint32_; SType t_uint64_; + SType t_fp16_; SType t_fp32_; SType t_fp64_; SType t_void_; SType t_void_func_; // gl compute shader related type(s) and variables SType t_v3_uint_; - Value gl_global_invocation_id; - Value gl_num_work_groups; - Value gl_work_group_size; + Value gl_global_invocation_id_; + Value gl_num_work_groups_; + Value gl_work_group_size_; + Value subgroup_local_invocation_id_; // Random function and variables bool init_rand_{false}; - Value _rand_x_; - Value _rand_y_; - Value _rand_z_; - Value _rand_w_; // per-thread local variable + Value rand_x_; + Value rand_y_; + Value rand_z_; + Value rand_w_; // per-thread local variable // map from value to its pointer type std::map, SType> pointer_type_tbl_; @@ -519,6 +573,5 @@ class IRBuilder { std::vector function_; }; } // namespace spirv -} // namespace vulkan } // namespace lang } // namespace taichi diff --git a/taichi/common/commit_hash.h.in b/taichi/common/commit_hash.h.in new file mode 100644 index 0000000000000..4420f658ba3d1 --- /dev/null +++ b/taichi/common/commit_hash.h.in @@ -0,0 +1 @@ +#define TI_COMMIT_HASH "@TI_COMMIT_HASH@" diff --git a/taichi/common/core.cpp b/taichi/common/core.cpp index b7decb32a2616..6ab16c92b2a74 100644 --- a/taichi/common/core.cpp +++ b/taichi/common/core.cpp @@ -17,18 +17,6 @@ TI_NAMESPACE_BEGIN -extern "C" { -#if defined(TI_PLATFORM_LINUX) && defined(TI_ARCH_x64) -// Avoid dependency on glibc 2.27 -// log2f is used by a third party .a file, so we have to define a wrapper. -// https://stackoverflow.com/questions/8823267/linking-against-older-symbol-version-in-a-so-file -__asm__(".symver log2f,log2f@GLIBC_2.2.5"); -float __wrap_log2f(float x) { - return log2f(x); -} -#endif -} - std::string python_package_dir; std::string get_python_package_dir() { @@ -63,15 +51,15 @@ std::string get_version_string() { } int get_version_major() { - return std::atoi(TI_VERSION_MAJOR); + return TI_VERSION_MAJOR; } int get_version_minor() { - return std::atoi(TI_VERSION_MINOR); + return TI_VERSION_MINOR; } int get_version_patch() { - return std::atoi(TI_VERSION_PATCH); + return TI_VERSION_PATCH; } std::string get_commit_hash() { @@ -79,7 +67,7 @@ std::string get_commit_hash() { } std::string get_cuda_version_string() { - return TI_CUDAVERSION; + return CUDA_VERSION; } int PID::get_pid() { diff --git a/taichi/common/core.h b/taichi/common/core.h index ce8f752356f9d..6ec323eeab36f 100644 --- a/taichi/common/core.h +++ b/taichi/common/core.h @@ -90,10 +90,7 @@ static_assert(__cplusplus >= 201402L, "C++14 required."); #include "taichi/platform/windows/windows.h" #pragma warning(pop) #include -#define TI_EXPORT __declspec(dllexport) -#else -#define TI_EXPORT -#endif +#endif // _WIN64 #ifndef _WIN64 #define sscanf_s sscanf @@ -125,7 +122,7 @@ static_assert(__cplusplus >= 201402L, "C++14 required."); } \ } -TI_EXPORT void taichi_raise_assertion_failure_in_python(const char *msg); +void taichi_raise_assertion_failure_in_python(const char *msg); TI_NAMESPACE_BEGIN @@ -220,8 +217,6 @@ float64 constexpr operator"" _fd(unsigned long long v) { return float64(v); } -TI_EXPORT void print_traceback(); - TI_NAMESPACE_END //****************************************************************************** // Meta-programming @@ -309,15 +304,15 @@ void trash(T &&t) { } class DeferedExecution { - std::function statement; + std::function statement_; public: DeferedExecution(const std::function &statement) - : statement(statement) { + : statement_(statement) { } ~DeferedExecution() { - statement(); + statement_(); } }; diff --git a/taichi/common/dict.h b/taichi/common/dict.h index b485f25211437..2675765357e1e 100644 --- a/taichi/common/dict.h +++ b/taichi/common/dict.h @@ -26,10 +26,10 @@ TI_NAMESPACE_BEGIN class Dict { private: - std::map data; + std::map data_; public: - TI_IO_DEF(data); + TI_IO_DEF(data_); Dict() = default; @@ -40,14 +40,14 @@ class Dict { std::vector get_keys() const { std::vector keys; - for (auto it = data.begin(); it != data.end(); ++it) { + for (auto it = data_.begin(); it != data_.end(); ++it) { keys.push_back(it->first); } return keys; } void clear() { - data.clear(); + data_.clear(); } template @@ -156,7 +156,7 @@ class Dict { T get(std::string key, const T &default_val) const; bool has_key(std::string key) const { - return data.find(key) != data.end(); + return data_.find(key) != data_.end(); } std::vector get_string_arr(std::string key) const { @@ -206,56 +206,56 @@ class Dict { Dict &set(std::string name, T val) { std::stringstream ss; ss << val; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const char *val) { std::stringstream ss; ss << val; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const Vector2 &val) { std::stringstream ss; ss << "(" << val.x << "," << val.y << ")"; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const Vector3 &val) { std::stringstream ss; ss << "(" << val.x << "," << val.y << "," << val.z << ")"; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const Vector4 &val) { std::stringstream ss; ss << "(" << val.x << "," << val.y << "," << val.z << "," << val.w << ")"; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const Vector2i &val) { std::stringstream ss; ss << "(" << val.x << "," << val.y << ")"; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const Vector3i &val) { std::stringstream ss; ss << "(" << val.x << "," << val.y << "," << val.z << ")"; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } Dict &set(std::string name, const Vector4i &val) { std::stringstream ss; ss << "(" << val.x << "," << val.y << "," << val.z << "," << val.w << ")"; - data[name] = ss.str(); + data_[name] = ss.str(); return *this; } @@ -268,15 +268,15 @@ class Dict { template Dict &set(std::string name, T *const ptr) { - data[name] = get_ptr_string(ptr); + data_[name] = get_ptr_string(ptr); return *this; } std::string get_string(std::string key) const { - if (data.find(key) == data.end()) { + if (data_.find(key) == data_.end()) { TI_ERROR("No key named '{}' found.", key); } - return data.find(key)->second; + return data_.find(key)->second; } template @@ -293,14 +293,14 @@ inline std::string Dict::get(std::string key) const { template inline T Dict::get(std::string key, const T &default_val) const { - if (data.find(key) == data.end()) { + if (data_.find(key) == data_.end()) { return default_val; } else return get(key); } inline std::string Dict::get(std::string key, const char *default_val) const { - if (data.find(key) == data.end()) { + if (data_.find(key) == data_.end()) { return default_val; } else return get(key); diff --git a/taichi/common/exceptions.h b/taichi/common/exceptions.h new file mode 100644 index 0000000000000..8cdb4cfe40223 --- /dev/null +++ b/taichi/common/exceptions.h @@ -0,0 +1,32 @@ +#pragma once + +namespace taichi { +namespace lang { + +class IRModified {}; + +class TaichiExceptionImpl : public std::exception { + std::string msg_; + + public: + TaichiExceptionImpl(const std::string msg) : msg_(msg) { + } + const char *what() const throw() override { + return msg_.c_str(); + } +}; + +class TaichiTypeError : public TaichiExceptionImpl { + using TaichiExceptionImpl::TaichiExceptionImpl; +}; + +class TaichiSyntaxError : public TaichiExceptionImpl { + using TaichiExceptionImpl::TaichiExceptionImpl; +}; + +class TaichiRuntimeError : public TaichiExceptionImpl { + using TaichiExceptionImpl::TaichiExceptionImpl; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/common/interface.h b/taichi/common/interface.h index 418d3d099135f..2588124637f42 100644 --- a/taichi/common/interface.h +++ b/taichi/common/interface.h @@ -17,40 +17,38 @@ TI_NAMESPACE_BEGIN template -TI_EXPORT std::shared_ptr create_instance(const std::string &alias); +std::shared_ptr create_instance(const std::string &alias); template -TI_EXPORT std::shared_ptr create_instance(const std::string &alias, - const Config &config); +std::shared_ptr create_instance(const std::string &alias, + const Config &config); template -TI_EXPORT std::unique_ptr create_instance_unique(const std::string &alias); +std::unique_ptr create_instance_unique(const std::string &alias); template -TI_EXPORT std::unique_ptr create_instance_unique(const std::string &alias, - const Config &config); +std::unique_ptr create_instance_unique(const std::string &alias, + const Config &config); template -TI_EXPORT std::unique_ptr create_instance_unique_ctor( - const std::string &alias, - const Config &config); +std::unique_ptr create_instance_unique_ctor(const std::string &alias, + const Config &config); template -TI_EXPORT T *create_instance_raw(const std::string &alias); +T *create_instance_raw(const std::string &alias); template -TI_EXPORT T *create_instance_raw(const std::string &alias, - const Config &config); +T *create_instance_raw(const std::string &alias, const Config &config); template -TI_EXPORT T *create_instance_placement(const std::string &alias, void *place); +T *create_instance_placement(const std::string &alias, void *place); template -TI_EXPORT T *create_instance_placement(const std::string &alias, - void *place, - const Config &config); +T *create_instance_placement(const std::string &alias, + void *place, + const Config &config); template -TI_EXPORT std::vector get_implementation_names(); +std::vector get_implementation_names(); class Unit { public: @@ -237,77 +235,76 @@ class InterfaceHolder { }; \ extern TI_IMPLEMENTATION_HOLDER_NAME(T) * TI_IMPLEMENTATION_HOLDER_PTR(T); -#define TI_INTERFACE_DEF(class_name, base_alias) \ - template <> \ - TI_EXPORT std::shared_ptr create_instance( \ - const std::string &alias) { \ - return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance()->create( \ - alias); \ - } \ - template <> \ - TI_EXPORT std::shared_ptr create_instance( \ - const std::string &alias, const Config &config) { \ - auto instance = create_instance(alias); \ - instance->initialize(config); \ - return instance; \ - } \ - template <> \ - TI_EXPORT std::unique_ptr create_instance_unique( \ - const std::string &alias) { \ - return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ - ->create_unique(alias); \ - } \ - template <> \ - TI_EXPORT std::unique_ptr create_instance_unique( \ - const std::string &alias, const Config &config) { \ - auto instance = create_instance_unique(alias); \ - instance->initialize(config); \ - return instance; \ - } \ - template <> \ - TI_EXPORT std::unique_ptr create_instance_unique_ctor( \ - const std::string &alias, const Dict &config) { \ - return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ - ->create_unique_ctor(alias, config); \ - } \ - template <> \ - TI_EXPORT class_name *create_instance_raw(const std::string &alias) { \ - return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ - ->create_raw(alias); \ - } \ - template <> \ - TI_EXPORT class_name *create_instance_placement(const std::string &alias, \ - void *place) { \ - return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ - ->create_placement(alias, place); \ - } \ - template <> \ - TI_EXPORT class_name *create_instance_placement( \ - const std::string &alias, void *place, const Config &config) { \ - auto instance = create_instance_placement(alias, place); \ - instance->initialize(config); \ - return instance; \ - } \ - template <> \ - TI_EXPORT class_name *create_instance_raw(const std::string &alias, \ - const Config &config) { \ - auto instance = create_instance_raw(alias); \ - instance->initialize(config); \ - return instance; \ - } \ - template <> \ - std::vector get_implementation_names() { \ - return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ - ->get_implementation_names(); \ - } \ - TI_IMPLEMENTATION_HOLDER_NAME(class_name) * \ - TI_IMPLEMENTATION_HOLDER_PTR(class_name) = nullptr; \ - void *get_implementation_holder_instance_##class_name() { \ - if (!TI_IMPLEMENTATION_HOLDER_PTR(class_name)) { \ - TI_IMPLEMENTATION_HOLDER_PTR(class_name) = \ - new TI_IMPLEMENTATION_HOLDER_NAME(class_name)(base_alias); \ - } \ - return TI_IMPLEMENTATION_HOLDER_PTR(class_name); \ +#define TI_INTERFACE_DEF(class_name, base_alias) \ + template <> \ + std::shared_ptr create_instance(const std::string &alias) { \ + return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance()->create( \ + alias); \ + } \ + template <> \ + std::shared_ptr create_instance(const std::string &alias, \ + const Config &config) { \ + auto instance = create_instance(alias); \ + instance->initialize(config); \ + return instance; \ + } \ + template <> \ + std::unique_ptr create_instance_unique( \ + const std::string &alias) { \ + return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ + ->create_unique(alias); \ + } \ + template <> \ + std::unique_ptr create_instance_unique(const std::string &alias, \ + const Config &config) { \ + auto instance = create_instance_unique(alias); \ + instance->initialize(config); \ + return instance; \ + } \ + template <> \ + std::unique_ptr create_instance_unique_ctor( \ + const std::string &alias, const Dict &config) { \ + return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ + ->create_unique_ctor(alias, config); \ + } \ + template <> \ + class_name *create_instance_raw(const std::string &alias) { \ + return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ + ->create_raw(alias); \ + } \ + template <> \ + class_name *create_instance_placement(const std::string &alias, \ + void *place) { \ + return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ + ->create_placement(alias, place); \ + } \ + template <> \ + class_name *create_instance_placement(const std::string &alias, void *place, \ + const Config &config) { \ + auto instance = create_instance_placement(alias, place); \ + instance->initialize(config); \ + return instance; \ + } \ + template <> \ + class_name *create_instance_raw(const std::string &alias, \ + const Config &config) { \ + auto instance = create_instance_raw(alias); \ + instance->initialize(config); \ + return instance; \ + } \ + template <> \ + std::vector get_implementation_names() { \ + return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \ + ->get_implementation_names(); \ + } \ + TI_IMPLEMENTATION_HOLDER_NAME(class_name) * \ + TI_IMPLEMENTATION_HOLDER_PTR(class_name) = nullptr; \ + void *get_implementation_holder_instance_##class_name() { \ + if (!TI_IMPLEMENTATION_HOLDER_PTR(class_name)) { \ + TI_IMPLEMENTATION_HOLDER_PTR(class_name) = \ + new TI_IMPLEMENTATION_HOLDER_NAME(class_name)(base_alias); \ + } \ + return TI_IMPLEMENTATION_HOLDER_PTR(class_name); \ } #define TI_IMPLEMENTATION(base_class_name, class_name, alias) \ diff --git a/taichi/common/logging.cpp b/taichi/common/logging.cpp index 6c1a304e8d561..4b250bebb69ff 100644 --- a/taichi/common/logging.cpp +++ b/taichi/common/logging.cpp @@ -8,6 +8,9 @@ #include "spdlog/common.h" #include "spdlog/sinks/stdout_color_sinks.h" #include "spdlog/spdlog.h" +#ifdef ANDROID +#include "spdlog/sinks/android_sink.h" +#endif #include "taichi/common/core.h" namespace taichi { @@ -16,12 +19,12 @@ const auto default_logging_level = "info"; void Logger::set_level(const std::string &level_name) { auto new_level = level_enum_from_string(level_name); - level = new_level; - spdlog::set_level((spdlog::level::level_enum)level); + level_ = new_level; + spdlog::set_level((spdlog::level::level_enum)level_); } int Logger::get_level() { - return level; + return level_; } bool Logger::is_level_effective(const std::string &level_name) { @@ -52,8 +55,13 @@ int Logger::level_enum_from_string(const std::string &level_name) { } Logger::Logger() { - console = spdlog::stdout_color_mt("console"); - console->flush_on(spdlog::level::trace); +#ifdef ANDROID + console_ = spdlog::android_logger_mt("android", "taichi"); + console_->flush_on(spdlog::level::trace); +#else + console_ = spdlog::stdout_color_mt("console"); + console_->flush_on(spdlog::level::trace); +#endif TI_LOG_SET_PATTERN("%^[%L %D %X.%e %t] %v%$"); set_level_default(); @@ -64,23 +72,23 @@ void Logger::set_level_default() { } void Logger::trace(const std::string &s) { - console->trace(s); + console_->trace(s); } void Logger::debug(const std::string &s) { - console->debug(s); + console_->debug(s); } void Logger::info(const std::string &s) { - console->info(s); + console_->info(s); } void Logger::warn(const std::string &s) { - console->warn(s); + console_->warn(s); } void Logger::error(const std::string &s, bool raise_exception) { - console->error(s); + console_->error(s); fmt::print("\n\n"); if (print_stacktrace_fn_) { print_stacktrace_fn_(); @@ -100,7 +108,7 @@ void Logger::critical(const std::string &s) { } void Logger::flush() { - console->flush(); + console_->flush(); } void Logger::set_print_stacktrace_func(std::function print_fn) { diff --git a/taichi/common/logging.h b/taichi/common/logging.h index 69ee5cfd23bd6..0fd4b38b6284e 100644 --- a/taichi/common/logging.h +++ b/taichi/common/logging.h @@ -120,10 +120,10 @@ class logger; namespace taichi { -class Logger { +class TI_DLL_EXPORT Logger { private: - std::shared_ptr console; - int level; + std::shared_ptr console_; + int level_; std::function print_stacktrace_fn_; Logger(); diff --git a/taichi/common/platform_macros.h b/taichi/common/platform_macros.h index dee5ddf581747..022744d311ad1 100644 --- a/taichi/common/platform_macros.h +++ b/taichi/common/platform_macros.h @@ -9,6 +9,17 @@ #define _CRT_SECURE_NO_WARNINGS #endif +// https://gcc.gnu.org/wiki/Visibility +#if defined _WIN32 || defined _WIN64 || defined __CYGWIN__ +#ifdef __GNUC__ +#define TI_DLL_EXPORT __attribute__((dllexport)) +#else +#define TI_DLL_EXPORT __declspec(dllexport) +#endif // __GNUC__ +#else +#define TI_DLL_EXPORT __attribute__((visibility("default"))) +#endif // defined _WIN32 || defined _WIN64 || defined __CYGWIN__ + // Windows #if defined(_WIN64) #define TI_PLATFORM_WINDOWS @@ -20,14 +31,19 @@ static_assert(false, "32-bit Windows systems are not supported") // Linux #if defined(__linux__) +#if defined(ANDROID) +#define TI_PLATFORM_ANDROID +#else #define TI_PLATFORM_LINUX #endif +#endif // OSX #if defined(__APPLE__) #define TI_PLATFORM_OSX #endif -#if (defined(TI_PLATFORM_LINUX) || defined(TI_PLATFORM_OSX)) +#if (defined(TI_PLATFORM_LINUX) || defined(TI_PLATFORM_OSX) || \ + defined(__unix__)) #define TI_PLATFORM_UNIX #endif diff --git a/taichi/common/serialization.h b/taichi/common/serialization.h index a455af94db367..251d72ea92834 100644 --- a/taichi/common/serialization.h +++ b/taichi/common/serialization.h @@ -24,14 +24,13 @@ TI_NAMESPACE_BEGIN #else #define TI_NAMESPACE_BEGIN #define TI_NAMESPACE_END -#define TI_EXPORT #define TI_TRACE #define TI_CRITICAL #define TI_ASSERT assert #endif template -TI_EXPORT std::unique_ptr create_instance_unique(const std::string &alias); +std::unique_ptr create_instance_unique(const std::string &alias); //////////////////////////////////////////////////////////////////////////////// // A Minimalist Serializer for Taichi // @@ -58,7 +57,7 @@ template using is_unit_t = typename is_unit::type; } // namespace type - +class TextSerializer; namespace detail { template @@ -109,16 +108,31 @@ void serialize_kv_impl(SER &ser, } template -void serialize_kv_impl(SER &ser, - const std::array &keys, - T &&head, - Args &&... rest) { +typename std::enable_if::value, void>::type +serialize_kv_impl(SER &ser, + const std::array &keys, + T &&head, + Args &&... rest) { constexpr auto i = (N - 1 - sizeof...(Args)); std::string key{keys[i]}; ser(key.c_str(), head); serialize_kv_impl(ser, keys, rest...); } +// Specialize for TextSerializer since we need to append comma in the end for +// non-last object. +template +typename std::enable_if::value, void>::type +serialize_kv_impl(SER &ser, + const std::array &keys, + T &&head, + Args &&... rest) { + constexpr auto i = (N - 1 - sizeof...(Args)); + std::string key{keys[i]}; + ser(key.c_str(), head, true); + serialize_kv_impl(ser, keys, rest...); +} + } // namespace detail #define TI_IO_DECL \ @@ -148,10 +162,11 @@ void serialize_kv_impl(SER &ser, (std::is_same::type, \ T>()) +#if !defined(TI_ARCH_x86) static_assert( sizeof(std::size_t) == sizeof(uint64_t), "sizeof(std::size_t) should be 8. Try compiling with 64bit mode."); - +#endif template struct IO { using implemented = std::false_type; @@ -595,6 +610,7 @@ class BinarySerializer : public Serializer { using BinaryOutputSerializer = BinarySerializer; using BinaryInputSerializer = BinarySerializer; +// Serialize to JSON format class TextSerializer : public Serializer { public: std::string data; @@ -609,9 +625,9 @@ class TextSerializer : public Serializer { } private: - int indent; + int indent_; static constexpr int indent_width = 2; - bool first_line; + bool first_line_; template inline static constexpr bool is_elementary_type_v = @@ -620,8 +636,8 @@ class TextSerializer : public Serializer { public: TextSerializer() { - indent = 0; - first_line = false; + indent_ = 0; + first_line_ = false; } template @@ -632,13 +648,25 @@ class TextSerializer : public Serializer { } template - void operator()(const char *key, const T &t) { - this->process(key, t); + void operator()(const char *key, const T &t, bool append_comma = false) { + add_key(key); + process(t); + if (append_comma) { + add_raw(","); + } + } + + // Entry to make an AOT json file + template + void serialize_to_json(const char *key, const T &t) { + add_raw("{"); + (*this)(key, t); + add_raw("}"); } private: - void process(const char *key, const std::string &val) { - add_line(std::string(key) + ": " + val); + void process(const std::string &val) { + add_raw("\"" + val + "\""); } template @@ -649,172 +677,193 @@ class TextSerializer : public Serializer { // C-array template std::enable_if_t::value, void> process( - const char *key, const TArray &val) { std::stringstream ss; - ss << "["; + ss << "{"; for (std::size_t i = 0; i < n; i++) { ss << val[i]; if (i != n - 1) { ss << ", "; } } - ss << "]"; - add_line(key, ss.str()); + ss << "}"; + add_raw(ss.str()); } // C-array template std::enable_if_t::value, void> process( - const char *key, const TArray &val) { - add_line(key, "["); - indent++; + add_raw("{"); + indent_++; for (std::size_t i = 0; i < n; i++) { - this->process(("[" + std::to_string(i) + "]").c_str(), val[i]); + add_key(std::to_string(i).c_str()); + process(val[i]); + if (i != n - 1) { + add_raw(","); + } } - indent--; - add_line("]"); + indent_--; + add_raw("}"); } // std::array template std::enable_if_t::value, void> process( - const char *key, const StdTArray &val) { std::stringstream ss; - ss << "["; + ss << "{"; for (std::size_t i = 0; i < n; i++) { ss << val[i]; if (i != n - 1) { ss << ", "; } } - ss << "]"; - add_line(key, ss.str()); + ss << "}"; + add_raw(ss.str()); } // std::array template std::enable_if_t::value, void> process( - const char *key, const StdTArray &val) { - add_line(key, "["); - indent++; + add_raw("{"); + indent_++; for (std::size_t i = 0; i < n; i++) { - this->process(("[" + std::to_string(i) + "]").c_str(), val[i]); + add_key(std::to_string(i).c_str()); + process(val[i]); + if (i != n - 1) { + add_raw(","); + } } - indent--; - add_line("]"); + indent_--; + add_raw("}"); } // Elementary data types template - std::enable_if_t, void> process(const char *key, - const T &val) { + std::enable_if_t, void> process(const T &val) { static_assert(!has_io::value, ""); std::stringstream ss; ss << std::boolalpha << val; - add_line(key, ss.str()); + add_raw(ss.str()); } template - std::enable_if_t::value, void> process(const char *key, - const T &val) { - add_line(key, "{"); - indent++; + std::enable_if_t::value, void> process(const T &val) { + add_raw("{"); + indent_++; val.io(*this); - indent--; - add_line("}"); + indent_--; + add_raw("}"); } template - std::enable_if_t::value, void> process(const char *key, - const T &val) { - add_line(key, "{"); - indent++; + std::enable_if_t::value, void> process(const T &val) { + add_raw("{"); + indent_++; IO, decltype(*this)>()(*this, val); - indent--; - add_line("}"); + indent_--; + add_raw("}"); } template - std::enable_if_t, void> process(const char *key, - const T &val) { + std::enable_if_t, void> process(const T &val) { using UT = std::underlying_type_t; - this->process(key, static_cast(val)); + process(static_cast(val)); } template - void process(const char *key, const std::vector &val) { - add_line(key, "["); - indent++; + void process(const std::vector &val) { + add_raw("["); + indent_++; for (std::size_t i = 0; i < val.size(); i++) { - this->process(("[" + std::to_string(i) + "]").c_str(), val[i]); + process(val[i]); + if (i < val.size() - 1) { + add_raw(","); + } } - indent--; - add_line("]"); + indent_--; + add_raw("]"); } template - void process(const char *key, const std::pair &val) { - add_line(key, "("); - indent++; - this->process("first", val.first); - this->process("second", val.second); - indent--; - add_line(")"); + void process(const std::pair &val) { + add_raw("["); + indent_++; + process("first", val.first); + add_raw(", "); + process("second", val.second); + indent_--; + add_raw("]"); } // std::map template - void process(const char *key, const std::map &val) { - handle_associative_container(key, val); + void process(const std::map &val) { + handle_associative_container(val); } // std::unordered_map template - void process(const char *key, const std::unordered_map &val) { - handle_associative_container(key, val); + void process(const std::unordered_map &val) { + handle_associative_container(val); } // std::optional template - void process(const char *key, const std::optional &val) { - add_line(key, "{"); - indent++; - this->process("has_value", val.has_value()); + void process(const std::optional &val) { + add_raw("{"); + indent_++; + add_key("has_value"); + process(val.has_value()); if (val.has_value()) { - this->process("value", val.value()); + add_raw(","); + add_key("value"); + process(val.value()); } - indent--; - add_line("}"); + indent_--; + add_raw("}"); } template - void handle_associative_container(const char *key, const M &val) { - add_line(key, "{"); - indent++; - for (auto iter : val) { - auto first = iter.first; - this->process("key", first); - this->process("value", iter.second); + void handle_associative_container(const M &val) { + add_raw("{"); + indent_++; + for (auto iter = val.begin(); iter != val.end(); iter++) { + auto first = iter->first; + bool is_string = typeid(first) == typeid(std::string); + // Non-string keys must be wrapped by quotes. + if (!is_string) { + add_raw("\""); + } + process(first); + if (!is_string) { + add_raw("\""); + } + add_raw(": "); + process(iter->second); + if (std::next(iter) != val.end()) { + add_raw(","); + } } - indent--; - add_line("}"); + indent_--; + add_raw("}"); + } + + void add_raw(const std::string &str) { + data += str; } - void add_line(const std::string &str) { - if (first_line) { - first_line = false; + void add_key(const std::string &key) { + if (first_line_) { + first_line_ = false; } else { data += "\n"; } - data += std::string(indent_width * indent, ' ') + str; - } + data += std::string(indent_width * indent_, ' ') + "\"" + key + "\""; - void add_line(const std::string &key, const std::string &value) { - add_line(key + ": " + value); + add_raw(": "); } }; diff --git a/taichi/common/symbol_version.cpp b/taichi/common/symbol_version.cpp new file mode 100644 index 0000000000000..9f1bfeaffd3d7 --- /dev/null +++ b/taichi/common/symbol_version.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + Copyright (c) The Taichi Authors (2016- ). All Rights Reserved. + The use of this software is governed by the LICENSE file. +*******************************************************************************/ + +#include "taichi/common/core.h" + +#if defined(TI_PLATFORM_WINDOWS) +#include "taichi/platform/windows/windows.h" +#else +// Mac and Linux +#include +#endif + +TI_NAMESPACE_BEGIN + +extern "C" { +#if defined(TI_PLATFORM_LINUX) && defined(TI_ARCH_x64) +// Avoid dependency on higher glibc versions such as 2.27 or 2.29 +// Related issue: https://github.com/taichi-dev/taichi/issues/3174 +// log2f is used by a third party .a file, so we have to define a wrapper. +// https://stackoverflow.com/questions/8823267/linking-against-older-symbol-version-in-a-so-file +// The wrapper should be linked using target_link_libraries in TaichiCore.cmake +__asm__(".symver log2f,log2f@GLIBC_2.2.5"); +float __wrap_log2f(float x) { + return log2f(x); +} +// The following are offending symbols using higher GLIBC versions +// They will fail Vulkan tests if wrapping is enabled +__asm__(".symver exp2,exp2@GLIBC_2.2.5"); +float __wrap_exp2(float x) { + return exp2(x); +} +__asm__(".symver log2,log2@GLIBC_2.2.5"); +float __wrap_log2(float x) { + return log2(x); +} +__asm__(".symver logf,logf@GLIBC_2.2.5"); +float __wrap_logf(float x) { + return logf(x); +} +__asm__(".symver powf,powf@GLIBC_2.2.5"); +float __wrap_powf(float x, float y) { + return powf(x, y); +} +__asm__(".symver exp,exp@GLIBC_2.2.5"); +float __wrap_exp(float x) { + return exp(x); +} +__asm__(".symver log,log@GLIBC_2.2.5"); +float __wrap_log(float x) { + return log(x); +} +__asm__(".symver pow,pow@GLIBC_2.2.5"); +float __wrap_pow(float x, float y) { + return pow(x, y); +} +#endif +} + +TI_NAMESPACE_END diff --git a/taichi/common/task.h b/taichi/common/task.h index 1253e8f398fbb..d470c833edd7a 100644 --- a/taichi/common/task.h +++ b/taichi/common/task.h @@ -23,7 +23,7 @@ class Task : public Unit { return this->run(std::vector()); } - ~Task() { + ~Task() override { } }; diff --git a/taichi/common/version.h.in b/taichi/common/version.h.in new file mode 100644 index 0000000000000..49f4ec0d1a3a8 --- /dev/null +++ b/taichi/common/version.h.in @@ -0,0 +1,5 @@ +#pragma once +#define TI_VERSION_MAJOR @TI_VERSION_MAJOR@ +#define TI_VERSION_MINOR @TI_VERSION_MINOR@ +#define TI_VERSION_PATCH @TI_VERSION_PATCH@ +#define CUDA_VERSION "@CUDA_VERSION@" diff --git a/taichi/gui/android.cpp b/taichi/gui/android.cpp new file mode 100644 index 0000000000000..60a5e62a49ff1 --- /dev/null +++ b/taichi/gui/android.cpp @@ -0,0 +1,32 @@ +#include "taichi/gui/gui.h" + +// GGUI is not suppored on Android as the window management is handled by the +// framework directly. It also provides a Canvas through Skia library that users +// can leverage for rendering of 2D elements (circle, rectangle, ...) +#if defined(TI_GUI_ANDROID) +#include + +TI_NAMESPACE_BEGIN + +void GUI::process_event() { + TI_ERROR("GGUI not supported on Android"); +} + +void GUI::create_window() { + TI_ERROR("GGUI not supported on Android"); +} + +void GUI::redraw() { + TI_ERROR("GGUI not supported on Android"); +} + +void GUI::set_title(std::string title) { + TI_ERROR("GGUI not supported on Android"); +} + +GUI::~GUI() { +} + +TI_NAMESPACE_END + +#endif diff --git a/taichi/gui/gui.h b/taichi/gui/gui.h index 7a5b957d3a48e..8c2159cad5e22 100644 --- a/taichi/gui/gui.h +++ b/taichi/gui/gui.h @@ -9,9 +9,14 @@ #include #include -#if defined(TI_PLATFORM_LINUX) +#if defined(TI_PLATFORM_LINUX) || \ + (defined(TI_PLATFORM_UNIX) && !defined(TI_PLATFORM_OSX)) +#if defined(TI_PLATFORM_ANDROID) +#define TI_GUI_ANDROID +#elif !defined(TI_EMSCRIPTENED) #define TI_GUI_X11 #endif +#endif #if defined(TI_PLATFORM_WINDOWS) #define TI_GUI_WIN32 @@ -42,7 +47,7 @@ constexpr uint32 slider_bar_color = 0x333333; constexpr uint32 slider_circle_color = 0x555555; #endif -class Canvas { +class TI_DLL_EXPORT Canvas { struct Context { Vector4 _color; real _radius; @@ -409,7 +414,7 @@ class Canvas { Vector4 color) { position = transform(position); std::string folder; - folder = fmt::format("{}/../assets", lang::runtime_lib_dir()); + folder = fmt::format("{}/../../assets", lang::runtime_lib_dir()); auto ttf_path = fmt::format("{}/Go-Regular.ttf", folder); img.write_text(ttf_path, str, size, position.x, position.y, color); } @@ -432,6 +437,28 @@ class Canvas { } }; +#if defined(TI_GUI_ANDROID) + +class GUIBaseAndroid { + public: + // @TODO +}; + +using GUIBase = GUIBaseAndroid; + +#endif + +#if defined(TI_EMSCRIPTENED) + +class GUIBaseJavascript { + public: + // @TODO +}; + +using GUIBase = GUIBaseJavascript; + +#endif + #if defined(TI_GUI_X11) class CXImage; @@ -482,7 +509,7 @@ class GUIBaseCocoa { using GUIBase = GUIBaseCocoa; #endif -class GUI : public GUIBase { +class TI_DLL_EXPORT GUI : public GUIBase { public: std::string window_name; int width, height; diff --git a/taichi/inc/archs.inc.h b/taichi/inc/archs.inc.h index 8f9b8bf21b469..ca266b0968937 100644 --- a/taichi/inc/archs.inc.h +++ b/taichi/inc/archs.inc.h @@ -11,7 +11,7 @@ PER_ARCH(wasm) // WebAssembly PER_ARCH(cuda) // NVIDIA CUDA PER_ARCH(metal) // Apple Metal PER_ARCH(opengl) // OpenGL Compute Shaders -PER_ARCH(dx) // Microsoft DirectX, N/A +PER_ARCH(dx11) // Microsoft DirectX 11, WIP PER_ARCH(opencl) // OpenCL, N/A PER_ARCH(amdgpu) // AMD GPU, N/A PER_ARCH(vulkan) // Vulkan diff --git a/taichi/inc/constants.h b/taichi/inc/constants.h index 4d2562dcb8e12..820a3229e659d 100644 --- a/taichi/inc/constants.h +++ b/taichi/inc/constants.h @@ -10,7 +10,7 @@ constexpr int taichi_max_num_args = 8; constexpr int taichi_max_num_args_total = 64; constexpr int taichi_max_num_args_extra = 16; constexpr int taichi_max_num_snodes = 1024; -constexpr int kMaxNumSnodeTreesLlvm = 32; +constexpr int kMaxNumSnodeTreesLlvm = 512; constexpr int taichi_max_gpu_block_dim = 1024; constexpr std::size_t taichi_global_tmp_buffer_size = 1024 * 1024; constexpr int taichi_max_num_mem_requests = 1024 * 64; @@ -18,14 +18,19 @@ constexpr std::size_t taichi_page_size = 4096; constexpr std::size_t taichi_error_message_max_length = 2048; constexpr std::size_t taichi_error_message_max_num_arguments = 32; constexpr std::size_t taichi_result_buffer_entries = 32; +constexpr std::size_t taichi_max_num_ret_value = 30; // slot for kernel return value constexpr std::size_t taichi_result_buffer_ret_value_id = 0; // slot for error code and error message char * -constexpr std::size_t taichi_result_buffer_error_id = 1; -constexpr std::size_t taichi_result_buffer_runtime_query_id = 2; +constexpr std::size_t taichi_result_buffer_error_id = 30; +constexpr std::size_t taichi_result_buffer_runtime_query_id = 31; constexpr int taichi_listgen_max_element_size = 1024; +// use for auto mesh_local to determine shared-mem size per block (in bytes) +// TODO: get this at runtime +constexpr std::size_t default_shared_mem_size = 65536; + template T taichi_union_cast_with_different_sizes(G g) { union { diff --git a/taichi/inc/extensions.inc.h b/taichi/inc/extensions.inc.h index dae6b432b30a2..b2c68a3678bc7 100644 --- a/taichi/inc/extensions.inc.h +++ b/taichi/inc/extensions.inc.h @@ -2,6 +2,7 @@ PER_EXTENSION(sparse) // Sparse data structures PER_EXTENSION(async_mode) // Asynchronous execution mode PER_EXTENSION(quant) // Quantization +PER_EXTENSION(mesh) // MeshTaichi PER_EXTENSION(quant_basic) // Basic operations in quantization PER_EXTENSION(data64) // Metal doesn't support 64-bit data buffers yet... PER_EXTENSION(adstack) // For keeping the history of mutable local variables diff --git a/taichi/inc/offloaded_task_type.inc.h b/taichi/inc/offloaded_task_type.inc.h index 36ae476708f6e..ddb0845cff061 100644 --- a/taichi/inc/offloaded_task_type.inc.h +++ b/taichi/inc/offloaded_task_type.inc.h @@ -1,5 +1,6 @@ PER_TASK_TYPE(serial) PER_TASK_TYPE(range_for) PER_TASK_TYPE(struct_for) +PER_TASK_TYPE(mesh_for) PER_TASK_TYPE(listgen) PER_TASK_TYPE(gc) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index a631fe23884c7..e82ba95ca407a 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -1,4 +1,5 @@ // Frontend statements +PER_STATEMENT(FrontendExternalFuncStmt) PER_STATEMENT(FrontendExprStmt) PER_STATEMENT(FrontendIfStmt) PER_STATEMENT(FrontendForStmt) @@ -16,14 +17,17 @@ PER_STATEMENT(FrontendReturnStmt) // Middle-end statement +// Decoration / debug statement +PER_STATEMENT(DecorationStmt) + // Without per-lane attributes PER_STATEMENT(RangeForStmt) PER_STATEMENT(StructForStmt) +PER_STATEMENT(MeshForStmt) PER_STATEMENT(IfStmt) PER_STATEMENT(WhileStmt) PER_STATEMENT(WhileControlStmt) PER_STATEMENT(ContinueStmt) -PER_STATEMENT(FuncBodyStmt) PER_STATEMENT(FuncCallStmt) PER_STATEMENT(ReturnStmt) @@ -71,10 +75,13 @@ PER_STATEMENT(ElementShuffleStmt) // Offloaded PER_STATEMENT(OffloadedStmt) +PER_STATEMENT(MeshRelationAccessStmt) +PER_STATEMENT(MeshIndexConversionStmt) +PER_STATEMENT(MeshPatchIndexStmt) PER_STATEMENT(LoopIndexStmt) PER_STATEMENT(LoopLinearIndexStmt) +PER_STATEMENT(GlobalThreadIndexStmt) PER_STATEMENT(BlockCornerIndexStmt) -PER_STATEMENT(BlockDimStmt) PER_STATEMENT(GlobalTemporaryStmt) PER_STATEMENT(ClearListStmt) diff --git a/taichi/inc/unary_op.inc.h b/taichi/inc/unary_op.inc.h index f10ec0badacfc..7bdbaffd9e4d6 100644 --- a/taichi/inc/unary_op.inc.h +++ b/taichi/inc/unary_op.inc.h @@ -1,5 +1,6 @@ PER_UNARY_OP(neg) PER_UNARY_OP(sqrt) +PER_UNARY_OP(round) PER_UNARY_OP(floor) PER_UNARY_OP(ceil) PER_UNARY_OP(cast_value) diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index a667fa55e12e7..ac536ac3c6bd9 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -1,8 +1,10 @@ #pragma once #include "taichi/ir/ir.h" +#include "taichi/ir/mesh.h" #include "taichi/ir/pass.h" #include "taichi/analysis/gather_uniquely_accessed_pointers.h" +#include "taichi/analysis/mesh_bls_analyzer.h" #include #include #include @@ -13,7 +15,7 @@ namespace lang { class DiffRange { private: - bool related; + bool related_; public: int coeff; @@ -31,22 +33,22 @@ class DiffRange { } DiffRange(bool related, int coeff, int low, int high) - : related(related), coeff(coeff), low(low), high(high) { + : related_(related), coeff(coeff), low(low), high(high) { if (!related) { this->low = this->high = 0; } } - bool related_() const { - return related; + bool related() const { + return related_; } bool linear_related() const { - return related && coeff == 1; + return related_ && coeff == 1; } bool certain() { - TI_ASSERT(related); + TI_ASSERT(related_); return high == low + 1; } }; @@ -104,7 +106,8 @@ gather_snode_read_writes(IRNode *root); std::vector gather_statements(IRNode *root, const std::function &test); void gather_uniquely_accessed_bit_structs(IRNode *root, AnalysisManager *amgr); -std::unordered_map +std::pair, + std::unordered_map> gather_uniquely_accessed_pointers(IRNode *root); std::unique_ptr> gather_used_atomics( IRNode *root); @@ -234,6 +237,16 @@ std::unordered_set constexpr_prop( void verify(IRNode *root); +// Mesh Related. +void gather_meshfor_relation_types(IRNode *node); +std::pair, + /* total= */ std::unordered_set> +gather_mesh_thread_local(OffloadedStmt *offload, const CompileConfig &config); +std::unique_ptr initialize_mesh_local_attribute( + OffloadedStmt *offload, + bool auto_mesh_local, + const CompileConfig &config); + } // namespace analysis } // namespace irpass } // namespace lang diff --git a/taichi/ir/basic_stmt_visitor.cpp b/taichi/ir/basic_stmt_visitor.cpp index 95f9847c4f198..989c71f35af46 100644 --- a/taichi/ir/basic_stmt_visitor.cpp +++ b/taichi/ir/basic_stmt_visitor.cpp @@ -40,14 +40,14 @@ void BasicStmtVisitor::visit(StructForStmt *for_stmt) { for_stmt->body->accept(this); } -void BasicStmtVisitor::visit(OffloadedStmt *stmt) { - preprocess_container_stmt(stmt); - stmt->all_blocks_accept(this); +void BasicStmtVisitor::visit(MeshForStmt *for_stmt) { + preprocess_container_stmt(for_stmt); + for_stmt->body->accept(this); } -void BasicStmtVisitor::visit(FuncBodyStmt *stmt) { +void BasicStmtVisitor::visit(OffloadedStmt *stmt) { preprocess_container_stmt(stmt); - stmt->body->accept(this); + stmt->all_blocks_accept(this); } void BasicStmtVisitor::visit(FrontendWhileStmt *stmt) { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index b7324a7c73233..98594ceb84ac4 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -29,10 +29,10 @@ CFGNode::CFGNode(Block *block, TI_ASSERT(begin_location >= 0); TI_ASSERT(block); auto parent_block = block; - parent_blocks.insert(parent_block); + parent_blocks_.insert(parent_block); while (parent_block->parent_block()) { parent_block = parent_block->parent_block(); - parent_blocks.insert(parent_block); + parent_blocks_.insert(parent_block); } } } @@ -167,7 +167,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // |parent_blocks| is precomputed in the constructor of CFGNode. // TODO: What if |stmt| appears in an ancestor of |block| but after // |position|? - return parent_blocks.find(stmt->parent) != parent_blocks.end(); + return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); }; /** * |stmt| is a definition in the UD-chain of |var|. Update |result| with @@ -261,15 +261,10 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { if (auto local_load = stmt->cast()) { bool regular = true; auto alloca = local_load->src[0].var; - // TODO: store-to-load forwarding with TensorType Alloca - if (alloca->is()) { - regular = false; - } else { - for (int l = 0; l < stmt->width(); l++) { - if (local_load->src[l].offset != l || - local_load->src[l].var != alloca) { - regular = false; - } + for (int l = 0; l < stmt->width(); l++) { + if (local_load->src[l].offset != l || + local_load->src[l].var != alloca) { + regular = false; } } if (regular) { @@ -288,7 +283,7 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { zero->repeat(result->width()); replace_with(i, std::move(zero), true); } else { - stmt->replace_with(result); + stmt->replace_usages_with(result); erase(i); // This causes end_location-- i--; // to cancel i++ in the for loop modified = true; @@ -442,6 +437,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // After lower_access, we only analyze local variables and stacks. // Do not eliminate AllocaStmt and AdStackAllocaStmt here. if (!stmt->is() && !stmt->is() && + !stmt->is() && !may_contain_variable(live_in_this_node, store_ptr) && (contain_variable(killed_in_this_node, store_ptr) || !may_contain_variable(live_out, store_ptr))) { @@ -511,7 +507,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // Only perform identical load elimination within a CFGNode. auto next_load_stmt = live_load_in_this_node[load_ptr]; TI_ASSERT(irpass::analysis::same_statements(stmt, next_load_stmt)); - next_load_stmt->replace_with(stmt); + next_load_stmt->replace_usages_with(stmt); erase(block->locate(next_load_stmt)); modified = true; } @@ -625,21 +621,18 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { TI_ASSERT(nodes[start_node]->empty()); nodes[start_node]->reach_gen.clear(); nodes[start_node]->reach_kill.clear(); - if (!after_lower_access) { - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - if (stmt->is() || stmt->is() || + for (int i = 0; i < num_nodes; i++) { + for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { + auto stmt = nodes[i]->block->statements[j].get(); + if ((stmt->is() && + stmt->as()->origin->is()) || + (!after_lower_access && + (stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is() || - (stmt->is() && - stmt->cast()->origin->is()) || - (stmt->is() && - stmt->cast()->is_unlowered_global_ptr())) { - // TODO: unify them - // A global pointer that may contain some data before this kernel. - nodes[start_node]->reach_gen.insert(stmt); - } + stmt->is() || stmt->is()))) { + // TODO: unify them + // A global pointer that may contain some data before this kernel. + nodes[start_node]->reach_gen.insert(stmt); } } } @@ -711,6 +704,10 @@ void ControlFlowGraph::live_variable_analysis( if (stmt->is() || stmt->is()) { return false; } + if (stmt->is() && + stmt->cast()->origin->is()) { + return false; + } if (auto *gptr = stmt->cast(); gptr && config_opt.has_value()) { TI_ASSERT(gptr->snodes.size() == 1); diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 88bac2f46d99b..1e98a4ce803d7 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -19,7 +19,7 @@ namespace lang { class CFGNode { private: // For accelerating get_store_forwarding_data() - std::unordered_set parent_blocks; + std::unordered_set parent_blocks_; public: // This node corresponds to block->statements[i] diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index df6f050901437..7ae14b364eabe 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -6,9 +6,15 @@ TLANG_NAMESPACE_BEGIN -std::string Expr::serialize() const { +void Expr::serialize(std::ostream &ss) const { TI_ASSERT(expr); - return expr->serialize(); + expr->serialize(ss); +} + +std::string Expr::serialize() const { + std::stringstream ss; + serialize(ss); + return ss.str(); } void Expr::set_tb(const std::string &tb) { @@ -23,6 +29,14 @@ std::string Expr::get_attribute(const std::string &key) const { return expr->get_attribute(key); } +DataType Expr::get_ret_type() const { + return expr->ret_type; +} + +void Expr::type_check(CompileConfig *config) { + expr->type_check(config); +} + Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) { return Expr::make(TernaryOpType::select, cond, true_val, false_val); @@ -37,39 +51,20 @@ Expr operator~(const Expr &expr) { } Expr cast(const Expr &input, DataType dt) { - auto ret = - std::make_shared(UnaryOpType::cast_value, input); - ret->cast_type = dt; - return Expr(ret); + return Expr::make(UnaryOpType::cast_value, input, dt); } Expr bit_cast(const Expr &input, DataType dt) { - auto ret = std::make_shared(UnaryOpType::cast_bits, input); - ret->cast_type = dt; - return Expr(ret); + return Expr::make(UnaryOpType::cast_bits, input, dt); } Expr Expr::operator[](const ExprGroup &indices) const { TI_ASSERT(is() || is()); - return Expr::make(*this, indices.loaded()); + return Expr::make(*this, indices); } Expr &Expr::operator=(const Expr &o) { - if (get_current_program().current_callable) { - // Inside a kernel or a function - // Create an assignment in the IR - if (expr == nullptr) { - set(o.eval()); - } else if (expr->is_lvalue()) { - current_ast_builder().insert(std::make_unique( - ptr_if_global(*this), load_if_ptr(o))); - } else { - // set(o.eval()); - TI_ERROR("Cannot assign to non-lvalue: {}", serialize()); - } - } else { - set(o); // Literally set this Expr to o - } + set(o); return *this; } @@ -99,107 +94,23 @@ void Expr::set_grad(const Expr &o) { } Expr::Expr(int32 x) : Expr() { - expr = std::make_shared(x); + expr = std::make_shared(PrimitiveType::i32, x); } Expr::Expr(int64 x) : Expr() { - expr = std::make_shared(x); + expr = std::make_shared(PrimitiveType::i64, x); } Expr::Expr(float32 x) : Expr() { - expr = std::make_shared(x); + expr = std::make_shared(PrimitiveType::f32, x); } Expr::Expr(float64 x) : Expr() { - expr = std::make_shared(x); + expr = std::make_shared(PrimitiveType::f64, x); } Expr::Expr(const Identifier &id) : Expr() { expr = std::make_shared(id); } -Expr Expr::eval() const { - TI_ASSERT(expr != nullptr); - if (is()) { - return *this; - } - auto eval_stmt = Stmt::make(*this); - auto eval_expr = Expr::make(eval_stmt.get()); - eval_stmt->as()->eval_expr.set(eval_expr); - // needed in lower_ast to replace the statement itself with the - // lowered statement - current_ast_builder().insert(std::move(eval_stmt)); - return eval_expr; -} - -void Expr::operator+=(const Expr &o) { - if (this->atomic) { - (*this) = Expr::make( - AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o)); - } else { - (*this) = (*this) + o; - } -} - -void Expr::operator-=(const Expr &o) { - if (this->atomic) { - (*this) = Expr::make( - AtomicOpType::sub, ptr_if_global(*this), load_if_ptr(o)); - } else { - (*this) = (*this) - o; - } -} - -void Expr::operator*=(const Expr &o) { - TI_ASSERT(!this->atomic); - (*this) = (*this) * load_if_ptr(o); -} - -void Expr::operator/=(const Expr &o) { - TI_ASSERT(!this->atomic); - (*this) = (*this) / load_if_ptr(o); -} - -Expr load_if_ptr(const Expr &ptr) { - if (ptr.is()) { - return Expr::make(ptr); - } else if (ptr.is()) { - TI_ASSERT(ptr.cast()->snode->num_active_indices == - 0); - return Expr::make(ptr[ExprGroup()]); - } else if (ptr.is()) { - auto tensor_ptr = ptr.cast(); - if (tensor_ptr->is_global_tensor()) - return Expr::make(ptr); - else if (tensor_ptr->is_local_tensor()) - return Expr::make(ptr); - else { - TI_NOT_IMPLEMENTED - } - } else - return ptr; -} - -Expr ptr_if_global(const Expr &var) { - if (var.is()) { - // singleton global variable - TI_ASSERT_INFO(var.snode()->num_active_indices == 0, - "Please always use 'x[None]' (instead of simply 'x') to " - "access any 0-D field."); - return var[ExprGroup()]; - } else { - // may be any local or global expr - return var; - } -} - -Expr Var(const Expr &x) { - auto var = Expr(std::make_shared()); - current_ast_builder().insert(std::make_unique( - std::static_pointer_cast(var.expr)->id, - PrimitiveType::unknown)); - var = x; - return var; -} - TLANG_NAMESPACE_END diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 726e438d82d5f..e41d71d383947 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -1,9 +1,11 @@ #pragma once +#include "taichi/util/str.h" #include "taichi/ir/type_utils.h" TLANG_NAMESPACE_BEGIN +struct CompileConfig; class Expression; class Identifier; class ExprGroup; @@ -20,13 +22,13 @@ class Expr { atomic = false; } - Expr(int32 x); + explicit Expr(int32 x); - Expr(int64 x); + explicit Expr(int64 x); - Expr(float32 x); + explicit Expr(float32 x); - Expr(float64 x); + explicit Expr(float64 x); Expr(std::shared_ptr expr) : Expr() { this->expr = expr; @@ -43,7 +45,7 @@ class Expr { atomic = o.atomic; } - Expr(const Identifier &id); + explicit Expr(const Identifier &id); void set(const Expr &o) { expr = o.expr; @@ -72,16 +74,16 @@ class Expr { return cast() != nullptr; } + // FIXME: We really should disable it completely, + // but we can't. This is because the usage of + // std::variant in FrontendPrintStmt. Expr &operator=(const Expr &o); Expr operator[](const ExprGroup &indices) const; std::string serialize() const; + void serialize(std::ostream &ss) const; - void operator+=(const Expr &o); - void operator-=(const Expr &o); - void operator*=(const Expr &o); - void operator/=(const Expr &o); Expr operator!(); Expr eval() const; @@ -105,6 +107,10 @@ class Expr { void set_attribute(const std::string &key, const std::string &value); std::string get_attribute(const std::string &key) const; + + DataType get_ret_type() const; + + void type_check(CompileConfig *config); }; Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val); @@ -128,15 +134,4 @@ Expr bit_cast(const Expr &input) { return taichi::lang::bit_cast(input, get_data_type()); } -Expr load_if_ptr(const Expr &ptr); -Expr ptr_if_global(const Expr &var); - -inline Expr smart_load(const Expr &var) { - return load_if_ptr(ptr_if_global(var)); -} - -// Begin: legacy frontend functions -Expr Var(const Expr &x); -// End: legacy frontend functions - TLANG_NAMESPACE_END diff --git a/taichi/ir/expression.cpp b/taichi/ir/expression.cpp index 43e084067b0b3..493ace469d00f 100644 --- a/taichi/ir/expression.cpp +++ b/taichi/ir/expression.cpp @@ -11,22 +11,19 @@ std::string Expression::get_attribute(const std::string &key) const { } } -ExprGroup ExprGroup::loaded() const { - auto indices_loaded = *this; - for (int i = 0; i < (int)this->size(); i++) - indices_loaded[i].set(load_if_ptr(indices_loaded[i])); - return indices_loaded; -} - -std::string ExprGroup::serialize() const { - std::string ret; +void ExprGroup::serialize(std::ostream &ss) const { for (int i = 0; i < (int)exprs.size(); i++) { - ret += exprs[i].serialize(); + exprs[i].serialize(ss); if (i + 1 < (int)exprs.size()) { - ret += ", "; + ss << ", "; } } - return ret; +} + +std::string ExprGroup::serialize() const { + std::stringstream ss; + serialize(ss); + return ss.str(); } } // namespace lang diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index f792feb4f4ef8..3f10a997e19fc 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -1,5 +1,7 @@ #pragma once +#include "taichi/program/compile_config.h" +#include "taichi/util/str.h" #include "taichi/ir/ir.h" #include "taichi/ir/expr.h" @@ -13,6 +15,7 @@ class Expression { Stmt *stmt; std::string tb; std::map attributes; + DataType ret_type; struct FlattenContext { VecStatement stmts; @@ -36,7 +39,12 @@ class Expression { stmt = nullptr; } - virtual std::string serialize() = 0; + virtual void type_check(CompileConfig *config) { + // TODO: make it pure virtual after type_check for all expressions are + // implemented + } + + virtual void serialize(std::ostream &ss) = 0; virtual void flatten(FlattenContext *ctx) { TI_NOT_IMPLEMENTED; @@ -64,22 +72,29 @@ class ExprGroup { } ExprGroup(const Expr &a) { - exprs.push_back(a); + exprs.emplace_back(a); } ExprGroup(const Expr &a, const Expr &b) { - exprs.push_back(a); - exprs.push_back(b); + exprs.emplace_back(a); + exprs.emplace_back(b); } ExprGroup(const ExprGroup &a, const Expr &b) { - exprs = a.exprs; - exprs.push_back(b); + exprs.resize(a.size() + 1); + + for (int i = 0; i < a.size(); ++i) { + exprs[i].set(a.exprs[i]); + } + exprs.back().set(b); } ExprGroup(const Expr &a, const ExprGroup &b) { - exprs = b.exprs; - exprs.insert(exprs.begin(), a); + exprs.resize(b.size() + 1); + exprs.front().set(a); + for (int i = 0; i < b.size(); i++) { + exprs[i + 1].set(b.exprs[i]); + } } void push_back(const Expr &expr) { @@ -98,8 +113,9 @@ class ExprGroup { return exprs[i]; } + void serialize(std::ostream &ss) const; + std::string serialize() const; - ExprGroup loaded() const; }; inline ExprGroup operator,(const Expr &a, const Expr &b) { diff --git a/taichi/ir/expression_ops.h b/taichi/ir/expression_ops.h index 7ff0bd1813d4e..79c29a5d2fbfc 100644 --- a/taichi/ir/expression_ops.h +++ b/taichi/ir/expression_ops.h @@ -56,6 +56,7 @@ #endif DEFINE_EXPRESSION_OP_UNARY(sqrt) +DEFINE_EXPRESSION_OP_UNARY(round) DEFINE_EXPRESSION_OP_UNARY(floor) DEFINE_EXPRESSION_OP_UNARY(ceil) DEFINE_EXPRESSION_OP_UNARY(abs) diff --git a/taichi/ir/frontend.cpp b/taichi/ir/frontend.cpp index 34c92fc01113a..b4f5b80b5b1a7 100644 --- a/taichi/ir/frontend.cpp +++ b/taichi/ir/frontend.cpp @@ -16,56 +16,4 @@ Expr global_new(DataType dt, std::string name) { auto id_expr = std::make_shared(name); return Expr::make(dt, id_expr->id); } - -Expr copy(const Expr &expr) { - auto e = expr.eval(); - auto stmt = Stmt::make( - VectorElement(e.cast()->stmt_ptr, 0)); - auto eval_expr = std::make_shared(stmt.get()); - current_ast_builder().insert(std::move(stmt)); - return Expr(eval_expr); -} - -void insert_snode_access_flag(SNodeAccessFlag v, const Expr &field) { - dec.mem_access_opt.add_flag(field.snode(), v); -} - -void reset_snode_access_flag() { - dec.reset(); -} - -// Begin: legacy frontend constructs - -For::For(const Expr &s, const Expr &e, const std::function &func) { - auto i = Expr(std::make_shared()); - auto stmt_unique = std::make_unique(i, s, e); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - auto _ = current_ast_builder().create_scope(stmt->body); - func(i); -} - -For::For(const Expr &i, - const Expr &s, - const Expr &e, - const std::function &func) { - auto stmt_unique = std::make_unique(i, s, e); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - auto _ = current_ast_builder().create_scope(stmt->body); - func(); -} - -For::For(const ExprGroup &i, - const Expr &global, - const std::function &func) { - auto stmt_unique = std::make_unique(i, global); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - auto _ = current_ast_builder().create_scope(stmt->body); - func(); -} - -// End: legacy frontend constructs - TLANG_NAMESPACE_END diff --git a/taichi/ir/frontend.h b/taichi/ir/frontend.h index e249913271726..fc96630a66ae7 100644 --- a/taichi/ir/frontend.h +++ b/taichi/ir/frontend.h @@ -34,13 +34,6 @@ Expr Rand() { return Expr::make(get_data_type()); } -template -T Eval(const T &t) { - return t.eval(); -} - -Expr copy(const Expr &expr); - template std::vector Axes(AX... axes) { auto ax_vec = std::vector({axes...}); @@ -59,18 +52,18 @@ inline Expr Atomic(Expr dest) { } // expr_group are indices -inline void Activate(SNode *snode, const ExprGroup &expr_group) { - current_ast_builder().insert(Stmt::make( - SNodeOpType::activate, snode, expr_group)); -} - -inline void Activate(const Expr &expr, const ExprGroup &expr_group) { - return Activate(expr.snode(), expr_group); +inline void Activate(ASTBuilder *ast_builder, + SNode *snode, + const ExprGroup &expr_group) { + ast_builder->insert(Stmt::make(SNodeOpType::activate, + snode, expr_group)); } -inline void Deactivate(SNode *snode, const ExprGroup &expr_group) { - current_ast_builder().insert(Stmt::make( - SNodeOpType::deactivate, snode, expr_group)); +inline void Deactivate(ASTBuilder *ast_builder, + SNode *snode, + const ExprGroup &expr_group) { + ast_builder->insert(Stmt::make(SNodeOpType::deactivate, + snode, expr_group)); } inline Expr Append(SNode *snode, const ExprGroup &indices, const Expr &val) { @@ -84,23 +77,10 @@ inline Expr Append(const Expr &expr, return Append(expr.snode(), indices, val); } -inline void InsertAssert(const std::string &text, const Expr &cond) { - current_ast_builder().insert(Stmt::make(cond, text)); -} - -inline void Clear(SNode *snode, const ExprGroup &indices) { - current_ast_builder().insert( - Stmt::make(SNodeOpType::clear, snode, indices)); -} - inline Expr is_active(SNode *snode, const ExprGroup &indices) { return Expr::make(snode, SNodeOpType::is_active, indices); } -inline void Clear(const Expr &expr, const ExprGroup &indices) { - return Clear(expr.snode(), indices); -} - inline Expr Length(SNode *snode, const ExprGroup &indices) { return Expr::make(snode, SNodeOpType::length, indices); } @@ -117,29 +97,6 @@ inline Expr AssumeInRange(const Expr &expr, } inline Expr LoopUnique(const Expr &input, const std::vector &covers) { - return Expr::make(load_if_ptr(input), covers); + return Expr::make(input, covers); } - -void insert_snode_access_flag(SNodeAccessFlag v, const Expr &field); - -void reset_snode_access_flag(); - -// Begin: legacy frontend constructs - -class For { - public: - For(const Expr &i, - const Expr &s, - const Expr &e, - const std::function &func); - - For(const ExprGroup &i, - const Expr &global, - const std::function &func); - - For(const Expr &s, const Expr &e, const std::function &func); -}; - -// End: legacy frontend constructs - TLANG_NAMESPACE_END diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 89b022cb4638f..cbbf7c089d0b2 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -2,17 +2,21 @@ #include "taichi/ir/statements.h" #include "taichi/program/program.h" +#include "taichi/common/exceptions.h" TLANG_NAMESPACE_BEGIN +#define TI_ASSERT_TYPE_CHECKED(x) \ + TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \ + "[{}] was not type-checked", x.serialize()) + FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, const ExprGroup &indices, const Expr &val) - : op_type(op_type), snode(snode), indices(indices.loaded()), val(val) { + : op_type(op_type), snode(snode), indices(indices), val(val) { if (val.expr != nullptr) { TI_ASSERT(op_type == SNodeOpType::append); - this->val.set(load_if_ptr(val)); } else { TI_ASSERT(op_type != SNodeOpType::append); } @@ -21,71 +25,98 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) : lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); + if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { + lhs.expr->ret_type = rhs->ret_type; + } } IRNode *FrontendContext::root() { - return static_cast(root_node.get()); + return static_cast(root_node_.get()); } FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, - const Expr &global_var) - : global_var(global_var) { - vectorize = dec.vectorize; - bit_vectorize = dec.bit_vectorize; - num_cpu_threads = dec.num_cpu_threads; - strictly_serialized = dec.strictly_serialized; - block_dim = dec.block_dim; - auto cfg = get_current_program().config; - if (cfg.arch == Arch::cuda) { - vectorize = 1; - num_cpu_threads = 1; - TI_ASSERT(block_dim <= taichi_max_gpu_block_dim); + const Expr &global_var, + Arch arch, + const ForLoopConfig &config) + : global_var(global_var), + bit_vectorize(config.bit_vectorize), + num_cpu_threads(config.num_cpu_threads), + strictly_serialized(config.strictly_serialized), + mem_access_opt(config.mem_access_opt), + block_dim(config.block_dim) { + if (arch == Arch::cuda) { + this->num_cpu_threads = 1; + TI_ASSERT(this->block_dim <= taichi_max_gpu_block_dim); } else { // cpu - if (num_cpu_threads == 0) - num_cpu_threads = std::thread::hardware_concurrency(); + if (this->num_cpu_threads == 0) + this->num_cpu_threads = std::thread::hardware_concurrency(); } - mem_access_opt = dec.mem_access_opt; - dec.reset(); - if (vectorize == -1) - vectorize = 1; - loop_var_id.resize(loop_var.size()); for (int i = 0; i < (int)loop_var.size(); i++) { loop_var_id[i] = loop_var[i].cast()->id; + loop_var[i].expr->ret_type = PrimitiveType::i32; } } -DecoratorRecorder dec; +FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, + const mesh::MeshPtr &mesh, + const mesh::MeshElementType &element_type, + Arch arch, + const ForLoopConfig &config) + : bit_vectorize(config.bit_vectorize), + num_cpu_threads(config.num_cpu_threads), + mem_access_opt(config.mem_access_opt), + block_dim(config.block_dim), + mesh_for(true), + mesh(mesh.ptr.get()), + element_type(element_type) { + if (arch == Arch::cuda) { + this->num_cpu_threads = 1; + TI_ASSERT(this->block_dim <= taichi_max_gpu_block_dim); + } else { + // cpu + if (this->num_cpu_threads == 0) + this->num_cpu_threads = std::thread::hardware_concurrency(); + } + loop_var_id.resize(loop_var.size()); + for (int i = 0; i < (int)loop_var.size(); i++) { + loop_var_id[i] = loop_var[i].cast()->id; + } +} -FrontendContext::FrontendContext() { - root_node = std::make_unique(); - current_builder = std::make_unique(root_node.get()); +FrontendContext::FrontendContext(Arch arch) { + root_node_ = std::make_unique(); + current_builder_ = std::make_unique(root_node_.get(), arch); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, const Expr &begin, - const Expr &end) - : begin(begin), end(end) { - vectorize = dec.vectorize; - bit_vectorize = dec.bit_vectorize; - num_cpu_threads = dec.num_cpu_threads; - strictly_serialized = dec.strictly_serialized; - block_dim = dec.block_dim; - auto cfg = get_current_program().config; - if (cfg.arch == Arch::cuda) { - vectorize = 1; - num_cpu_threads = 1; + const Expr &end, + Arch arch, + const ForLoopConfig &config) + : begin(begin), + end(end), + bit_vectorize(config.bit_vectorize), + num_cpu_threads(config.num_cpu_threads), + strictly_serialized(config.strictly_serialized), + mem_access_opt(config.mem_access_opt), + block_dim(config.block_dim) { + if (arch == Arch::cuda) { + this->num_cpu_threads = 1; } else { - if (num_cpu_threads == 0) - num_cpu_threads = std::thread::hardware_concurrency(); + if (this->num_cpu_threads == 0) + this->num_cpu_threads = std::thread::hardware_concurrency(); } - mem_access_opt = dec.mem_access_opt; - dec.reset(); - if (vectorize == -1) - vectorize = 1; loop_var_id.resize(1); loop_var_id[0] = loop_var.cast()->id; + loop_var.expr->ret_type = PrimitiveType::i32; +} + +void ArgLoadExpression::type_check(CompileConfig *) { + TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, + "Invalid dt [{}] for ArgLoadExpression", dt->to_string()); + ret_type = dt; } void ArgLoadExpression::flatten(FlattenContext *ctx) { @@ -94,21 +125,44 @@ void ArgLoadExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void RandExpression::type_check(CompileConfig *) { + TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, + "Invalid dt [{}] for RandExpression", dt->to_string()); + ret_type = dt; +} + void RandExpression::flatten(FlattenContext *ctx) { auto ran = std::make_unique(dt); ctx->push_back(std::move(ran)); stmt = ctx->back_stmt(); } -std::string UnaryOpExpression::serialize() { +void UnaryOpExpression::serialize(std::ostream &ss) { + ss << '('; if (is_cast()) { - std::string reint = type == UnaryOpType::cast_value ? "" : "reinterpret_"; - return fmt::format("({}{}<{}> {})", reint, unary_op_type_name(type), - data_type_name(cast_type), operand->serialize()); + ss << (type == UnaryOpType::cast_value ? "" : "reinterpret_"); + ss << unary_op_type_name(type); + ss << '<' << data_type_name(cast_type) << "> "; } else { - return fmt::format("({} {})", unary_op_type_name(type), - operand->serialize()); + ss << unary_op_type_name(type) << ' '; } + operand->serialize(ss); + ss << ')'; +} + +void UnaryOpExpression::type_check(CompileConfig *) { + TI_ASSERT_TYPE_CHECKED(operand); + if (!operand->ret_type->is()) + throw TaichiTypeError( + fmt::format("unsupported operand type(s) for '{}': '{}'", + unary_op_type_name(type), operand->ret_type->to_string())); + if ((type == UnaryOpType::round || type == UnaryOpType::floor || + type == UnaryOpType::ceil || is_trigonometric(type)) && + !is_real(operand->ret_type)) + throw TaichiTypeError( + fmt::format("'{}' takes real inputs only, however '{}' is provided", + unary_op_type_name(type), operand->ret_type->to_string())); + ret_type = is_cast() ? cast_type : operand->ret_type; } bool UnaryOpExpression::is_cast() const { @@ -116,7 +170,7 @@ bool UnaryOpExpression::is_cast() const { } void UnaryOpExpression::flatten(FlattenContext *ctx) { - operand->flatten(ctx); + flatten_rvalue(operand, ctx); auto unary = std::make_unique(type, operand->stmt); if (is_cast()) { unary->cast_type = cast_type; @@ -126,52 +180,97 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { ctx->push_back(std::move(unary)); } +void BinaryOpExpression::type_check(CompileConfig *config) { + TI_ASSERT_TYPE_CHECKED(lhs); + TI_ASSERT_TYPE_CHECKED(rhs); + auto lhs_type = lhs->ret_type; + auto rhs_type = rhs->ret_type; + auto error = [&]() { + throw TaichiTypeError( + fmt::format("unsupported operand type(s) for '{}': '{}' and '{}'", + binary_op_type_symbol(type), lhs->ret_type->to_string(), + rhs->ret_type->to_string())); + }; + if (!lhs_type->is() || !rhs_type->is()) + error(); + if (binary_is_bitwise(type) && + (!is_integral(lhs_type) || !is_integral(rhs_type))) + error(); + if (is_comparison(type)) { + ret_type = PrimitiveType::i32; + return; + } + if (type == BinaryOpType::truediv) { + auto default_fp = config->default_fp; + if (!is_real(lhs_type)) { + lhs_type = default_fp; + } + if (!is_real(rhs_type)) { + rhs_type = default_fp; + } + } + ret_type = promoted_type(lhs_type, rhs_type); +} + void BinaryOpExpression::flatten(FlattenContext *ctx) { // if (stmt) // return; - lhs->flatten(ctx); - rhs->flatten(ctx); + flatten_rvalue(lhs, ctx); + flatten_rvalue(rhs, ctx); ctx->push_back(std::make_unique(type, lhs->stmt, rhs->stmt)); ctx->stmts.back()->tb = tb; stmt = ctx->back_stmt(); } +void TernaryOpExpression::type_check(CompileConfig *) { + TI_ASSERT_TYPE_CHECKED(op1); + TI_ASSERT_TYPE_CHECKED(op2); + TI_ASSERT_TYPE_CHECKED(op3); + auto op1_type = op1->ret_type; + auto op2_type = op2->ret_type; + auto op3_type = op3->ret_type; + auto error = [&]() { + throw TaichiTypeError( + fmt::format("unsupported operand type(s) for '{}': '{}', '{}' and '{}'", + ternary_type_name(type), op1->ret_type->to_string(), + op2->ret_type->to_string(), op3->ret_type->to_string())); + }; + if (!is_integral(op1_type) || !op2_type->is() || + !op3_type->is()) + error(); + ret_type = promoted_type(op2_type, op3_type); +} + void TernaryOpExpression::flatten(FlattenContext *ctx) { // if (stmt) // return; - op1->flatten(ctx); - op2->flatten(ctx); - op3->flatten(ctx); + flatten_rvalue(op1, ctx); + flatten_rvalue(op2, ctx); + flatten_rvalue(op3, ctx); ctx->push_back( std::make_unique(type, op1->stmt, op2->stmt, op3->stmt)); stmt = ctx->back_stmt(); } +void InternalFuncCallExpression::type_check(CompileConfig *) { + for (auto &arg : args) { + TI_ASSERT_TYPE_CHECKED(arg); + // no arg type compatibility check for now due to lack of specification + } + // internal func calls have default return type + ret_type = PrimitiveType::i32; +} + void InternalFuncCallExpression::flatten(FlattenContext *ctx) { std::vector args_stmts(args.size()); for (int i = 0; i < (int)args.size(); ++i) { - args[i]->flatten(ctx); + flatten_rvalue(args[i], ctx); args_stmts[i] = args[i]->stmt; } ctx->push_back(func_name, args_stmts); stmt = ctx->back_stmt(); } -void ExternalFuncCallExpression::flatten(FlattenContext *ctx) { - std::vector arg_statements, output_statements; - for (auto &s : args) { - s.set(load_if_ptr(s)); - s->flatten(ctx); - arg_statements.push_back(s->stmt); - } - for (auto &s : outputs) { - output_statements.push_back(s.cast()->flatten_noload(ctx)); - } - ctx->push_back(std::make_unique( - func, source, arg_statements, output_statements)); - stmt = ctx->back_stmt(); -} - void ExternalTensorExpression::flatten(FlattenContext *ctx) { auto ptr = Stmt::make(arg_id, dt, /*is_ptr=*/true); ctx->push_back(std::move(ptr)); @@ -185,16 +284,42 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) { ctx->push_back(std::move(ptr)); } -std::string GlobalPtrExpression::serialize() { - std::string s = fmt::format( - "{}[", snode ? snode->get_node_type_name_hinted() : var.serialize()); +void GlobalPtrExpression::type_check(CompileConfig *) { + // Currently, dimension compatibility check happens in Python + if (snode != nullptr) { + ret_type = snode->dt; + } else if (var.is()) { + ret_type = + var.cast()->snode->dt->get_compute_type(); + } else if (var.is()) { + for (int i = 0; i < indices.exprs.size(); i++) { + auto &expr = indices.exprs[i]; + TI_ASSERT_TYPE_CHECKED(expr); + if (!is_integral(expr->ret_type)) + throw TaichiTypeError( + fmt::format("indices must be integers, however '{}' is " + "provided as index {}", + expr->ret_type->to_string(), i)); + } + ret_type = var.cast()->dt; + } else { + TI_ERROR("Invalid GlobalPtrExpression"); + } +} + +void GlobalPtrExpression::serialize(std::ostream &ss) { + if (snode) { + ss << snode->get_node_type_name_hinted(); + } else { + var.serialize(ss); + } + ss << '['; for (int i = 0; i < (int)indices.size(); i++) { - s += indices.exprs[i]->serialize(); + indices.exprs[i]->serialize(ss); if (i + 1 < (int)indices.size()) - s += ", "; + ss << ", "; } - s += "]"; - return s; + ss << ']'; } void GlobalPtrExpression::flatten(FlattenContext *ctx) { @@ -209,7 +334,7 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { offsets = snode->index_offsets; } for (int i = 0; i < (int)indices.size(); i++) { - indices.exprs[i]->flatten(ctx); + flatten_rvalue(indices.exprs[i], ctx); Stmt *ind = indices.exprs[i]->stmt; if (!offsets.empty()) { // Subtract offsets from indices so that new indices are @@ -223,13 +348,34 @@ void GlobalPtrExpression::flatten(FlattenContext *ctx) { ctx->push_back(std::make_unique(snode, index_stmts)); } else { TI_ASSERT(var.is()); - var->flatten(ctx); + flatten_lvalue(var, ctx); + auto expr = var.cast(); ctx->push_back(std::make_unique( - var.cast()->stmt, index_stmts)); + expr->stmt, index_stmts, expr->element_shape, expr->element_dim)); } stmt = ctx->back_stmt(); } +void TensorElementExpression::type_check(CompileConfig *) { + std::string invalid_msg{ + "Invalid TensorElementExpression: the source is neither a local tensor " + "nor a global tensor field"}; + if (is_local_tensor()) { + TI_ASSERT_INFO(var->ret_type->is(), invalid_msg); + ret_type = var->ret_type->cast()->get_element_type(); + } else if (is_global_tensor()) { + TI_ASSERT_INFO( + var.is() && + var.cast()->var.is(), + invalid_msg); + ret_type = var.cast() + ->var.cast() + ->snode->dt; + } else { + TI_ERROR(invalid_msg); + } +} + bool TensorElementExpression::is_local_tensor() const { return var.is(); } @@ -239,128 +385,121 @@ bool TensorElementExpression::is_global_tensor() const { } void TensorElementExpression::flatten(FlattenContext *ctx) { - var->flatten(ctx); - Stmt *var_stmt = var->stmt; - DataType element_type; - if (var.is()) { - // Local tensor subscripting - TI_ASSERT(layout_stride == 1); - TI_ASSERT(var_stmt->ret_type->is()); - auto tensor_type = var_stmt->ret_type->cast(); - element_type = tensor_type->get_element_type(); - } else { - TI_ASSERT(var.is()); - // Global tensor subscripting - SNode *snode = var.cast() - ->var.cast() - ->snode; - // layout_stride != 1 is satisfied if and only if subscripting on SOA - // global tensor. - TI_ASSERT(layout_stride == 1 || snode->is_path_all_dense); - element_type = snode->dt; - } - // Compute exact offset - // Type A[x, y, ...] - // ^^^^^^^^^ - indices[0].set(load_if_ptr(indices[0])); - indices[0]->flatten(ctx); - Stmt *offset_stmt = indices[0]->stmt; - for (int i = 1; i < (int)shape.size(); ++i) { - Stmt *shape_on_i = - ctx->push_back(Stmt::make(TypedConstant(shape[i]))); - Stmt *mul_stmt = ctx->push_back( - Stmt::make(BinaryOpType::mul, offset_stmt, shape_on_i)); - indices[i].set(load_if_ptr(indices[i])); - indices[i]->flatten(ctx); - ctx->push_back(Stmt::make(BinaryOpType::add, mul_stmt, - indices[i]->stmt)); - offset_stmt = ctx->back_stmt(); - } - // Type A[x, y, ...] - // ^^^^ - Stmt *dt_size_stmt = ctx->push_back( - Stmt::make(TypedConstant(data_type_size(element_type)))); - ctx->push_back( - Stmt::make(BinaryOpType::mul, offset_stmt, dt_size_stmt)); - offset_stmt = ctx->back_stmt(); - Stmt *layout_stride_stmt = - ctx->push_back(Stmt::make(TypedConstant(layout_stride))); - ctx->push_back(Stmt::make(BinaryOpType::mul, offset_stmt, - layout_stride_stmt)); - ctx->push_back(std::make_unique(var_stmt, ctx->back_stmt())); - stmt = ctx->back_stmt(); + flatten_lvalue(var, ctx); + Stmt *offset_stmt = ctx->push_back(TypedConstant(0)); + for (int i = 0; i < (int)shape.size(); ++i) { + flatten_rvalue(indices[i], ctx); + Stmt *shape_stmt = ctx->push_back(TypedConstant(shape[i])); + Stmt *mul_stmt = ctx->push_back(BinaryOpType::mul, + offset_stmt, shape_stmt); + offset_stmt = ctx->push_back(BinaryOpType::add, mul_stmt, + indices[i]->stmt); + } + Stmt *stride_stmt = ctx->push_back(TypedConstant(stride)); + offset_stmt = + ctx->push_back(BinaryOpType::mul, offset_stmt, stride_stmt); + stmt = ctx->push_back(var->stmt, offset_stmt); +} + +void RangeAssumptionExpression::type_check(CompileConfig *) { + TI_ASSERT_TYPE_CHECKED(input); + TI_ASSERT_TYPE_CHECKED(base); + if (!input->ret_type->is() || + !base->ret_type->is() || input->ret_type != base->ret_type) + throw TaichiTypeError( + fmt::format("unsupported operand type(s) for " + "'range_assumption': '{}' and '{}'", + input->ret_type->to_string(), base->ret_type->to_string())); + ret_type = input->ret_type; } void RangeAssumptionExpression::flatten(FlattenContext *ctx) { - input->flatten(ctx); - base->flatten(ctx); + flatten_rvalue(input, ctx); + flatten_rvalue(base, ctx); ctx->push_back( Stmt::make(input->stmt, base->stmt, low, high)); stmt = ctx->back_stmt(); } -std::string LoopUniqueExpression::serialize() { - std::string result = "loop_unique(" + input->serialize(); +void LoopUniqueExpression::type_check(CompileConfig *) { + TI_ASSERT_TYPE_CHECKED(input); + if (!input->ret_type->is()) + throw TaichiTypeError( + fmt::format("unsupported operand type(s) for 'loop_unique': '{}'", + input->ret_type->to_string())); + ret_type = input->ret_type; +} + +void LoopUniqueExpression::serialize(std::ostream &ss) { + ss << "loop_unique("; + input.serialize(ss); for (int i = 0; i < covers.size(); i++) { if (i == 0) - result += ", covers=["; - result += covers[i]->get_node_type_name_hinted(); + ss << ", covers=["; + ss << covers[i]->get_node_type_name_hinted(); if (i == (int)covers.size() - 1) - result += "]"; + ss << ']'; else - result += ", "; + ss << ", "; } - result += ")"; - return result; + ss << ')'; } void LoopUniqueExpression::flatten(FlattenContext *ctx) { - input->flatten(ctx); + flatten_rvalue(input, ctx); ctx->push_back(Stmt::make(input->stmt, covers)); stmt = ctx->back_stmt(); } void IdExpression::flatten(FlattenContext *ctx) { - auto var_stmt = ctx->current_block->lookup_var(id); - if (var_stmt->is()) { - if (var_stmt->ret_type->is()) { - // For TensorType alloca, directly return the first element's address - stmt = var_stmt; - } else { - // For other alloca, load the value and then return - ctx->push_back( - std::make_unique(LocalAddress(var_stmt, 0))); - stmt = ctx->back_stmt(); - } + stmt = ctx->current_block->lookup_var(id); +} + +void AtomicOpExpression::type_check(CompileConfig *) { + TI_ASSERT_TYPE_CHECKED(dest); + TI_ASSERT_TYPE_CHECKED(val); + auto error = [&]() { + throw TaichiTypeError(fmt::format( + "unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", + atomic_op_type_name(op_type), dest->ret_type->to_string(), + val->ret_type->to_string())); + }; + if (!val->ret_type->is()) + error(); + if (auto cit = dest->ret_type->cast()) { + ret_type = cit->get_compute_type(); + } else if (auto cft = dest->ret_type->cast()) { + ret_type = cft->get_compute_type(); + } else if (dest->ret_type->is()) { + ret_type = dest->ret_type; } else { - // The loop index may have a coordinate offset. - TI_ASSERT(var_stmt->is() || var_stmt->is()); - stmt = var_stmt; + error(); } } -std::string AtomicOpExpression::serialize() { +void AtomicOpExpression::serialize(std::ostream &ss) { if (op_type == AtomicOpType::add) { - return fmt::format("atomic_add({}, {})", dest.serialize(), val.serialize()); + ss << "atomic_add("; } else if (op_type == AtomicOpType::sub) { - return fmt::format("atomic_sub({}, {})", dest.serialize(), val.serialize()); + ss << "atomic_sub("; } else if (op_type == AtomicOpType::min) { - return fmt::format("atomic_min({}, {})", dest.serialize(), val.serialize()); + ss << "atomic_min("; } else if (op_type == AtomicOpType::max) { - return fmt::format("atomic_max({}, {})", dest.serialize(), val.serialize()); + ss << "atomic_max("; } else if (op_type == AtomicOpType::bit_and) { - return fmt::format("atomic_bit_and({}, {})", dest.serialize(), - val.serialize()); + ss << "atomic_bit_and("; } else if (op_type == AtomicOpType::bit_or) { - return fmt::format("atomic_bit_or({}, {})", dest.serialize(), - val.serialize()); + ss << "atomic_bit_or("; } else if (op_type == AtomicOpType::bit_xor) { - return fmt::format("atomic_bit_xor({}, {})", dest.serialize(), - val.serialize()); + ss << "atomic_bit_xor("; } else { // min/max not supported in the LLVM backend yet. TI_NOT_IMPLEMENTED; } + dest.serialize(ss); + ss << ", "; + val.serialize(ss); + ss << ")"; } void AtomicOpExpression::flatten(FlattenContext *ctx) { @@ -371,39 +510,45 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { } // expand rhs auto expr = val; - expr->flatten(ctx); + flatten_rvalue(expr, ctx); if (dest.is()) { // local variable // emit local store stmt auto alloca = ctx->current_block->lookup_var(dest.cast()->id); ctx->push_back(op_type, alloca, expr->stmt); - } else if (dest.is()) { - auto tensor_ptr = dest.cast(); - tensor_ptr->flatten(ctx); - ctx->push_back(op_type, tensor_ptr->stmt, expr->stmt); - } else { // global variable - TI_ASSERT(dest.is()); - auto global_ptr = dest.cast(); - global_ptr->flatten(ctx); - ctx->push_back(op_type, ctx->back_stmt(), expr->stmt); + } else { + TI_ASSERT(dest.is() || + dest.is()); + flatten_lvalue(dest, ctx); + ctx->push_back(op_type, dest->stmt, expr->stmt); } stmt = ctx->back_stmt(); } -std::string SNodeOpExpression::serialize() { - if (value.expr) { - return fmt::format("{}({}, [{}], {})", snode_op_type_name(op_type), - snode->get_node_type_name_hinted(), indices.serialize(), - value.serialize()); +void SNodeOpExpression::type_check(CompileConfig *) { + if (op_type == SNodeOpType::get_addr) { + ret_type = PrimitiveType::u64; } else { - return fmt::format("{}({}, [{}])", snode_op_type_name(op_type), - snode->get_node_type_name_hinted(), indices.serialize()); + ret_type = PrimitiveType::i32; + } +} + +void SNodeOpExpression::serialize(std::ostream &ss) { + ss << snode_op_type_name(op_type); + ss << '('; + ss << snode->get_node_type_name_hinted() << ", ["; + indices.serialize(ss); + ss << "]"; + if (value.expr) { + ss << ' '; + value.serialize(ss); } + ss << ')'; } void SNodeOpExpression::flatten(FlattenContext *ctx) { std::vector indices_stmt; for (int i = 0; i < (int)indices.size(); i++) { - indices[i]->flatten(ctx); + flatten_rvalue(indices[i], ctx); indices_stmt.push_back(indices[i]->stmt); } auto ptr = ctx->push_back(snode, indices_stmt); @@ -418,7 +563,7 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { } else if (op_type == SNodeOpType::get_addr) { ctx->push_back(SNodeOpType::get_addr, snode, ptr, nullptr); } else if (op_type == SNodeOpType::append) { - value->flatten(ctx); + flatten_rvalue(value, ctx); ctx->push_back(SNodeOpType::append, snode, ptr, value->stmt); TI_ERROR_IF(snode->type != SNodeType::dynamic, "ti.append only works on dynamic nodes."); @@ -430,21 +575,11 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void LocalLoadExpression::flatten(FlattenContext *ctx) { - ptr->flatten(ctx); - auto ptr_offset_stmt = ctx->back_stmt(); - TI_ASSERT(ptr_offset_stmt->is()); - auto local_addr = - LaneAttribute(LocalAddress(ptr_offset_stmt, 0)); - auto local_load_stmt = - ctx->push_back(LaneAttribute(local_addr)); - stmt = local_load_stmt; -} - -void GlobalLoadExpression::flatten(FlattenContext *ctx) { - ptr->flatten(ctx); - ctx->push_back(std::make_unique(ptr->stmt)); - stmt = ctx->back_stmt(); +void ConstExpression::type_check(CompileConfig *) { + TI_ASSERT_INFO( + val.dt->is() && val.dt != PrimitiveType::unknown, + "Invalid dt [{}] for ConstExpression", val.dt->to_string()); + ret_type = val.dt; } void ConstExpression::flatten(FlattenContext *ctx) { @@ -452,6 +587,13 @@ void ConstExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void ExternalTensorShapeAlongAxisExpression::type_check(CompileConfig *) { + TI_ASSERT_INFO(ptr.is(), + "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression", + ptr.serialize()); + ret_type = PrimitiveType::i32; +} + void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { auto temp = ptr.cast(); TI_ASSERT(0 <= axis && axis < temp->dim); @@ -459,55 +601,370 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void FuncCallExpression::type_check(CompileConfig *) { + for (auto &arg : args.exprs) { + TI_ASSERT_TYPE_CHECKED(arg); + // no arg type compatibility check for now due to lack of specification + } + TI_ASSERT_INFO(func->rets.size() <= 1, + "Too many (> 1) return values for FuncCallExpression"); + if (func->rets.size() == 1) { + ret_type = func->rets[0].dt; + } +} + void FuncCallExpression::flatten(FlattenContext *ctx) { std::vector stmt_args; for (auto &arg : args.exprs) { - arg->flatten(ctx); + flatten_rvalue(arg, ctx); stmt_args.push_back(arg->stmt); } ctx->push_back(func, stmt_args); stmt = ctx->back_stmt(); } -std::string FuncCallExpression::serialize() { - return fmt::format("func_call(\"{}\", {})", func->func_key.get_full_name(), - args.serialize()); +void FuncCallExpression::serialize(std::ostream &ss) { + ss << "func_call(\"" << func->func_key.get_full_name() << "\", "; + args.serialize(ss); + ss << ')'; +} + +// Mesh related. + +void MeshPatchIndexExpression::flatten(FlattenContext *ctx) { + auto pid_stmt = std::make_unique(); + ctx->push_back(std::move(pid_stmt)); + stmt = ctx->back_stmt(); +} + +void MeshPatchIndexExpression::type_check(CompileConfig *) { + ret_type = PrimitiveType::i32; +} + +void MeshRelationAccessExpression::type_check(CompileConfig *) { + ret_type = PrimitiveType::i32; +} + +void MeshRelationAccessExpression::flatten(FlattenContext *ctx) { + flatten_rvalue(mesh_idx, ctx); + if (neighbor_idx) { + flatten_rvalue(neighbor_idx, ctx); + ctx->push_back(mesh, mesh_idx->stmt, to_type, + neighbor_idx->stmt); + } else { + ctx->push_back(mesh, mesh_idx->stmt, to_type); + } + stmt = ctx->back_stmt(); +} + +void MeshIndexConversionExpression::type_check(CompileConfig *) { + ret_type = PrimitiveType::i32; +} + +void MeshIndexConversionExpression::flatten(FlattenContext *ctx) { + flatten_rvalue(idx, ctx); + ctx->push_back(mesh, idx_type, idx->stmt, conv_type); + stmt = ctx->back_stmt(); } Block *ASTBuilder::current_block() { - if (stack.empty()) + if (stack_.empty()) return nullptr; else - return stack.back(); + return stack_.back(); } Stmt *ASTBuilder::get_last_stmt() { - TI_ASSERT(!stack.empty()); - return stack.back()->back(); + TI_ASSERT(!stack_.empty()); + return stack_.back()->back(); } void ASTBuilder::insert(std::unique_ptr &&stmt, int location) { - TI_ASSERT(!stack.empty()); - stack.back()->insert(std::move(stmt), location); + TI_ASSERT(!stack_.empty()); + stack_.back()->insert(std::move(stmt), location); } void ASTBuilder::stop_gradient(SNode *snode) { - TI_ASSERT(!stack.empty()); - stack.back()->stop_gradients.push_back(snode); + TI_ASSERT(!stack_.empty()); + stack_.back()->stop_gradients.push_back(snode); +} + +void ASTBuilder::insert_assignment(Expr &lhs, const Expr &rhs) { + // Inside a kernel or a function + // Create an assignment in the IR + if (lhs.expr == nullptr) { + lhs.set(rhs); + } else if (lhs.expr->is_lvalue()) { + this->insert(std::make_unique(lhs, rhs)); + } else { + TI_ERROR("Cannot assign to non-lvalue: {}", lhs.serialize()); + } +} + +Expr ASTBuilder::make_var(const Expr &x) { + auto var = Expr(std::make_shared()); + this->insert(std::make_unique( + std::static_pointer_cast(var.expr)->id, + PrimitiveType::unknown)); + this->insert_assignment(var, x); + return var; +} + +void ASTBuilder::insert_for(const Expr &s, + const Expr &e, + const std::function &func) { + auto i = Expr(std::make_shared()); + auto stmt_unique = std::make_unique(i, s, e, this->arch_, + for_loop_dec_.config); + for_loop_dec_.reset(); + auto stmt = stmt_unique.get(); + this->insert(std::move(stmt_unique)); + this->create_scope(stmt->body); + func(i); + this->pop_scope(); +} + +Expr ASTBuilder::insert_thread_idx_expr() { + auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr; + TI_ERROR_IF(arch_ != Arch::cuda && !arch_is_cpu(arch_), + "ti.thread_idx() is only available in cuda or cpu context."); + if (loop != nullptr) { + auto i = stack_.size() - 1; + while (!(loop->is())) { + loop = i > 0 ? stack_[--i]->parent_stmt : nullptr; + if (loop == nullptr) + break; + } + } + TI_ERROR_IF(!(loop && loop->is()), + "ti.thread_idx() is only valid within loops."); + return Expr::make("linear_thread_idx", + std::vector{}); +} + +Expr ASTBuilder::insert_patch_idx_expr() { + auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr; + if (loop != nullptr) { + auto i = stack_.size() - 1; + while (!(loop->is())) { + loop = i > 0 ? stack_[--i]->parent_stmt : nullptr; + if (loop == nullptr) + break; + } + } + TI_ERROR_IF(!(loop && loop->is() && + loop->as()->mesh_for), + "ti.mesh_patch_idx() is only valid within mesh-for loops."); + return Expr::make(); +} + +void ASTBuilder::create_kernel_exprgroup_return(const ExprGroup &group) { + this->insert(Stmt::make(group)); +} + +void ASTBuilder::create_print( + std::vector> contents) { + this->insert(std::make_unique(contents)); +} + +void ASTBuilder::begin_func(const std::string &funcid) { + auto stmt_unique = std::make_unique(funcid); + auto stmt = stmt_unique.get(); + this->insert(std::move(stmt_unique)); + this->create_scope(stmt->body); +} + +void ASTBuilder::end_func(const std::string &funcid) { + this->pop_scope(); +} + +void ASTBuilder::begin_frontend_if(const Expr &cond) { + auto stmt_tmp = std::make_unique(cond); + this->insert(std::move(stmt_tmp)); +} + +void ASTBuilder::begin_frontend_if_true() { + auto if_stmt = this->get_last_stmt()->as(); + this->create_scope(if_stmt->true_statements); +} + +void ASTBuilder::begin_frontend_if_false() { + auto if_stmt = this->get_last_stmt()->as(); + this->create_scope(if_stmt->false_statements); +} + +void ASTBuilder::insert_external_func_call(std::size_t func_addr, + std::string source, + std::string filename, + std::string funcname, + const ExprGroup &args, + const ExprGroup &outputs) { + auto stmt = Stmt::make( + (void *)func_addr, source, filename, funcname, args.exprs, outputs.exprs); + this->insert(std::move(stmt)); +} + +Expr ASTBuilder::expr_alloca() { + auto var = Expr(std::make_shared()); + this->insert(std::make_unique( + std::static_pointer_cast(var.expr)->id, + PrimitiveType::unknown)); + return var; +} + +Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, + const DataType &element_type, + const ExprGroup &elements) { + auto var = Expr(std::make_shared()); + this->insert(std::make_unique( + std::static_pointer_cast(var.expr)->id, shape, + element_type)); + var->ret_type = this->get_last_stmt()->ret_type; + for (int i = 0; i < (int)elements.exprs.size(); ++i) { + ExprGroup reversed_indices; + int linearized_index = i; + for (int d = (int)shape.size() - 1; d >= 0; --d) { + reversed_indices.push_back( + Expr::make(linearized_index % shape[d])); + linearized_index /= shape[d]; + } + ExprGroup indices; + for (int d = 0; d < (int)shape.size(); ++d) + indices.push_back(reversed_indices[(int)shape.size() - 1 - d]); + this->insert(std::make_unique( + Expr::make(var, indices, shape, 1), + elements.exprs[i])); + } + return var; +} + +void ASTBuilder::expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) { + TI_ASSERT(lhs->is_lvalue()); + auto stmt = std::make_unique(lhs, rhs); + stmt->set_tb(tb); + this->insert(std::move(stmt)); +} + +void ASTBuilder::create_assert_stmt(const Expr &cond, + const std::string &msg, + const std::vector &args) { + auto stmt_unique = std::make_unique(cond, msg, args); + this->insert(std::move(stmt_unique)); +} + +void ASTBuilder::begin_frontend_range_for(const Expr &i, + const Expr &s, + const Expr &e) { + auto stmt_unique = + std::make_unique(i, s, e, arch_, for_loop_dec_.config); + for_loop_dec_.reset(); + auto stmt = stmt_unique.get(); + this->insert(std::move(stmt_unique)); + this->create_scope(stmt->body, For); +} + +void ASTBuilder::begin_frontend_struct_for(const ExprGroup &loop_vars, + const Expr &global) { + auto stmt_unique = std::make_unique(loop_vars, global, arch_, + for_loop_dec_.config); + for_loop_dec_.reset(); + auto stmt = stmt_unique.get(); + this->insert(std::move(stmt_unique)); + this->create_scope(stmt->body, For); +} + +void ASTBuilder::begin_frontend_mesh_for( + const Expr &i, + const mesh::MeshPtr &mesh_ptr, + const mesh::MeshElementType &element_type) { + auto stmt_unique = std::make_unique( + i, mesh_ptr, element_type, arch_, for_loop_dec_.config); + for_loop_dec_.reset(); + auto stmt = stmt_unique.get(); + this->insert(std::move(stmt_unique)); + this->create_scope(stmt->body, For); +} + +void ASTBuilder::begin_frontend_while(const Expr &cond) { + auto stmt_unique = std::make_unique(cond); + auto stmt = stmt_unique.get(); + this->insert(std::move(stmt_unique)); + this->create_scope(stmt->body, While); +} + +void ASTBuilder::insert_break_stmt() { + if (loop_state_stack_.back() == Outermost) { + throw TaichiSyntaxError("Cannot break in the outermost loop"); + } + this->insert(Stmt::make()); +} + +void ASTBuilder::insert_continue_stmt() { + this->insert(Stmt::make()); } -std::unique_ptr ASTBuilder::create_scope( - std::unique_ptr &list) { +void ASTBuilder::insert_expr_stmt(const Expr &val) { + this->insert(Stmt::make(val)); +} + +void ASTBuilder::create_scope(std::unique_ptr &list, LoopType tp) { TI_ASSERT(list == nullptr); + LoopState prev = loop_state_stack_.back(); + if (tp == NotLoop) { + loop_state_stack_.push_back(prev); + } else if (tp == For && stack_.size() == 1) { + loop_state_stack_.push_back(Outermost); + } else { + loop_state_stack_.push_back(Inner); + } list = std::make_unique(); - if (!stack.empty()) { + if (!stack_.empty()) { list->parent_stmt = get_last_stmt(); } - return std::make_unique(this, list.get()); + stack_.push_back(list.get()); +} + +void ASTBuilder::pop_scope() { + stack_.pop_back(); + loop_state_stack_.pop_back(); +} + +void flatten_lvalue(Expr expr, Expression::FlattenContext *ctx) { + expr->flatten(ctx); } -ASTBuilder ¤t_ast_builder() { - return get_current_program().current_callable->context->builder(); +void flatten_global_load(Expr ptr, Expression::FlattenContext *ctx) { + ctx->push_back(std::make_unique(ptr->stmt)); + ptr->stmt = ctx->back_stmt(); +} + +void flatten_local_load(Expr ptr, Expression::FlattenContext *ctx) { + ctx->push_back(LocalAddress(ptr->stmt, 0)); + ptr->stmt = ctx->back_stmt(); +} + +void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { + ptr->flatten(ctx); + if (ptr.is()) { + if (ptr->stmt->is()) { + flatten_local_load(ptr, ctx); + } + } else if (ptr.is()) { + flatten_global_load(ptr, ctx); + } else if (ptr.is()) { + TI_ASSERT(ptr.cast()->snode->num_active_indices == + 0); + flatten_global_load(ptr[ExprGroup()], ctx); + } else if (ptr.is()) { + auto tensor_ptr = ptr.cast(); + if (tensor_ptr->is_global_tensor()) + flatten_global_load(ptr, ctx); + else if (tensor_ptr->is_local_tensor()) + flatten_local_load(ptr, ctx); + else { + TI_NOT_IMPLEMENTED + } + } } TLANG_NAMESPACE_END diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a330bd28b53da..18eb2ad97df9b 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -7,11 +7,47 @@ #include "taichi/ir/stmt_op_types.h" #include "taichi/ir/ir.h" #include "taichi/ir/expression.h" +#include "taichi/backends/arch.h" #include "taichi/program/function.h" +#include "taichi/ir/mesh.h" TLANG_NAMESPACE_BEGIN +struct ForLoopConfig { + int bit_vectorize{0}; + int num_cpu_threads{0}; + bool strictly_serialized{false}; + MemoryAccessOptions mem_access_opt; + int block_dim{0}; + bool uniform{false}; +}; + // Frontend Statements +class FrontendExternalFuncStmt : public Stmt { + public: + void *so_func; + std::string asm_source; + std::string bc_filename; + std::string bc_funcname; + std::vector args; + std::vector outputs; + + FrontendExternalFuncStmt(void *so_func, + const std::string &asm_source, + const std::string &bc_filename, + const std::string &bc_funcname, + const std::vector &args, + const std::vector &outputs) + : so_func(so_func), + asm_source(asm_source), + bc_filename(bc_filename), + bc_funcname(bc_funcname), + args(args), + outputs(outputs) { + } + + TI_DEFINE_ACCEPT +}; class FrontendExprStmt : public Stmt { public: @@ -71,7 +107,7 @@ class FrontendAssertStmt : public Stmt { const std::vector &args_) : text(text), cond(cond) { for (auto &a : args_) { - args.push_back(load_if_ptr(a)); + args.push_back(a); } } @@ -92,7 +128,7 @@ class FrontendIfStmt : public Stmt { Expr condition; std::unique_ptr true_statements, false_statements; - FrontendIfStmt(const Expr &condition) : condition(load_if_ptr(condition)) { + FrontendIfStmt(const Expr &condition) : condition(condition) { } bool is_container_statement() const override { @@ -110,7 +146,7 @@ class FrontendPrintStmt : public Stmt { FrontendPrintStmt(const std::vector &contents_) { for (const auto &c : contents_) { if (std::holds_alternative(c)) - contents.push_back(load_if_ptr(std::get(c))); + contents.push_back(std::get(c)); else contents.push_back(c); } @@ -119,44 +155,46 @@ class FrontendPrintStmt : public Stmt { TI_DEFINE_ACCEPT }; -// This statement evaluates the expression. -// The expression should have side effects otherwise the expression will do -// nothing. -class FrontendEvalStmt : public Stmt { - public: - Expr expr; - Expr eval_expr; - - FrontendEvalStmt(const Expr &expr) : expr(load_if_ptr(expr)) { - } - - TI_DEFINE_ACCEPT -}; - class FrontendForStmt : public Stmt { public: Expr begin, end; Expr global_var; std::unique_ptr body; std::vector loop_var_id; - int vectorize; int bit_vectorize; int num_cpu_threads; bool strictly_serialized; MemoryAccessOptions mem_access_opt; int block_dim; + bool mesh_for = false; + mesh::Mesh *mesh; + mesh::MeshElementType element_type; + bool is_ranged() const { - if (global_var.expr == nullptr) { + if (global_var.expr == nullptr && !mesh_for) { return true; } else { return false; } } - FrontendForStmt(const ExprGroup &loop_var, const Expr &global_var); + FrontendForStmt(const ExprGroup &loop_var, + const Expr &global_var, + Arch arch, + const ForLoopConfig &config); + + FrontendForStmt(const ExprGroup &loop_var, + const mesh::MeshPtr &mesh, + const mesh::MeshElementType &element_type, + Arch arch, + const ForLoopConfig &config); - FrontendForStmt(const Expr &loop_var, const Expr &begin, const Expr &end); + FrontendForStmt(const Expr &loop_var, + const Expr &begin, + const Expr &end, + Arch arch, + const ForLoopConfig &config); bool is_container_statement() const override { return true; @@ -208,7 +246,7 @@ class FrontendWhileStmt : public Stmt { Expr cond; std::unique_ptr body; - FrontendWhileStmt(const Expr &cond) : cond(load_if_ptr(cond)) { + FrontendWhileStmt(const Expr &cond) : cond(cond) { } bool is_container_statement() const override { @@ -220,9 +258,9 @@ class FrontendWhileStmt : public Stmt { class FrontendReturnStmt : public Stmt { public: - Expr value; + ExprGroup values; - FrontendReturnStmt(const Expr &value) : value(value) { + FrontendReturnStmt(const ExprGroup &group) : values(group) { } bool is_container_statement() const override { @@ -242,8 +280,10 @@ class ArgLoadExpression : public Expression { ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) { } - std::string serialize() override { - return fmt::format("arg[{}] (dt={})", arg_id, data_type_name(dt)); + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << fmt::format("arg[{}] (dt={})", arg_id, data_type_name(dt)); } void flatten(FlattenContext *ctx) override; @@ -256,8 +296,10 @@ class RandExpression : public Expression { RandExpression(DataType dt) : dt(dt) { } - std::string serialize() override { - return fmt::format("rand<{}>()", data_type_name(dt)); + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << fmt::format("rand<{}>()", data_type_name(dt)); } void flatten(FlattenContext *ctx) override; @@ -270,13 +312,19 @@ class UnaryOpExpression : public Expression { DataType cast_type; UnaryOpExpression(UnaryOpType type, const Expr &operand) - : type(type), operand(smart_load(operand)) { + : type(type), operand(operand) { cast_type = PrimitiveType::unknown; } + UnaryOpExpression(UnaryOpType type, const Expr &operand, DataType cast_type) + : type(type), operand(operand), cast_type(cast_type) { + } + + void type_check(CompileConfig *config) override; + bool is_cast() const; - std::string serialize() override; + void serialize(std::ostream &ss) override; void flatten(FlattenContext *ctx) override; }; @@ -287,14 +335,19 @@ class BinaryOpExpression : public Expression { Expr lhs, rhs; BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) - : type(type) { - this->lhs.set(smart_load(lhs)); - this->rhs.set(smart_load(rhs)); + : type(type), lhs(lhs), rhs(rhs) { } - std::string serialize() override { - return fmt::format("({} {} {})", lhs->serialize(), - binary_op_type_symbol(type), rhs->serialize()); + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << '('; + lhs->serialize(ss); + ss << ' '; + ss << binary_op_type_symbol(type); + ss << ' '; + rhs->serialize(ss); + ss << ')'; } void flatten(FlattenContext *ctx) override; @@ -310,14 +363,21 @@ class TernaryOpExpression : public Expression { const Expr &op2, const Expr &op3) : type(type) { - this->op1.set(load_if_ptr(op1)); - this->op2.set(load_if_ptr(op2)); - this->op3.set(load_if_ptr(op3)); + this->op1.set(op1); + this->op2.set(op2); + this->op3.set(op3); } - std::string serialize() override { - return fmt::format("{}({} {} {})", ternary_type_name(type), - op1->serialize(), op2->serialize(), op3->serialize()); + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << ternary_type_name(type) << '('; + op1->serialize(ss); + ss << ' '; + op2->serialize(ss); + ss << ' '; + op3->serialize(ss); + ss << ')'; } void flatten(FlattenContext *ctx) override; @@ -332,61 +392,28 @@ class InternalFuncCallExpression : public Expression { const std::vector &args_) : func_name(func_name) { for (auto &a : args_) { - args.push_back(load_if_ptr(a)); + args.push_back(a); } } - std::string serialize() override { + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << "internal call " << func_name << '('; std::string args_str; for (int i = 0; i < args.size(); i++) { if (i != 0) { - args_str += ", "; + ss << ", "; } - args_str += args[i]->serialize(); - } - return fmt::format("internal call {}({})", func_name, args_str); - } - - void flatten(FlattenContext *ctx) override; -}; - -class ExternalFuncCallExpression : public Expression { - public: - void *func; - std::string source; - std::vector args; - std::vector outputs; - - ExternalFuncCallExpression(void *func, - std::string const &source, - const std::vector &args, - const std::vector &outputs) - : func(func), source(source), args(args), outputs(outputs) { - } - - std::string serialize() override { - std::string io = "inputs="; - - for (auto &s : args) { - io += s.serialize(); - } - - io += ", outputs="; - - for (auto &s : outputs) { - io += s.serialize(); - } - - if (func) { - return fmt::format("call {:x} ({})", (uint64)func, io); - } else { - return fmt::format("asm \"{}\" ({})", source, io); + args[i]->serialize(ss); } + ss << ')'; } void flatten(FlattenContext *ctx) override; }; +// TODO: Make this a non-expr class ExternalTensorExpression : public Expression { public: DataType dt; @@ -395,6 +422,9 @@ class ExternalTensorExpression : public Expression { int element_dim; // 0: scalar; 1: vector (SOA); 2: matrix (SOA); -1: vector // (AOS); -2: matrix (AOS) + // Fill element shape if compile-time specialization is desired. + std::vector element_shape; + ExternalTensorExpression(const DataType &dt, int dim, int arg_id, @@ -403,13 +433,26 @@ class ExternalTensorExpression : public Expression { set_attribute("dim", std::to_string(dim)); } - std::string serialize() override { - return fmt::format("{}d_ext_arr", dim); + ExternalTensorExpression(const DataType &dt, + int dim, + int arg_id, + int element_dim, + const std::vector &element_shape) + : ExternalTensorExpression(dt, dim, arg_id, element_dim) { + this->element_shape = element_shape; + } + + void type_check(CompileConfig *config) override { + } + + void serialize(std::ostream &ss) override { + ss << fmt::format("{}d_ext_arr", dim); } void flatten(FlattenContext *ctx) override; }; +// TODO: Make this a non-expr class GlobalVariableExpression : public Expression { public: Identifier ident; @@ -434,13 +477,16 @@ class GlobalVariableExpression : public Expression { is_primal = true; } + void type_check(CompileConfig *config) override { + } + void set_snode(SNode *snode) { this->snode = snode; set_attribute("dim", std::to_string(snode->num_active_indices)); } - std::string serialize() override { - return "#" + ident.name(); + void serialize(std::ostream &ss) override { + ss << "#" << ident.name(); } void flatten(FlattenContext *ctx) override; @@ -460,7 +506,9 @@ class GlobalPtrExpression : public Expression { : snode(snode), indices(indices) { } - std::string serialize() override; + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override; void flatten(FlattenContext *ctx) override; @@ -474,35 +522,38 @@ class TensorElementExpression : public Expression { Expr var; ExprGroup indices; std::vector shape; - int layout_stride{1}; + int stride{0}; TensorElementExpression(const Expr &var, const ExprGroup &indices, const std::vector &shape, - int layout_stride) - : var(var), indices(indices), shape(shape), layout_stride(layout_stride) { + int stride) + : var(var), indices(indices), shape(shape), stride(stride) { + // TODO: shape & indices check } + void type_check(CompileConfig *config) override; + bool is_local_tensor() const; bool is_global_tensor() const; - std::string serialize() override { - std::string s = fmt::format("{}[", var.serialize()); + void serialize(std::ostream &ss) override { + var.serialize(ss); + ss << '['; for (int i = 0; i < (int)indices.size(); i++) { - s += indices.exprs[i]->serialize(); + indices.exprs[i]->serialize(ss); if (i + 1 < (int)indices.size()) - s += ", "; + ss << ", "; } - s += "] ("; + ss << "] ("; for (int i = 0; i < (int)shape.size(); i++) { - s += std::to_string(shape[i]); + ss << std::to_string(shape[i]); if (i + 1 < (int)shape.size()) - s += ", "; + ss << ", "; } - s += ", layout_stride = " + std::to_string(layout_stride); - s += ")"; - return s; + ss << ", stride = " + std::to_string(stride); + ss << ')'; } void flatten(FlattenContext *ctx) override; @@ -512,23 +563,6 @@ class TensorElementExpression : public Expression { } }; -class EvalExpression : public Expression { - public: - Stmt *stmt_ptr; - int stmt_id; - EvalExpression(Stmt *stmt) : stmt_ptr(stmt), stmt_id(stmt_ptr->id) { - // cache stmt->id since it may be released later - } - - std::string serialize() override { - return fmt::format("%{}", stmt_id); - } - - void flatten(FlattenContext *ctx) override { - stmt = stmt_ptr; - } -}; - class RangeAssumptionExpression : public Expression { public: Expr input, base; @@ -541,10 +575,17 @@ class RangeAssumptionExpression : public Expression { : input(input), base(base), low(low), high(high) { } - std::string serialize() override { - return fmt::format("assume_in_range({}{:+d} <= ({}) < {}{:+d})", - base.serialize(), low, input.serialize(), - base.serialize(), high); + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << "assume_in_range({"; + base.serialize(ss); + ss << fmt::format("{:+d}", low); + ss << " <= ("; + input.serialize(ss); + ss << ") < "; + base.serialize(ss); + ss << fmt::format("{:+d})", high); } void flatten(FlattenContext *ctx) override; @@ -559,7 +600,9 @@ class LoopUniqueExpression : public Expression { : input(input), covers(covers) { } - std::string serialize() override; + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override; void flatten(FlattenContext *ctx) override; }; @@ -572,8 +615,11 @@ class IdExpression : public Expression { IdExpression(const Identifier &id) : id(id) { } - std::string serialize() override { - return id.name(); + void type_check(CompileConfig *config) override { + } + + void serialize(std::ostream &ss) override { + ss << id.name(); } void flatten(FlattenContext *ctx) override; @@ -597,7 +643,9 @@ class AtomicOpExpression : public Expression { : op_type(op_type), dest(dest), val(val) { } - std::string serialize() override; + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override; void flatten(FlattenContext *ctx) override; }; @@ -620,78 +668,150 @@ class SNodeOpExpression : public Expression { : snode(snode), op_type(op_type), indices(indices), value(value) { } - std::string serialize() override; + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override; void flatten(FlattenContext *ctx) override; }; -class LocalLoadExpression : public Expression { +class ConstExpression : public Expression { public: - Expr ptr; - LocalLoadExpression(const Expr &ptr) : ptr(ptr) { + TypedConstant val; + + template + ConstExpression(const T &x) : val(x) { + ret_type = val.dt; } + template + ConstExpression(const DataType &dt, const T &x) : val({dt, x}) { + ret_type = dt; + } + + void type_check(CompileConfig *config) override; - std::string serialize() override { - return "lcl load " + ptr.serialize(); + void serialize(std::ostream &ss) override { + ss << val.stringify(); } void flatten(FlattenContext *ctx) override; }; -class GlobalLoadExpression : public Expression { +class ExternalTensorShapeAlongAxisExpression : public Expression { public: Expr ptr; - GlobalLoadExpression(const Expr &ptr) : ptr(ptr) { + int axis; + + void serialize(std::ostream &ss) override { + ss << "external_tensor_shape_along_axis("; + ptr->serialize(ss); + ss << ", " << axis << ')'; } - std::string serialize() override { - return "gbl load " + ptr.serialize(); + ExternalTensorShapeAlongAxisExpression(const Expr &ptr, int axis) + : ptr(ptr), axis(axis) { } + void type_check(CompileConfig *config) override; + void flatten(FlattenContext *ctx) override; }; -class ConstExpression : public Expression { +class FuncCallExpression : public Expression { public: - TypedConstant val; + Function *func; + ExprGroup args; - template - ConstExpression(const T &x) : val(x) { + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override; + + FuncCallExpression(Function *func, const ExprGroup &args) + : func(func), args(args) { + } + + void flatten(FlattenContext *ctx) override; +}; + +// Mesh related. + +class MeshPatchIndexExpression : public Expression { + public: + MeshPatchIndexExpression() { } - std::string serialize() override { - return val.stringify(); + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + ss << fmt::format("mesh_patch_idx()"); } void flatten(FlattenContext *ctx) override; }; -class ExternalTensorShapeAlongAxisExpression : public Expression { +class MeshRelationAccessExpression : public Expression { public: - Expr ptr; - int axis; + mesh::Mesh *mesh; + Expr mesh_idx; + mesh::MeshElementType to_type; + Expr neighbor_idx; + + void type_check(CompileConfig *config) override; + + void serialize(std::ostream &ss) override { + if (neighbor_idx) { + ss << "mesh_relation_access("; + mesh_idx->serialize(ss); + ss << ", " << mesh::element_type_name(to_type) << "["; + neighbor_idx->serialize(ss); + ss << "])"; + } else { + ss << "mesh_relation_size("; + mesh_idx->serialize(ss); + ss << ", " << mesh::element_type_name(to_type) << ")"; + } + } - std::string serialize() override { - return fmt::format("external_tensor_shape_along_axis({}, {})", - ptr->serialize(), axis); + MeshRelationAccessExpression(mesh::Mesh *mesh, + const Expr mesh_idx, + mesh::MeshElementType to_type) + : mesh(mesh), mesh_idx(mesh_idx), to_type(to_type) { } - ExternalTensorShapeAlongAxisExpression(const Expr &ptr, int axis) - : ptr(ptr), axis(axis) { + MeshRelationAccessExpression(mesh::Mesh *mesh, + const Expr mesh_idx, + mesh::MeshElementType to_type, + const Expr neighbor_idx) + : mesh(mesh), + mesh_idx(mesh_idx), + to_type(to_type), + neighbor_idx(neighbor_idx) { } void flatten(FlattenContext *ctx) override; }; -class FuncCallExpression : public Expression { +class MeshIndexConversionExpression : public Expression { public: - Function *func; - ExprGroup args; + mesh::Mesh *mesh; + mesh::MeshElementType idx_type; + Expr idx; + mesh::ConvType conv_type; - std::string serialize() override; + void type_check(CompileConfig *config) override; - FuncCallExpression(Function *func, const ExprGroup &args) - : func(func), args(args) { + void serialize(std::ostream &ss) override { + ss << "mesh_index_conversion(" << mesh::conv_type_name(conv_type) << ", " + << mesh::element_type_name(idx_type) << ", "; + idx->serialize(ss); + ss << ")"; + } + + MeshIndexConversionExpression(mesh::Mesh *mesh, + mesh::MeshElementType idx_type, + const Expr idx, + mesh::ConvType conv_type) + : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { } void flatten(FlattenContext *ctx) override; @@ -699,53 +819,133 @@ class FuncCallExpression : public Expression { class ASTBuilder { private: - std::vector stack; + enum LoopState { None, Outermost, Inner }; + enum LoopType { NotLoop, For, While }; - public: - ASTBuilder(Block *initial) { - stack.push_back(initial); - } + class ForLoopDecoratorRecorder { + public: + ForLoopConfig config; - void insert(std::unique_ptr &&stmt, int location = -1); - - struct ScopeGuard { - ASTBuilder *builder; - Block *list; - ScopeGuard(ASTBuilder *builder, Block *list) - : builder(builder), list(list) { - builder->stack.push_back(list); + ForLoopDecoratorRecorder() { + reset(); } - ~ScopeGuard() { - builder->stack.pop_back(); + void reset() { + config.bit_vectorize = -1; + config.num_cpu_threads = 0; + config.uniform = false; + config.mem_access_opt.clear(); + config.block_dim = 0; + config.strictly_serialized = false; } }; - std::unique_ptr create_scope(std::unique_ptr &list); + std::vector stack_; + std::vector loop_state_stack_; + Arch arch_; + ForLoopDecoratorRecorder for_loop_dec_; + + public: + ASTBuilder(Block *initial, Arch arch) : arch_(arch) { + stack_.push_back(initial); + loop_state_stack_.push_back(None); + } + + void insert(std::unique_ptr &&stmt, int location = -1); + Block *current_block(); Stmt *get_last_stmt(); void stop_gradient(SNode *); + void insert_assignment(Expr &lhs, const Expr &rhs); + Expr make_var(const Expr &x); + void insert_for(const Expr &s, + const Expr &e, + const std::function &func); + + Expr insert_thread_idx_expr(); + Expr insert_patch_idx_expr(); + void create_kernel_exprgroup_return(const ExprGroup &group); + void create_print(std::vector> contents); + void begin_func(const std::string &funcid); + void end_func(const std::string &funcid); + void begin_frontend_if(const Expr &cond); + void begin_frontend_if_true(); + void begin_frontend_if_false(); + void insert_external_func_call(std::size_t func_addr, + std::string source, + std::string filename, + std::string funcname, + const ExprGroup &args, + const ExprGroup &outputs); + Expr expr_alloca(); + Expr expr_alloca_local_tensor(const std::vector &shape, + const DataType &element_type, + const ExprGroup &elements); + void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); + void create_assert_stmt(const Expr &cond, + const std::string &msg, + const std::vector &args); + void begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e); + void begin_frontend_struct_for(const ExprGroup &loop_vars, + const Expr &global); + void begin_frontend_mesh_for(const Expr &i, + const mesh::MeshPtr &mesh_ptr, + const mesh::MeshElementType &element_type); + void begin_frontend_while(const Expr &cond); + void insert_break_stmt(); + void insert_continue_stmt(); + void insert_expr_stmt(const Expr &val); + + void create_scope(std::unique_ptr &list, LoopType tp = NotLoop); + void pop_scope(); + + void bit_vectorize(int v) { + for_loop_dec_.config.bit_vectorize = v; + } + + void parallelize(int v) { + for_loop_dec_.config.num_cpu_threads = v; + } + + void strictly_serialize() { + for_loop_dec_.config.strictly_serialized = true; + } + + void block_dim(int v) { + TI_ASSERT(bit::is_power_of_two(v)); + for_loop_dec_.config.block_dim = v; + } + + void insert_snode_access_flag(SNodeAccessFlag v, const Expr &field) { + for_loop_dec_.config.mem_access_opt.add_flag(field.snode(), v); + } + + void reset_snode_access_flag() { + for_loop_dec_.reset(); + } }; -ASTBuilder ¤t_ast_builder(); - class FrontendContext { private: - std::unique_ptr current_builder; - std::unique_ptr root_node; + std::unique_ptr current_builder_; + std::unique_ptr root_node_; public: - FrontendContext(); + FrontendContext(Arch arch); ASTBuilder &builder() { - return *current_builder; + return *current_builder_; } IRNode *root(); std::unique_ptr get_root() { - return std::move(root_node); + return std::move(root_node_); } }; +void flatten_lvalue(Expr expr, Expression::FlattenContext *ctx); + +void flatten_rvalue(Expr expr, Expression::FlattenContext *ctx); + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 7989a2783c30a..3692513481eea 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -16,21 +16,13 @@ std::string snode_access_flag_name(SNodeAccessFlag type) { return "block_local"; } else if (type == SNodeAccessFlag::read_only) { return "read_only"; + } else if (type == SNodeAccessFlag::mesh_local) { + return "mesh_local"; } else { TI_ERROR("Undefined SNode AccessType (value={})", int(type)); } } -void DecoratorRecorder::reset() { - vectorize = -1; - bit_vectorize = -1; - num_cpu_threads = 0; - uniform = false; - mem_access_opt.clear(); - block_dim = 0; - strictly_serialized = false; -} - int Identifier::id_counter = 0; std::string Identifier::raw_name() const { if (name_.empty()) @@ -89,7 +81,7 @@ int StmtFieldSNode::get_snode_id(SNode *snode) { bool StmtFieldSNode::equal(const StmtField *other_generic) const { if (auto other = dynamic_cast(other_generic)) { - return get_snode_id(snode) == get_snode_id(other->snode); + return get_snode_id(snode_) == get_snode_id(other->snode_); } else { // Different types return false; @@ -171,7 +163,7 @@ Stmt *Stmt::insert_after_me(std::unique_ptr &&new_stmt) { return ret; } -void Stmt::replace_with(Stmt *new_stmt) { +void Stmt::replace_usages_with(Stmt *new_stmt) { irpass::replace_all_usages_with(nullptr, this, new_stmt); } @@ -391,7 +383,7 @@ void Block::replace_with(Stmt *old_statement, } TI_ASSERT(location != -1); if (replace_usages && !new_statements.stmts.empty()) - old_statement->replace_with(new_statements.back().get()); + old_statement->replace_usages_with(new_statements.back().get()); trash_bin.push_back(std::move(statements[location])); if (new_statements.size() == 1) { // Keep all std::vector::iterator valid in this case. @@ -444,76 +436,87 @@ std::unique_ptr Block::clone() const { } DelayedIRModifier::~DelayedIRModifier() { - TI_ASSERT(to_insert_before.empty()); - TI_ASSERT(to_insert_after.empty()); - TI_ASSERT(to_erase.empty()); - TI_ASSERT(to_replace_with.empty()); - TI_ASSERT(to_extract_to_block_front.empty()); + TI_ASSERT(to_insert_before_.empty()); + TI_ASSERT(to_insert_after_.empty()); + TI_ASSERT(to_erase_.empty()); + TI_ASSERT(to_replace_with_.empty()); + TI_ASSERT(to_extract_to_block_front_.empty()); + TI_ASSERT(to_type_check_.empty()); } void DelayedIRModifier::erase(Stmt *stmt) { - to_erase.push_back(stmt); + to_erase_.push_back(stmt); } void DelayedIRModifier::insert_before(Stmt *old_statement, std::unique_ptr new_statements) { - to_insert_before.emplace_back(old_statement, - VecStatement(std::move(new_statements))); + to_insert_before_.emplace_back(old_statement, + VecStatement(std::move(new_statements))); } void DelayedIRModifier::insert_before(Stmt *old_statement, VecStatement &&new_statements) { - to_insert_before.emplace_back(old_statement, std::move(new_statements)); + to_insert_before_.emplace_back(old_statement, std::move(new_statements)); } void DelayedIRModifier::insert_after(Stmt *old_statement, std::unique_ptr new_statements) { - to_insert_after.emplace_back(old_statement, - VecStatement(std::move(new_statements))); + to_insert_after_.emplace_back(old_statement, + VecStatement(std::move(new_statements))); } void DelayedIRModifier::insert_after(Stmt *old_statement, VecStatement &&new_statements) { - to_insert_after.emplace_back(old_statement, std::move(new_statements)); + to_insert_after_.emplace_back(old_statement, std::move(new_statements)); } void DelayedIRModifier::replace_with(Stmt *stmt, VecStatement &&new_statements, bool replace_usages) { - to_replace_with.emplace_back(stmt, std::move(new_statements), replace_usages); + to_replace_with_.emplace_back(stmt, std::move(new_statements), + replace_usages); } void DelayedIRModifier::extract_to_block_front(Stmt *stmt, Block *blk) { - to_extract_to_block_front.emplace_back(stmt, blk); + to_extract_to_block_front_.emplace_back(stmt, blk); +} + +void DelayedIRModifier::type_check(IRNode *node, CompileConfig cfg) { + to_type_check_.emplace_back(node, cfg); } bool DelayedIRModifier::modify_ir() { bool force_modified = modified_; modified_ = false; - if (to_insert_before.empty() && to_insert_after.empty() && to_erase.empty() && - to_replace_with.empty() && to_extract_to_block_front.empty()) + if (to_insert_before_.empty() && to_insert_after_.empty() && + to_erase_.empty() && to_replace_with_.empty() && + to_extract_to_block_front_.empty() && to_type_check_.empty()) return force_modified; - for (auto &i : to_insert_before) { + for (auto &i : to_insert_before_) { i.first->parent->insert_before(i.first, std::move(i.second)); } - to_insert_before.clear(); - for (auto &i : to_insert_after) { + to_insert_before_.clear(); + for (auto &i : to_insert_after_) { i.first->parent->insert_after(i.first, std::move(i.second)); } - to_insert_after.clear(); - for (auto &stmt : to_erase) { + to_insert_after_.clear(); + for (auto &stmt : to_erase_) { stmt->parent->erase(stmt); } - to_erase.clear(); - for (auto &i : to_replace_with) { + to_erase_.clear(); + for (auto &i : to_replace_with_) { std::get<0>(i)->replace_with(std::move(std::get<1>(i)), std::get<2>(i)); } - to_replace_with.clear(); - for (auto &i : to_extract_to_block_front) { + to_replace_with_.clear(); + for (auto &i : to_extract_to_block_front_) { auto extracted = i.first->parent->extract(i.first); i.second->insert(std::move(extracted), 0); } - to_extract_to_block_front.clear(); + to_extract_to_block_front_.clear(); + for (auto &i : to_type_check_) { + irpass::type_check(i.first, i.second); + } + to_type_check_.clear(); return true; } diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 34f7ac860eb69..0a765fa36ff6a 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -9,8 +9,9 @@ #include #include "taichi/common/core.h" -#include "taichi/ir/ir_modified.h" +#include "taichi/common/exceptions.h" #include "taichi/ir/snode.h" +#include "taichi/ir/mesh.h" #include "taichi/ir/type_factory.h" #include "taichi/util/short_name.h" @@ -27,7 +28,7 @@ class SNode; class Kernel; struct CompileConfig; -enum class SNodeAccessFlag : int { block_local, read_only }; +enum class SNodeAccessFlag : int { block_local, read_only, mesh_local }; std::string snode_access_flag_name(SNodeAccessFlag type); class MemoryAccessOptions { @@ -70,23 +71,6 @@ class MemoryAccessOptions { #include "taichi/inc/statements.inc.h" #undef PER_STATEMENT -class DecoratorRecorder { - public: - int vectorize; - int bit_vectorize; - int num_cpu_threads; - bool strictly_serialized; - MemoryAccessOptions mem_access_opt; - int block_dim; - bool uniform; - - DecoratorRecorder() { - reset(); - } - - void reset(); -}; - class Identifier { public: static int id_counter; @@ -394,28 +378,28 @@ class StmtField { template class StmtFieldNumeric final : public StmtField { private: - std::variant value; + std::variant value_; public: - explicit StmtFieldNumeric(T *value) : value(value) { + explicit StmtFieldNumeric(T *value) : value_(value) { } - explicit StmtFieldNumeric(T value) : value(value) { + explicit StmtFieldNumeric(T value) : value_(value) { } bool equal(const StmtField *other_generic) const override { if (auto other = dynamic_cast(other_generic)) { - if (std::holds_alternative(other->value) && - std::holds_alternative(value)) { - return *(std::get(other->value)) == *(std::get(value)); - } else if (std::holds_alternative(other->value) || - std::holds_alternative(value)) { + if (std::holds_alternative(other->value_) && + std::holds_alternative(value_)) { + return *(std::get(other->value_)) == *(std::get(value_)); + } else if (std::holds_alternative(other->value_) || + std::holds_alternative(value_)) { TI_ERROR( "Inconsistent StmtField value types: a pointer value is compared " "to a non-pointer value."); return false; } else { - return std::get(other->value) == std::get(value); + return std::get(other->value_) == std::get(value_); } } else { // Different types @@ -426,10 +410,10 @@ class StmtFieldNumeric final : public StmtField { class StmtFieldSNode final : public StmtField { private: - SNode *const &snode; + SNode *const &snode_; public: - explicit StmtFieldSNode(SNode *const &snode) : snode(snode) { + explicit StmtFieldSNode(SNode *const &snode) : snode_(snode) { } static int get_snode_id(SNode *snode); @@ -451,12 +435,12 @@ class StmtFieldMemoryAccessOptions final : public StmtField { class StmtFieldManager { private: - Stmt *stmt; + Stmt *stmt_; public: std::vector> fields; - StmtFieldManager(Stmt *stmt) : stmt(stmt) { + StmtFieldManager(Stmt *stmt) : stmt_(stmt) { } template @@ -547,7 +531,7 @@ class Stmt : public IRNode { bool has_operand(Stmt *stmt) const; - void replace_with(Stmt *new_stmt); + void replace_usages_with(Stmt *new_stmt); void replace_with(VecStatement &&new_statements, bool replace_usages = true); virtual void replace_operand_with(Stmt *old_stmt, Stmt *new_stmt); @@ -596,7 +580,7 @@ class Stmt : public IRNode { TI_NOT_IMPLEMENTED } - virtual ~Stmt() = default; + ~Stmt() override = default; }; class Block : public IRNode { @@ -672,11 +656,12 @@ class Block : public IRNode { class DelayedIRModifier { private: - std::vector> to_insert_before; - std::vector> to_insert_after; - std::vector> to_replace_with; - std::vector to_erase; - std::vector> to_extract_to_block_front; + std::vector> to_insert_before_; + std::vector> to_insert_after_; + std::vector> to_replace_with_; + std::vector to_erase_; + std::vector> to_extract_to_block_front_; + std::vector> to_type_check_; bool modified_{false}; public: @@ -690,6 +675,7 @@ class DelayedIRModifier { VecStatement &&new_statements, bool replace_usages = true); void extract_to_block_front(Stmt *stmt, Block *blk); + void type_check(IRNode *node, CompileConfig cfg); bool modify_ir(); // Force the next call of modify_ir() to return true. @@ -703,29 +689,6 @@ struct LocalAddress { LocalAddress(Stmt *var, int offset); }; -extern DecoratorRecorder dec; - -inline void Vectorize(int v) { - dec.vectorize = v; -} - -inline void BitVectorize(int v) { - dec.bit_vectorize = v; -} - -inline void Parallelize(int v) { - dec.num_cpu_threads = v; -} - -inline void StrictlySerialize() { - dec.strictly_serialized = true; -} - -inline void BlockDim(int v) { - TI_ASSERT(bit::is_power_of_two(v)); - dec.block_dim = v; -} - class VectorElement { public: Stmt *stmt; @@ -743,7 +706,7 @@ inline void StmtFieldManager::operator()(const char *key, T &&value) { using decay_T = typename std::decay::type; if constexpr (is_specialization::value || is_specialization::value) { - stmt->field_manager.fields.emplace_back( + stmt_->field_manager.fields.emplace_back( std::make_unique>(value.size())); for (int i = 0; i < (int)value.size(); i++) { (*this)("__element", value[i]); @@ -751,30 +714,30 @@ inline void StmtFieldManager::operator()(const char *key, T &&value) { } else if constexpr (std::is_same>::value) { if (std::holds_alternative(value)) { - stmt->field_manager.fields.emplace_back( + stmt_->field_manager.fields.emplace_back( std::make_unique>( std::get(value))); } else { (*this)("__element", std::get(value)); } } else if constexpr (std::is_same::value) { - stmt->register_operand(const_cast(value)); + stmt_->register_operand(const_cast(value)); } else if constexpr (std::is_same::value) { - stmt->register_operand(const_cast(value.var)); - stmt->field_manager.fields.emplace_back( + stmt_->register_operand(const_cast(value.var)); + stmt_->field_manager.fields.emplace_back( std::make_unique>(value.offset)); } else if constexpr (std::is_same::value) { - stmt->register_operand(const_cast(value.stmt)); - stmt->field_manager.fields.emplace_back( + stmt_->register_operand(const_cast(value.stmt)); + stmt_->field_manager.fields.emplace_back( std::make_unique>(value.index)); } else if constexpr (std::is_same::value) { - stmt->field_manager.fields.emplace_back( + stmt_->field_manager.fields.emplace_back( std::make_unique(value)); } else if constexpr (std::is_same::value) { - stmt->field_manager.fields.emplace_back( + stmt_->field_manager.fields.emplace_back( std::make_unique(value)); } else { - stmt->field_manager.fields.emplace_back( + stmt_->field_manager.fields.emplace_back( std::make_unique>>(&value)); } } diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index b0332b9477bfc..43d6c7c456549 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -85,23 +85,31 @@ IRBuilder::IfGuard::~IfGuard() { RangeForStmt *IRBuilder::create_range_for(Stmt *begin, Stmt *end, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim, bool strictly_serialized) { return insert(Stmt::make_typed( - begin, end, std::make_unique(), vectorize, bit_vectorize, - num_cpu_threads, block_dim, strictly_serialized)); + begin, end, std::make_unique(), bit_vectorize, num_cpu_threads, + block_dim, strictly_serialized)); } StructForStmt *IRBuilder::create_struct_for(SNode *snode, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim) { return insert(Stmt::make_typed( - snode, std::make_unique(), vectorize, bit_vectorize, + snode, std::make_unique(), bit_vectorize, num_cpu_threads, + block_dim)); +} + +MeshForStmt *IRBuilder::create_mesh_for(mesh::Mesh *mesh, + mesh::MeshElementType element_type, + int bit_vectorize, + int num_cpu_threads, + int block_dim) { + return insert(Stmt::make_typed( + mesh, element_type, std::make_unique(), bit_vectorize, num_cpu_threads, block_dim)); } @@ -208,6 +216,10 @@ UnaryOpStmt *IRBuilder::create_logical_not(Stmt *value) { return insert(Stmt::make_typed(UnaryOpType::logic_not, value)); } +UnaryOpStmt *IRBuilder::create_round(Stmt *value) { + return insert(Stmt::make_typed(UnaryOpType::round, value)); +} + UnaryOpStmt *IRBuilder::create_floor(Stmt *value) { return insert(Stmt::make_typed(UnaryOpType::floor, value)); } @@ -415,7 +427,8 @@ GlobalPtrStmt *IRBuilder::create_global_ptr( ExternalPtrStmt *IRBuilder::create_external_ptr( ArgLoadStmt *ptr, const std::vector &indices) { - return insert(Stmt::make_typed(ptr, indices)); + return insert( + Stmt::make_typed(ptr, indices, std::vector(), 0)); } AdStackAllocaStmt *IRBuilder::create_ad_stack(const DataType &dt, @@ -445,4 +458,36 @@ void IRBuilder::ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, insert(Stmt::make_typed(stack, val)); } +// Mesh related. + +MeshRelationAccessStmt *IRBuilder::get_relation_size( + mesh::Mesh *mesh, + Stmt *mesh_idx, + mesh::MeshElementType to_type) { + return insert( + Stmt::make_typed(mesh, mesh_idx, to_type)); +} + +MeshRelationAccessStmt *IRBuilder::get_relation_access( + mesh::Mesh *mesh, + Stmt *mesh_idx, + mesh::MeshElementType to_type, + Stmt *neighbor_idx) { + return insert(Stmt::make_typed( + mesh, mesh_idx, to_type, neighbor_idx)); +} + +MeshIndexConversionStmt *IRBuilder::get_index_conversion( + mesh::Mesh *mesh, + mesh::MeshElementType idx_type, + Stmt *idx, + mesh::ConvType conv_type) { + return insert(Stmt::make_typed(mesh, idx_type, idx, + conv_type)); +} + +MeshPatchIndexStmt *IRBuilder::get_patch_index() { + return insert(Stmt::make_typed()); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index ca90b358a3760..049b09649c3d0 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -1,6 +1,7 @@ #pragma once #include "taichi/ir/ir.h" +#include "taichi/ir/mesh.h" TLANG_NAMESPACE_BEGIN @@ -49,6 +50,7 @@ class IRBuilder { } if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { set_insertion_point({loop->body.get(), 0}); } else { @@ -101,16 +103,19 @@ class IRBuilder { // Control flows. RangeForStmt *create_range_for(Stmt *begin, Stmt *end, - int vectorize = -1, int bit_vectorize = -1, int num_cpu_threads = 0, int block_dim = 0, bool strictly_serialized = false); StructForStmt *create_struct_for(SNode *snode, - int vectorize = -1, int bit_vectorize = -1, int num_cpu_threads = 0, int block_dim = 0); + MeshForStmt *create_mesh_for(mesh::Mesh *mesh, + mesh::MeshElementType element_type, + int bit_vectorize = -1, + int num_cpu_threads = 0, + int block_dim = 0); WhileStmt *create_while_true(); IfStmt *create_if(Stmt *cond); WhileControlStmt *create_break(); @@ -145,6 +150,7 @@ class IRBuilder { UnaryOpStmt *create_neg(Stmt *value); UnaryOpStmt *create_not(Stmt *value); // bitwise UnaryOpStmt *create_logical_not(Stmt *value); + UnaryOpStmt *create_round(Stmt *value); UnaryOpStmt *create_floor(Stmt *value); UnaryOpStmt *create_ceil(Stmt *value); UnaryOpStmt *create_abs(Stmt *value); @@ -256,6 +262,20 @@ class IRBuilder { AdStackLoadTopAdjStmt *ad_stack_load_top_adjoint(AdStackAllocaStmt *stack); void ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, Stmt *val); + // Mesh related. + MeshRelationAccessStmt *get_relation_size(mesh::Mesh *mesh, + Stmt *mesh_idx, + mesh::MeshElementType to_type); + MeshRelationAccessStmt *get_relation_access(mesh::Mesh *mesh, + Stmt *mesh_idx, + mesh::MeshElementType to_type, + Stmt *neighbor_idx); + MeshIndexConversionStmt *get_index_conversion(mesh::Mesh *mesh, + mesh::MeshElementType idx_type, + Stmt *idx, + mesh::ConvType conv_type); + MeshPatchIndexStmt *get_patch_index(); + private: std::unique_ptr root_{nullptr}; InsertPoint insert_point_; diff --git a/taichi/ir/mesh.cpp b/taichi/ir/mesh.cpp new file mode 100644 index 0000000000000..9219c84c4da4a --- /dev/null +++ b/taichi/ir/mesh.cpp @@ -0,0 +1,61 @@ +#include "taichi/ir/mesh.h" + +namespace taichi { +namespace lang { +namespace mesh { + +std::string element_type_name(MeshElementType type) { + if (type == MeshElementType::Vertex) + return "verts"; + else if (type == MeshElementType::Edge) + return "edges"; + else if (type == MeshElementType::Face) + return "faces"; + else if (type == MeshElementType::Cell) + return "cells"; + else { + TI_NOT_IMPLEMENTED; + } +} + +std::string relation_type_name(MeshRelationType type) { + return element_type_name(MeshElementType(from_end_element_order(type))) + + "-" + element_type_name(MeshElementType(to_end_element_order(type))); +} + +std::string conv_type_name(ConvType type) { + if (type == mesh::ConvType::l2g) + return "local to global"; + else if (type == mesh::ConvType::l2r) + return "local to reordered"; + else if (type == mesh::ConvType::g2r) { + return "global to reordered"; + } else { + TI_NOT_IMPLEMENTED; + } +} + +int element_order(MeshElementType type) { + return int(type); +} + +int from_end_element_order(MeshRelationType rel) { + return int(rel) >> 0x2; +} + +int to_end_element_order(MeshRelationType rel) { + return int(rel) & 0x3; +} + +MeshRelationType relation_by_orders(int from_order, int to_order) { + return MeshRelationType((from_order << 2) | to_order); +} + +MeshRelationType inverse_relation(MeshRelationType rel) { + return relation_by_orders(to_end_element_order(rel), + from_end_element_order(rel)); +} + +} // namespace mesh +} // namespace lang +} // namespace taichi diff --git a/taichi/ir/mesh.h b/taichi/ir/mesh.h new file mode 100644 index 0000000000000..ba87b58091568 --- /dev/null +++ b/taichi/ir/mesh.h @@ -0,0 +1,96 @@ +#pragma once + +#include + +#include "taichi/ir/type.h" +#include "taichi/ir/snode.h" +#include "taichi/ir/scratch_pad.h" + +#include + +namespace taichi { +namespace lang { + +class Stmt; + +namespace mesh { + +enum class MeshTopology { Triangle = 3, Tetrahedron = 4 }; + +enum class MeshElementType { Vertex = 0, Edge = 1, Face = 2, Cell = 3 }; + +std::string element_type_name(MeshElementType type); + +enum class MeshRelationType { + VV = 0, + VE = 1, + VF = 2, + VC = 3, + EV = 4, + EE = 5, + EF = 6, + EC = 7, + FV = 8, + FE = 9, + FF = 10, + FC = 11, + CV = 12, + CE = 13, + CF = 14, + CC = 15, +}; + +std::string relation_type_name(MeshRelationType type); + +enum class ConvType { l2g, l2r, g2r }; + +std::string conv_type_name(ConvType type); + +int element_order(MeshElementType type); +int from_end_element_order(MeshRelationType rel); +int to_end_element_order(MeshRelationType rel); +MeshRelationType relation_by_orders(int from_order, int to_order); +MeshRelationType inverse_relation(MeshRelationType rel); + +struct MeshLocalRelation { + MeshLocalRelation(SNode *value_, SNode *offset_) + : value(value_), offset(offset_) { + fixed = false; + } + + MeshLocalRelation(SNode *value_) : value(value_) { + fixed = true; + } + + bool fixed; + SNode *value{nullptr}; + SNode *offset{nullptr}; +}; + +class Mesh { + public: + Mesh() = default; + + template + using MeshMapping = std::unordered_map; + + int num_patches{0}; + MeshMapping num_elements{}; + MeshMapping + patch_max_element_num{}; // the max number of mesh element in each patch + + MeshMapping owned_offset{}; // prefix of owned element + MeshMapping total_offset{}; // prefix of total element + std::map, SNode *> + index_mapping{}; // mapping from one index space to another index space + + std::map relations; +}; + +struct MeshPtr { // Mesh wrapper in python + std::shared_ptr ptr; +}; + +} // namespace mesh +} // namespace lang +} // namespace taichi diff --git a/taichi/ir/offloaded_task_type.h b/taichi/ir/offloaded_task_type.h index e32a338259a99..673a8164c2087 100644 --- a/taichi/ir/offloaded_task_type.h +++ b/taichi/ir/offloaded_task_type.h @@ -4,7 +4,8 @@ #include -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { enum class OffloadedTaskType : int { #define PER_TASK_TYPE(x) x, @@ -14,4 +15,5 @@ enum class OffloadedTaskType : int { std::string offloaded_task_type_name(OffloadedTaskType tt); -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index a17117ece854a..b5b715884868c 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -1,8 +1,11 @@ #include "taichi/ir/snode.h" +#include + #include "taichi/ir/ir.h" #include "taichi/ir/statements.h" #include "taichi/program/program.h" +#include "taichi/program/snode_rw_accessors_bank.h" TLANG_NAMESPACE_BEGIN @@ -11,7 +14,8 @@ std::atomic SNode::counter{0}; SNode &SNode::insert_children(SNodeType t) { TI_ASSERT(t != SNodeType::root); - auto new_ch = std::make_unique(depth + 1, t); + auto new_ch = std::make_unique(depth + 1, t, snode_to_glb_var_exprs_, + snode_rw_accessors_bank_); new_ch->parent = this; new_ch->is_path_all_dense = (is_path_all_dense && !new_ch->need_activation()); for (int i = 0; i < taichi_max_num_indices; i++) { @@ -46,7 +50,10 @@ SNode &SNode::create_node(std::vector axes, auto &new_node = insert_children(type); for (int i = 0; i < (int)axes.size(); i++) { - TI_ASSERT(sizes[i] > 0); + if (sizes[i] <= 0) { + throw TaichiRuntimeError( + "Every dimension of a Taichi field should be positive"); + } auto &ind = axes[i]; new_node.extractors[ind.value].activate( bit::log2int(bit::least_pot_bound(sizes[i]))); @@ -76,11 +83,17 @@ SNode &SNode::create_node(std::vector axes, std::sort(new_node.physical_index_position, new_node.physical_index_position + new_node.num_active_indices); // infer extractors - int acc_shape = 1; + int64 acc_shape = 1; for (int i = taichi_max_num_indices - 1; i >= 0; i--) { - new_node.extractors[i].acc_shape = acc_shape; + // casting to int32 in extractors. + new_node.extractors[i].acc_shape = static_cast(acc_shape); acc_shape *= new_node.extractors[i].shape; } + if (acc_shape > std::numeric_limits::max()) { + TI_WARN( + "Snode index might be out of int32 boundary but int64 indexing is not " + "supported yet."); + } new_node.num_cells_per_container = acc_shape; // infer extractors (only for POT) int acc_offsets = 0; @@ -165,10 +178,46 @@ int SNode::shape_along_axis(int i) const { return extractor.num_elements_from_root; } -SNode::SNode() : SNode(0, SNodeType::undefined) { +int64 SNode::read_int(const std::vector &i) { + return snode_rw_accessors_bank_->get(this).read_int(i); +} + +uint64 SNode::read_uint(const std::vector &i) { + return snode_rw_accessors_bank_->get(this).read_uint(i); +} + +float64 SNode::read_float(const std::vector &i) { + return snode_rw_accessors_bank_->get(this).read_float(i); +} + +void SNode::write_int(const std::vector &i, int64 val) { + snode_rw_accessors_bank_->get(this).write_int(i, val); +} + +void SNode::write_float(const std::vector &i, float64 val) { + snode_rw_accessors_bank_->get(this).write_float(i, val); +} + +Expr SNode::get_expr() const { + return Expr(snode_to_glb_var_exprs_->at(this)); +} + +SNode::SNode(SNodeGlobalVarExprMap *snode_to_glb_var_exprs, + SNodeRwAccessorsBank *snode_rw_accessors_bank) + : SNode(0, + SNodeType::undefined, + snode_to_glb_var_exprs, + snode_rw_accessors_bank) { } -SNode::SNode(int depth, SNodeType t) : depth(depth), type(t) { +SNode::SNode(int depth, + SNodeType t, + SNodeGlobalVarExprMap *snode_to_glb_var_exprs, + SNodeRwAccessorsBank *snode_rw_accessors_bank) + : depth(depth), + type(t), + snode_to_glb_var_exprs_(snode_to_glb_var_exprs), + snode_rw_accessors_bank_(snode_rw_accessors_bank) { id = counter++; node_type_name = get_node_type_name(); total_num_bits = 0; diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index f50fca4d0b6a0..6931bcc82e449 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -6,9 +6,12 @@ #include "taichi/ir/expr.h" #include "taichi/ir/snode_types.h" #include "taichi/ir/type.h" +#include "taichi/program/snode_expr_utils.h" namespace taichi { namespace lang { +class Program; +class SNodeRwAccessorsBank; /** * Dimension (or axis) of a tensor. @@ -123,6 +126,7 @@ class SNode { int total_bit_start{0}; int chunk_size{0}; std::size_t cell_size_bytes{0}; + std::size_t offset_bytes_in_parent_cell{0}; PrimitiveType *physical_type{nullptr}; // for bit_struct and bit_array only DataType dt; bool has_ambient{false}; @@ -146,9 +150,13 @@ class SNode { // Whether the path from root to |this| contains only `dense` SNodes. bool is_path_all_dense{true}; - SNode(); + SNode(SNodeGlobalVarExprMap *snode_to_glb_var_exprs = nullptr, + SNodeRwAccessorsBank *snode_rw_accessors_bank = nullptr); - SNode(int depth, SNodeType t); + SNode(int depth, + SNodeType t, + SNodeGlobalVarExprMap *snode_to_glb_var_exprs = nullptr, + SNodeRwAccessorsBank *snode_rw_accessors_bank = nullptr); SNode(const SNode &); @@ -315,6 +323,22 @@ class SNode { int shape_along_axis(int i) const; + void place(Expr &expr, const std::vector &offset) { + place_child(&expr, offset, this, snode_to_glb_var_exprs_); + } + + void lazy_grad() { + make_lazy_grad(this, snode_to_glb_var_exprs_); + } + + int64 read_int(const std::vector &i); + uint64 read_uint(const std::vector &i); + float64 read_float(const std::vector &i); + void write_int(const std::vector &i, int64 val); + void write_float(const std::vector &i, float64 val); + + Expr get_expr() const; + uint64 fetch_reader_result(); // TODO: refactor void begin_shared_exp_placement(); @@ -329,6 +353,8 @@ class SNode { private: int snode_tree_id_{0}; + SNodeGlobalVarExprMap *snode_to_glb_var_exprs_{nullptr}; + SNodeRwAccessorsBank *snode_rw_accessors_bank_{nullptr}; }; } // namespace lang diff --git a/taichi/ir/state_machine.cpp b/taichi/ir/state_machine.cpp deleted file mode 100644 index 4c747c4cc08f2..0000000000000 --- a/taichi/ir/state_machine.cpp +++ /dev/null @@ -1,440 +0,0 @@ -#include "taichi/ir/state_machine.h" - -#include "taichi/ir/statements.h" -#include "taichi/ir/analysis.h" -#include "taichi/ir/ir_modified.h" - -TLANG_NAMESPACE_BEGIN - -std::unique_ptr> StateMachine::used_atomics; - -StateMachine::StateMachine(Stmt *var, bool zero_initialized) - : var(var), - stored(never), - stored_in_this_if_or_loop(never), - loaded(never), - loaded_in_this_if_or_loop(never), - last_store(nullptr), - last_store_forwardable(false), - last_store_eliminable(false), - last_atomic(nullptr), - last_atomic_eliminable(false), - maybe_loaded_before_first_definite_store_in_this_if_or_loop(false) { - if (!zero_initialized) - stored = stored_in_this_if_or_loop = maybe; -} - -bool StateMachine::same_data(Stmt *store_stmt1, Stmt *store_stmt2) { - if (store_stmt1->is()) { - if (!store_stmt2->is()) - return false; - return irpass::analysis::same_statements( - store_stmt1->as()->val, - store_stmt2->as()->val); - } else { - if (!store_stmt2->is()) - return false; - return irpass::analysis::same_statements( - store_stmt1->as()->val, - store_stmt2->as()->val); - } -} - -StateMachine::State StateMachine::merge_either_a_or_b( - const StateMachine::State &a, - const StateMachine::State &b) { - if (a == definitely && b == definitely) - return definitely; - if (a != never || b != never) - return maybe; - return never; -} - -StateMachine::State StateMachine::merge_a_and_b(const StateMachine::State &a, - const StateMachine::State &b) { - if (a == definitely || b == definitely) - return definitely; - if (a == maybe || b == maybe) - return maybe; - return never; -} - -StateMachine::State StateMachine::merge_a_and_maybe_b( - const StateMachine::State &a, - const StateMachine::State &b) { - if (a == definitely) - return definitely; - if (a == maybe || b != never) - return maybe; - return never; -} - -void StateMachine::rebuild_atomics_usage(IRNode *root) { - used_atomics = irpass::analysis::gather_used_atomics(root); -} - -void StateMachine::atomic_op(AtomicOpStmt *stmt) { - // This statement is loading the last store, so we can't eliminate it. - if (stored_in_this_if_or_loop != definitely) - maybe_loaded_before_first_definite_store_in_this_if_or_loop = true; - - stored = stored_in_this_if_or_loop = definitely; - loaded = loaded_in_this_if_or_loop = definitely; - - last_store = nullptr; - last_store_forwardable = false; - last_store_eliminable = false; - - TI_ASSERT(used_atomics); - last_atomic = stmt; - last_atomic_eliminable = used_atomics->find(stmt) == used_atomics->end(); -} - -void StateMachine::store(Stmt *store_stmt) { - TI_ASSERT(store_stmt->is() || - store_stmt->is()); - if (last_store && last_store_eliminable && - stored_in_this_if_or_loop == definitely) { - // The last store is never loaded. - last_store->parent->erase(last_store); - throw IRModified(); - } - if (last_atomic && last_atomic_eliminable && - stored_in_this_if_or_loop == definitely) { - // The last AtomicOpStmt is never used. - last_atomic->parent->erase(last_atomic); - throw IRModified(); - } - if (last_store_forwardable && same_data(last_store, store_stmt)) { - // This store is useless. - store_stmt->parent->erase(store_stmt); - throw IRModified(); - } - stored = stored_in_this_if_or_loop = definitely; - - last_store = store_stmt; - last_store_forwardable = true; - last_store_eliminable = true; - - last_atomic = nullptr; - last_atomic_eliminable = false; -} - -void StateMachine::load(Stmt *load_stmt) { - // The load_stmt == nullptr case is only for an offloaded range_for loading - // global temps via begin_offset and end_offset. - if (load_stmt) - TI_ASSERT(load_stmt->is() || - load_stmt->is()); - if (stored_in_this_if_or_loop != definitely) - maybe_loaded_before_first_definite_store_in_this_if_or_loop = true; - loaded = loaded_in_this_if_or_loop = definitely; - last_store_eliminable = false; - last_atomic_eliminable = false; - if (!load_stmt) - return; - - if (stored == never) { - auto zero = load_stmt->insert_after_me(Stmt::make( - LaneAttribute(load_stmt->ret_type))); - zero->repeat(load_stmt->width()); - int current_stmt_id = load_stmt->parent->locate(load_stmt); - load_stmt->replace_with(zero); - load_stmt->parent->erase(current_stmt_id); - throw IRModified(); - } - if (last_store_forwardable) { - // store-forwarding - if (last_store->is()) - load_stmt->replace_with(last_store->as()->val); - else - load_stmt->replace_with(last_store->as()->val); - load_stmt->parent->erase(load_stmt); - throw IRModified(); - } -} - -void StateMachine::continue_or_break() { - last_store_eliminable = false; - last_atomic_eliminable = false; -} - -void StateMachine::maybe_atomic_op() { - if (stored_in_this_if_or_loop != definitely) - maybe_loaded_before_first_definite_store_in_this_if_or_loop = true; - if (stored == never) - stored = maybe; - if (stored_in_this_if_or_loop == never) - stored_in_this_if_or_loop = maybe; - if (loaded == never) - loaded = maybe; - if (loaded_in_this_if_or_loop == never) - loaded_in_this_if_or_loop = maybe; - - last_store = nullptr; - last_store_forwardable = false; - last_store_eliminable = false; - - last_atomic = nullptr; - last_atomic_eliminable = false; -} - -void StateMachine::maybe_store(Stmt *store_stmt) { - TI_ASSERT(store_stmt->is() || - store_stmt->is()); - if (stored == never) - stored = maybe; - if (stored_in_this_if_or_loop == never) - stored_in_this_if_or_loop = maybe; - - if (last_store_forwardable) { - last_store_forwardable = same_data(last_store, store_stmt); - } -} - -void StateMachine::maybe_load() { - if (stored_in_this_if_or_loop != definitely) - maybe_loaded_before_first_definite_store_in_this_if_or_loop = true; - if (loaded == never) - loaded = maybe; - if (loaded_in_this_if_or_loop == never) - loaded_in_this_if_or_loop = maybe; - last_store_eliminable = false; - last_atomic_eliminable = false; -} - -void StateMachine::mark_as_loop_var() { - stored = stored_in_this_if_or_loop = definitely; - loaded = loaded_in_this_if_or_loop = definitely; - last_store = nullptr; - last_store_forwardable = false; - last_store_eliminable = false; - last_atomic = nullptr; - last_atomic_eliminable = false; - maybe_loaded_before_first_definite_store_in_this_if_or_loop = false; -} - -void StateMachine::begin_offload() { - last_store_forwardable = false; -} - -void StateMachine::begin_if_or_loop() { - stored_in_this_if_or_loop = never; - loaded_in_this_if_or_loop = never; - maybe_loaded_before_first_definite_store_in_this_if_or_loop = false; -} - -void StateMachine::merge_from_if(const StateMachine &true_branch, - const StateMachine &false_branch) { - if (last_store && last_store_eliminable && - true_branch.stored_in_this_if_or_loop == definitely && - !true_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop && - false_branch.stored_in_this_if_or_loop == definitely && - !false_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop) { - // The last store is never loaded. - last_store->parent->erase(last_store); - throw IRModified(); - } - if (last_atomic && last_atomic_eliminable && - true_branch.stored_in_this_if_or_loop == definitely && - !true_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop && - false_branch.stored_in_this_if_or_loop == definitely && - !false_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop) { - // The last AtomicOpStmt is never used. - last_atomic->parent->erase(last_atomic); - throw IRModified(); - } - - if (stored_in_this_if_or_loop != definitely) { - maybe_loaded_before_first_definite_store_in_this_if_or_loop = - maybe_loaded_before_first_definite_store_in_this_if_or_loop || - true_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop || - false_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop; - } - - stored = merge_either_a_or_b(true_branch.stored, false_branch.stored); - stored_in_this_if_or_loop = merge_a_and_b( - stored_in_this_if_or_loop, - merge_either_a_or_b(true_branch.stored_in_this_if_or_loop, - false_branch.stored_in_this_if_or_loop)); - loaded = merge_either_a_or_b(true_branch.loaded, false_branch.loaded); - loaded_in_this_if_or_loop = merge_a_and_b( - loaded_in_this_if_or_loop, - merge_either_a_or_b(true_branch.loaded_in_this_if_or_loop, - false_branch.loaded_in_this_if_or_loop)); - - if (true_branch.last_store_forwardable && - false_branch.last_store_forwardable && - same_data(true_branch.last_store, false_branch.last_store)) { - last_store_forwardable = true; - if (last_store == true_branch.last_store || - last_store == false_branch.last_store) { - // The last store didn't change. - last_store_eliminable = - last_store_eliminable && - true_branch.last_store == false_branch.last_store && - true_branch.last_store_eliminable && - false_branch.last_store_eliminable; - } else { - TI_ASSERT(true_branch.last_store != false_branch.last_store); - // if $b - // $c : store $a <- v1 - // else - // $d : store $a <- v1 - // Maybe move them outside in the future? - if (true_branch.last_store_eliminable) { - last_store = true_branch.last_store; - last_store_eliminable = true; - } else { - last_store = false_branch.last_store; - last_store_eliminable = false_branch.last_store_eliminable; - } - } - } else { - last_store_forwardable = false; - // We only care if we can eliminate the last store here. - if (true_branch.last_store == last_store && - false_branch.last_store == last_store) { - // The last store didn't change. - last_store_eliminable = last_store_eliminable && - true_branch.last_store_eliminable && - false_branch.last_store_eliminable; - } else { - // The last store changed. - bool current_eliminable = - last_store && last_store_eliminable && - !true_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop && - !false_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop; - bool true_eliminable = true_branch.last_store != last_store && - true_branch.last_store != nullptr && - true_branch.last_store_eliminable; - bool false_eliminable = false_branch.last_store != last_store && - false_branch.last_store != nullptr && - false_branch.last_store_eliminable; - if (true_eliminable) { - last_store = true_branch.last_store; - last_store_eliminable = true; - } else if (false_eliminable) { - last_store = false_branch.last_store; - last_store_eliminable = true; - } else if (current_eliminable) { - last_store_eliminable = true; - } else { - // Neither branch provides a eliminable local store. - last_store = nullptr; - last_store_eliminable = false; - } - } - } - - // We only care if we can eliminate the last AtomicOpStmt here. - if (true_branch.last_atomic == last_atomic && - false_branch.last_atomic == last_atomic) { - // The last AtomicOpStmt didn't change. - last_atomic_eliminable = last_atomic_eliminable && - true_branch.last_atomic_eliminable && - false_branch.last_atomic_eliminable; - } else { - // The last store changed. - bool current_eliminable = - last_atomic && last_atomic_eliminable && - !true_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop && - !false_branch - .maybe_loaded_before_first_definite_store_in_this_if_or_loop; - bool true_eliminable = true_branch.last_atomic != last_atomic && - true_branch.last_atomic != nullptr && - true_branch.last_atomic_eliminable; - bool false_eliminable = false_branch.last_atomic != last_atomic && - false_branch.last_atomic != nullptr && - false_branch.last_atomic_eliminable; - if (true_eliminable) { - last_atomic = true_branch.last_atomic; - last_atomic_eliminable = true; - } else if (false_eliminable) { - last_atomic = false_branch.last_atomic; - last_atomic_eliminable = true; - } else if (current_eliminable) { - last_atomic_eliminable = true; - } else { - // Neither branch provides a eliminable local store. - last_atomic = nullptr; - last_atomic_eliminable = false; - } - } -} - -void StateMachine::merge_from_loop(const StateMachine &loop) { - if (stored_in_this_if_or_loop != definitely) { - maybe_loaded_before_first_definite_store_in_this_if_or_loop = - maybe_loaded_before_first_definite_store_in_this_if_or_loop || - loop.maybe_loaded_before_first_definite_store_in_this_if_or_loop; - } - - stored = merge_a_and_maybe_b(stored, loop.stored); - stored_in_this_if_or_loop = merge_a_and_maybe_b( - stored_in_this_if_or_loop, loop.stored_in_this_if_or_loop); - loaded = merge_a_and_maybe_b(loaded, loop.loaded); - loaded_in_this_if_or_loop = merge_a_and_maybe_b( - loaded_in_this_if_or_loop, loop.loaded_in_this_if_or_loop); - - // We must be cautious here because of possible Continues and WhileControls. - if (loop.stored_in_this_if_or_loop != never) { - // Not forwardable if stored in the loop. - if (loop.loaded_in_this_if_or_loop != never) { - // Not eliminable if loaded in the loop. - last_store = nullptr; - last_store_forwardable = false; - last_store_eliminable = false; - last_atomic = nullptr; - last_atomic_eliminable = false; - } else { - last_store = loop.last_store; - last_store_forwardable = false; - last_store_eliminable = loop.last_atomic_eliminable; - last_atomic = loop.last_atomic; - last_atomic_eliminable = loop.last_atomic_eliminable; - } - } else { - if (loop.loaded_in_this_if_or_loop != never) { - // Not eliminable if loaded in the loop. - last_store_eliminable = false; - last_atomic_eliminable = false; - } - } -} - -void StateMachine::finalize() { - if (last_store && last_store_eliminable) { - // The last store is never loaded. - last_store->parent->erase(last_store); - throw IRModified(); - } - if (last_atomic && last_atomic_eliminable) { - // The last AtomicOpStmt is never used. - last_atomic->parent->erase(last_atomic); - throw IRModified(); - } - if (stored == never && loaded == never) { - // Never stored and never loaded. - // For future vectorization, if it's an alloca, we need to check that - // this alloca is not used as masks (this can be done by checking operands) - // before eliminating it. - var->parent->erase(var); - throw IRModified(); - } -} - -Stmt *StateMachine::get_var() const { - return var; -} - -TLANG_NAMESPACE_END diff --git a/taichi/ir/state_machine.h b/taichi/ir/state_machine.h deleted file mode 100644 index db34f4b1f30b1..0000000000000 --- a/taichi/ir/state_machine.h +++ /dev/null @@ -1,79 +0,0 @@ -#pragma once - -#include - -#include "taichi/ir/ir.h" - -TLANG_NAMESPACE_BEGIN - -// State machine for AllocaStmt/GlobalTemporaryStmt/GlobalPtrStmt. -class StateMachine { - private: - Stmt *var; - static std::unique_ptr> used_atomics; - - bool same_data(Stmt *store_stmt1, Stmt *store_stmt2); - - public: - // If neither stored nor loaded (nor used as operands in masks/loop_vars), - // we can safely delete this variable if it's an alloca or a global temp. - enum State { never, maybe, definitely }; - State stored; // Is this variable ever stored (or atomic-operated)? - State stored_in_this_if_or_loop; - State loaded; // Is this variable ever loaded (or atomic-operated)? - State loaded_in_this_if_or_loop; - - Stmt *last_store; - - // last_store_forwardable: Can we do store-forwarding? - bool last_store_forwardable; - - // last_store_eliminable: Can we eliminate last_store? - bool last_store_eliminable; - - AtomicOpStmt *last_atomic; - - // last_atomic_eliminable: Can we eliminate last_atomic? - bool last_atomic_eliminable; - - // Is this variable ever loaded before the first *definite* store in the - // current if branch? This is ONLY for determining whether we can eliminate - // the last store before the IfStmt. - bool maybe_loaded_before_first_definite_store_in_this_if_or_loop; - - StateMachine() { - TI_ERROR("StateMachine constructor invoked with no parameters.") - } - explicit StateMachine(Stmt *var, bool zero_initialized); - - // This must be called before using StateMachine to eliminate AtomicOpStmts. - static void rebuild_atomics_usage(IRNode *root); - - static State merge_either_a_or_b(const State &a, const State &b); - static State merge_a_and_b(const State &a, const State &b); - static State merge_a_and_maybe_b(const State &a, const State &b); - - void atomic_op(AtomicOpStmt *stmt); - void store(Stmt *store_stmt); - void load(Stmt *load_stmt = nullptr); - - void continue_or_break(); - - void maybe_atomic_op(); - void maybe_store(Stmt *store_stmt); - void maybe_load(); - - void mark_as_loop_var(); - - void begin_offload(); - void begin_if_or_loop(); - void merge_from_if(const StateMachine &true_branch, - const StateMachine &false_branch); - void merge_from_loop(const StateMachine &loop); - - void finalize(); - - Stmt *get_var() const; -}; - -TLANG_NAMESPACE_END diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 924f05cea93d4..2da47acc5ed28 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -4,16 +4,6 @@ TLANG_NAMESPACE_BEGIN -bool ContinueStmt::as_return() const { - TI_ASSERT(scope != nullptr); - if (auto *offl = scope->cast(); offl) { - TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || - offl->task_type == OffloadedStmt::TaskType::struct_for); - return true; - } - return false; -} - UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) : op_type(op_type), operand(operand) { TI_ASSERT(!operand->is()); @@ -21,6 +11,12 @@ UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) TI_STMT_REG_FIELDS; } +DecorationStmt::DecorationStmt(Stmt *operand, + const std::vector &decoration) + : operand(operand), decoration(decoration) { + TI_STMT_REG_FIELDS; +} + bool UnaryOpStmt::is_cast() const { return unary_op_is_cast(op_type); } @@ -49,6 +45,15 @@ ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, TI_STMT_REG_FIELDS; } +ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, + const std::vector &indices, + const std::vector &element_shape, + int element_dim) + : ExternalPtrStmt(base_ptrs, indices) { + this->element_shape = element_shape; + this->element_dim = element_dim; +} + GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute &snodes, const std::vector &indices, bool activate) @@ -246,19 +251,19 @@ std::unique_ptr ConstStmt::copy() { RangeForStmt::RangeForStmt(Stmt *begin, Stmt *end, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim, - bool strictly_serialized) + bool strictly_serialized, + std::string range_hint) : begin(begin), end(end), body(std::move(body)), - vectorize(vectorize), bit_vectorize(bit_vectorize), num_cpu_threads(num_cpu_threads), block_dim(block_dim), - strictly_serialized(strictly_serialized) { + strictly_serialized(strictly_serialized), + range_hint(range_hint) { reversed = false; this->body->parent_stmt = this; TI_STMT_REG_FIELDS; @@ -266,21 +271,19 @@ RangeForStmt::RangeForStmt(Stmt *begin, std::unique_ptr RangeForStmt::clone() const { auto new_stmt = std::make_unique( - begin, end, body->clone(), vectorize, bit_vectorize, num_cpu_threads, - block_dim, strictly_serialized); + begin, end, body->clone(), bit_vectorize, num_cpu_threads, block_dim, + strictly_serialized); new_stmt->reversed = reversed; return new_stmt; } StructForStmt::StructForStmt(SNode *snode, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim) : snode(snode), body(std::move(body)), - vectorize(vectorize), bit_vectorize(bit_vectorize), num_cpu_threads(num_cpu_threads), block_dim(block_dim) { @@ -289,23 +292,36 @@ StructForStmt::StructForStmt(SNode *snode, } std::unique_ptr StructForStmt::clone() const { - auto new_stmt = std::make_unique(snode, body->clone(), - vectorize, bit_vectorize, - num_cpu_threads, block_dim); + auto new_stmt = std::make_unique( + snode, body->clone(), bit_vectorize, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; return new_stmt; } -FuncBodyStmt::FuncBodyStmt(const std::string &funcid, - std::unique_ptr &&body) - : funcid(funcid), body(std::move(body)) { - if (this->body) - this->body->parent_stmt = this; +MeshForStmt::MeshForStmt(mesh::Mesh *mesh, + mesh::MeshElementType element_type, + std::unique_ptr &&body, + int bit_vectorize, + int num_cpu_threads, + int block_dim) + : mesh(mesh), + body(std::move(body)), + bit_vectorize(bit_vectorize), + num_cpu_threads(num_cpu_threads), + block_dim(block_dim), + major_from_type(element_type) { + this->body->parent_stmt = this; TI_STMT_REG_FIELDS; } -std::unique_ptr FuncBodyStmt::clone() const { - return std::make_unique(funcid, body->clone()); +std::unique_ptr MeshForStmt::clone() const { + auto new_stmt = + std::make_unique(mesh, major_from_type, body->clone(), + bit_vectorize, num_cpu_threads, block_dim); + new_stmt->major_to_types = major_to_types; + new_stmt->minor_relation_types = minor_relation_types; + new_stmt->mem_access_opt = mem_access_opt; + return new_stmt; } FuncCallStmt::FuncCallStmt(Function *func, const std::vector &args) @@ -349,6 +365,8 @@ std::string OffloadedStmt::task_name() const { return "range_for"; } else if (task_type == TaskType::struct_for) { return "struct_for"; + } else if (task_type == TaskType::mesh_for) { + return "mesh_for"; } else if (task_type == TaskType::listgen) { TI_ASSERT(snode); return fmt::format("listgen_{}", snode->get_node_type_name_hinted()); @@ -379,10 +397,25 @@ std::unique_ptr OffloadedStmt::clone() const { new_stmt->reversed = reversed; new_stmt->num_cpu_threads = num_cpu_threads; new_stmt->index_offsets = index_offsets; + + new_stmt->mesh = mesh; + new_stmt->major_from_type = major_from_type; + new_stmt->major_to_types = major_to_types; + new_stmt->minor_relation_types = minor_relation_types; + + new_stmt->owned_offset_local = owned_offset_local; + new_stmt->total_offset_local = total_offset_local; + new_stmt->owned_num_local = owned_num_local; + new_stmt->total_num_local = total_num_local; + if (tls_prologue) { new_stmt->tls_prologue = tls_prologue->clone(); new_stmt->tls_prologue->parent_stmt = new_stmt.get(); } + if (mesh_prologue) { + new_stmt->mesh_prologue = mesh_prologue->clone(); + new_stmt->mesh_prologue->parent_stmt = new_stmt.get(); + } if (bls_prologue) { new_stmt->bls_prologue = bls_prologue->clone(); new_stmt->bls_prologue->parent_stmt = new_stmt.get(); @@ -405,9 +438,12 @@ std::unique_ptr OffloadedStmt::clone() const { return new_stmt; } -void OffloadedStmt::all_blocks_accept(IRVisitor *visitor) { +void OffloadedStmt::all_blocks_accept(IRVisitor *visitor, + bool skip_mesh_prologue) { if (tls_prologue) tls_prologue->accept(visitor); + if (mesh_prologue && !skip_mesh_prologue) + mesh_prologue->accept(visitor); if (bls_prologue) bls_prologue->accept(visitor); if (body) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 785f6eec1fffc..218258d59ca8e 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -3,7 +3,10 @@ #include "taichi/ir/ir.h" #include "taichi/ir/offloaded_task_type.h" #include "taichi/ir/stmt_op_types.h" -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" +#include "taichi/ir/mesh.h" + +#include namespace taichi { namespace lang { @@ -91,12 +94,43 @@ class ContinueStmt : public Stmt { // // If run_foo_kernel() is directly inlined within foo_kernel(), `return` // could prematurely terminate the entire kernel. - bool as_return() const; TI_STMT_DEF_FIELDS(scope); TI_DEFINE_ACCEPT_AND_CLONE; }; +/** + * A decoration statement. The decorated "operands" will keep this decoration. + */ +class DecorationStmt : public Stmt { + public: + enum class Decoration : uint32_t { kUnknown, kLoopUnique }; + + Stmt *operand; + std::vector decoration; + + DecorationStmt(Stmt *operand, const std::vector &decoration); + + bool same_operation(DecorationStmt *o) const { + return false; + } + + bool is_cast() const { + return false; + } + + bool has_global_side_effect() const override { + return false; + } + + bool dead_instruction_eliminable() const override { + return false; + } + + TI_STMT_DEF_FIELDS(operand, decoration); + TI_DEFINE_ACCEPT_AND_CLONE +}; + /** * A unary operation. The field |cast_type| is used only when is_cast() is true. */ @@ -261,10 +295,19 @@ class ExternalPtrStmt : public Stmt { public: LaneAttribute base_ptrs; std::vector indices; + std::vector element_shape; + // AOS: element_dim < 0 + // SOA: element_dim > 0 + int element_dim; ExternalPtrStmt(const LaneAttribute &base_ptrs, const std::vector &indices); + ExternalPtrStmt(const LaneAttribute &base_ptrs, + const std::vector &indices, + const std::vector &element_shape, + int element_dim); + bool has_global_side_effect() const override { return false; } @@ -376,6 +419,10 @@ class ExternalTensorShapeAlongAxisStmt : public Stmt { ExternalTensorShapeAlongAxisStmt(int axis, int arg_id); + bool has_global_side_effect() const override { + return false; + } + TI_STMT_DEF_FIELDS(ret_type, axis, arg_id); TI_DEFINE_ACCEPT_AND_CLONE }; @@ -408,23 +455,40 @@ class AssertStmt : public Stmt { */ class ExternalFuncCallStmt : public Stmt { public: - void *func; - std::string source; + enum Type { SHARED_OBJECT = 0, ASSEMBLY = 1, BITCODE = 2 }; + + Type type; + void *so_func; // SHARED_OBJECT + std::string asm_source; // ASM + std::string bc_filename; // BITCODE + std::string bc_funcname; // BITCODE std::vector arg_stmts; - std::vector output_stmts; + std::vector output_stmts; // BITCODE doesn't use this - ExternalFuncCallStmt(void *func, - const std::string &source, + ExternalFuncCallStmt(Type type, + void *so_func, + std::string asm_source, + std::string bc_filename, + std::string bc_funcname, const std::vector &arg_stmts, const std::vector &output_stmts) - : func(func), - source(source), + : type(type), + so_func(so_func), + asm_source(asm_source), + bc_filename(bc_filename), + bc_funcname(bc_funcname), arg_stmts(arg_stmts), output_stmts(output_stmts) { TI_STMT_REG_FIELDS; } - TI_STMT_DEF_FIELDS(func, arg_stmts, output_stmts); + TI_STMT_DEF_FIELDS(type, + so_func, + asm_source, + bc_filename, + bc_funcname, + arg_stmts, + output_stmts); TI_DEFINE_ACCEPT_AND_CLONE }; @@ -705,20 +769,20 @@ class RangeForStmt : public Stmt { Stmt *begin, *end; std::unique_ptr body; bool reversed; - int vectorize; int bit_vectorize; int num_cpu_threads; int block_dim; bool strictly_serialized; + std::string range_hint; RangeForStmt(Stmt *begin, Stmt *end, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim, - bool strictly_serialized); + bool strictly_serialized, + std::string range_hint = ""); bool is_container_statement() const override { return true; @@ -733,7 +797,6 @@ class RangeForStmt : public Stmt { TI_STMT_DEF_FIELDS(begin, end, reversed, - vectorize, bit_vectorize, num_cpu_threads, block_dim, @@ -752,7 +815,6 @@ class StructForStmt : public Stmt { std::unique_ptr block_initialization; std::unique_ptr block_finalization; std::vector index_offsets; - int vectorize; int bit_vectorize; int num_cpu_threads; int block_dim; @@ -760,7 +822,6 @@ class StructForStmt : public Stmt { StructForStmt(SNode *snode, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim); @@ -773,7 +834,6 @@ class StructForStmt : public Stmt { TI_STMT_DEF_FIELDS(snode, index_offsets, - vectorize, bit_vectorize, num_cpu_threads, block_dim, @@ -782,15 +842,26 @@ class StructForStmt : public Stmt { }; /** - * An inline Taichi function. - * TODO: This statement seems unused. + * meshfor */ -class FuncBodyStmt : public Stmt { +class MeshForStmt : public Stmt { public: - std::string funcid; + mesh::Mesh *mesh; std::unique_ptr body; + int bit_vectorize; + int num_cpu_threads; + int block_dim; + mesh::MeshElementType major_from_type; + std::unordered_set major_to_types{}; + std::unordered_set minor_relation_types{}; + MemoryAccessOptions mem_access_opt; - FuncBodyStmt(const std::string &funcid, std::unique_ptr &&body); + MeshForStmt(mesh::Mesh *mesh, + mesh::MeshElementType element_type, + std::unique_ptr &&body, + int bit_vectorize, + int num_cpu_threads, + int block_dim); bool is_container_statement() const override { return true; @@ -798,7 +869,14 @@ class FuncBodyStmt : public Stmt { std::unique_ptr clone() const override; - TI_STMT_DEF_FIELDS(funcid); + TI_STMT_DEF_FIELDS(mesh, + bit_vectorize, + num_cpu_threads, + block_dim, + major_from_type, + major_to_types, + minor_relation_types, + mem_access_opt); TI_DEFINE_ACCEPT }; @@ -821,13 +899,35 @@ class FuncCallStmt : public Stmt { */ class ReturnStmt : public Stmt { public: - Stmt *value; + std::vector values; + + explicit ReturnStmt(const std::vector &values) : values(values) { + TI_STMT_REG_FIELDS; + } - explicit ReturnStmt(Stmt *value) : value(value) { + explicit ReturnStmt(Stmt *value) : values({value}) { TI_STMT_REG_FIELDS; } - TI_STMT_DEF_FIELDS(value); + std::vector element_types() { + std::vector ele_types; + for (auto &x : values) { + ele_types.push_back(x->element_type()); + } + return ele_types; + } + + std::string values_raw_names() { + std::string names; + for (auto &x : values) { + names += x->raw_name() + ", "; + } + names.pop_back(); + names.pop_back(); + return names; + } + + TI_STMT_DEF_FIELDS(values); TI_DEFINE_ACCEPT_AND_CLONE }; @@ -851,19 +951,6 @@ class WhileStmt : public Stmt { TI_DEFINE_ACCEPT }; -// TODO: remove this -class PragmaSLPStmt : public Stmt { - public: - int slp_width; - - PragmaSLPStmt(int slp_width) : slp_width(slp_width) { - TI_STMT_REG_FIELDS; - } - - TI_STMT_DEF_FIELDS(slp_width); - TI_DEFINE_ACCEPT_AND_CLONE -}; - // TODO: document for this class ElementShuffleStmt : public Stmt { public: @@ -1062,10 +1149,27 @@ class OffloadedStmt : public Stmt { int block_dim{1}; bool reversed{false}; int num_cpu_threads{1}; + Stmt *end_stmt{nullptr}; + std::string range_hint = ""; + + mesh::Mesh *mesh{nullptr}; + mesh::MeshElementType major_from_type; + std::unordered_set major_to_types; + std::unordered_set minor_relation_types; + + std::unordered_map + owned_offset_local; // |owned_offset[idx]| + std::unordered_map + total_offset_local; // |total_offset[idx]| + std::unordered_map + owned_num_local; // |owned_offset[idx+1] - owned_offset[idx]| + std::unordered_map + total_num_local; // |total_offset[idx+1] - total_offset[idx]| std::vector index_offsets; std::unique_ptr tls_prologue; + std::unique_ptr mesh_prologue; // mesh-for only block std::unique_ptr bls_prologue; std::unique_ptr body; std::unique_ptr bls_epilogue; @@ -1090,7 +1194,7 @@ class OffloadedStmt : public Stmt { std::unique_ptr clone() const override; - void all_blocks_accept(IRVisitor *visitor); + void all_blocks_accept(IRVisitor *visitor, bool skip_mesh_prologue = false); TI_STMT_DEF_FIELDS(ret_type /*inherited from Stmt*/, task_type, @@ -1123,6 +1227,27 @@ class LoopIndexStmt : public Stmt { TI_STMT_REG_FIELDS; } + bool is_mesh_index() const { + if (auto offload = loop->cast()) { + return offload->task_type == OffloadedTaskType::mesh_for; + } else if (loop->cast()) { + return true; + } else { + return false; + } + } + + mesh::MeshElementType mesh_index_type() const { + TI_ASSERT(is_mesh_index()); + if (auto offload = loop->cast()) { + return offload->major_from_type; + } else if (auto mesh_for = loop->cast()) { + return mesh_for->major_from_type; + } else { + TI_NOT_IMPLEMENTED; + } + } + bool has_global_side_effect() const override { return false; } @@ -1155,15 +1280,11 @@ class LoopLinearIndexStmt : public Stmt { }; /** - * The lowest |index|-th index of the |loop| among the iterations iterated by - * the block. + * global thread index, i.e. thread_idx() + block_idx() * block_dim() */ -class BlockCornerIndexStmt : public Stmt { +class GlobalThreadIndexStmt : public Stmt { public: - Stmt *loop; - int index; - - BlockCornerIndexStmt(Stmt *loop, int index) : loop(loop), index(index) { + explicit GlobalThreadIndexStmt() { TI_STMT_REG_FIELDS; } @@ -1171,14 +1292,20 @@ class BlockCornerIndexStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, loop, index); + TI_STMT_DEF_FIELDS(ret_type); TI_DEFINE_ACCEPT_AND_CLONE }; -// TODO: remove this -class BlockDimStmt : public Stmt { +/** + * The lowest |index|-th index of the |loop| among the iterations iterated by + * the block. + */ +class BlockCornerIndexStmt : public Stmt { public: - BlockDimStmt() { + Stmt *loop; + int index; + + BlockCornerIndexStmt(Stmt *loop, int index) : loop(loop), index(index) { TI_STMT_REG_FIELDS; } @@ -1186,7 +1313,7 @@ class BlockDimStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type); + TI_STMT_DEF_FIELDS(ret_type, loop, index); TI_DEFINE_ACCEPT_AND_CLONE }; @@ -1471,5 +1598,112 @@ class BitStructStoreStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE; }; +// Mesh related. + +/** + * The relation access, mesh_idx -> to_type[neighbor_idx] + * If neibhor_idex has no value, it returns the number of neighbors (length of + * relation) of a mesh idx + */ +class MeshRelationAccessStmt : public Stmt { + public: + mesh::Mesh *mesh; + Stmt *mesh_idx; + mesh::MeshElementType to_type; + Stmt *neighbor_idx; + + MeshRelationAccessStmt(mesh::Mesh *mesh, + Stmt *mesh_idx, + mesh::MeshElementType to_type, + Stmt *neighbor_idx) + : mesh(mesh), + mesh_idx(mesh_idx), + to_type(to_type), + neighbor_idx(neighbor_idx) { + this->ret_type = PrimitiveType::i32; + TI_STMT_REG_FIELDS; + } + + MeshRelationAccessStmt(mesh::Mesh *mesh, + Stmt *mesh_idx, + mesh::MeshElementType to_type) + : mesh(mesh), + mesh_idx(mesh_idx), + to_type(to_type), + neighbor_idx(nullptr) { + this->ret_type = PrimitiveType::i32; + TI_STMT_REG_FIELDS; + } + + bool is_size() const { + return neighbor_idx == nullptr; + } + + bool has_global_side_effect() const override { + return false; + } + + mesh::MeshElementType from_type() const { + if (auto idx = mesh_idx->cast()) { + TI_ASSERT(idx->is_mesh_index()); + return idx->mesh_index_type(); + } else if (auto idx = mesh_idx->cast()) { + TI_ASSERT(!idx->is_size()); + return idx->to_type; + } else { + TI_NOT_IMPLEMENTED; + } + } + + TI_STMT_DEF_FIELDS(ret_type, mesh, mesh_idx, to_type, neighbor_idx); + TI_DEFINE_ACCEPT_AND_CLONE +}; + +/** + * Convert a mesh index to another index space + */ +class MeshIndexConversionStmt : public Stmt { + public: + mesh::Mesh *mesh; + mesh::MeshElementType idx_type; + Stmt *idx; + + mesh::ConvType conv_type; + + MeshIndexConversionStmt(mesh::Mesh *mesh, + mesh::MeshElementType idx_type, + Stmt *idx, + mesh::ConvType conv_type) + : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { + this->ret_type = PrimitiveType::i32; + TI_STMT_REG_FIELDS; + } + + bool has_global_side_effect() const override { + return false; + } + + TI_STMT_DEF_FIELDS(ret_type, mesh, idx_type, idx, conv_type); + TI_DEFINE_ACCEPT_AND_CLONE +}; + +/** + * The patch index of the |mesh_loop|. + */ +class MeshPatchIndexStmt : public Stmt { + public: + MeshPatchIndexStmt() { + this->ret_type = PrimitiveType::i32; + TI_STMT_REG_FIELDS; + } + + bool has_global_side_effect() const override { + return false; + } + + TI_STMT_DEF_FIELDS(ret_type); + TI_DEFINE_ACCEPT_AND_CLONE +}; + } // namespace lang } // namespace taichi diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 82fe83caf7fd7..4ff9b74fc4204 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -13,6 +13,8 @@ #include "taichi/transforms/inlining.h" #include "taichi/transforms/lower_access.h" #include "taichi/transforms/make_block_local.h" +#include "taichi/transforms/make_mesh_block_local.h" +#include "taichi/transforms/demote_mesh_statements.h" #include "taichi/transforms/simplify.h" #include "taichi/common/trait.h" @@ -38,7 +40,6 @@ bool alg_simp(IRNode *root, const CompileConfig &config); bool demote_operations(IRNode *root, const CompileConfig &config); bool binary_op_simplify(IRNode *root, const CompileConfig &config); bool whole_kernel_cse(IRNode *root); -void variable_optimization(IRNode *root, bool after_lower_access); bool extract_constant(IRNode *root, const CompileConfig &config); bool unreachable_code_elimination(IRNode *root); bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config); @@ -46,15 +47,14 @@ void full_simplify(IRNode *root, const CompileConfig &config, const FullSimplifyPass::Args &args); void print(IRNode *root, std::string *output = nullptr); +void frontend_type_check(IRNode *root); void lower_ast(IRNode *root); void type_check(IRNode *root, const CompileConfig &config); bool inlining(IRNode *root, const CompileConfig &config, const InliningPass::Args &args); -void loop_vectorize(IRNode *root, const CompileConfig &config); void bit_loop_vectorize(IRNode *root); void slp_vectorize(IRNode *root); -void vector_split(IRNode *root, int max_width, bool serial_schedule); void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt); bool check_out_of_bound(IRNode *root, const CompileConfig &config, @@ -64,6 +64,15 @@ std::unique_ptr initialize_scratch_pad(OffloadedStmt *root); void make_block_local(IRNode *root, const CompileConfig &config, const MakeBlockLocalPass::Args &args); +void make_mesh_thread_local(IRNode *root, + const CompileConfig &config, + const MakeBlockLocalPass::Args &args); +void make_mesh_block_local(IRNode *root, + const CompileConfig &config, + const MakeMeshBlockLocal::Args &args); +void demote_mesh_statements(IRNode *root, + const CompileConfig &config, + const DemoteMeshStatements::Args &args); bool remove_loop_unique(IRNode *root); bool remove_range_assumption(IRNode *root); bool lower_access(IRNode *root, @@ -108,6 +117,7 @@ bool replace_statements(IRNode *root, std::function filter, std::function finder); void demote_dense_struct_fors(IRNode *root, bool packed); +void demote_no_access_mesh_fors(IRNode *root); bool demote_atomics(IRNode *root, const CompileConfig &config); void reverse_segments(IRNode *root); // for autograd void detect_read_only(IRNode *root); @@ -137,7 +147,6 @@ void compile_to_offloads(IRNode *ir, const CompileConfig &config, Kernel *kernel, bool verbose, - bool vectorize, bool grad, bool ad_use_stack, bool start_from_ast); @@ -155,7 +164,6 @@ void offload_to_executable(IRNode *ir, void compile_to_executable(IRNode *ir, const CompileConfig &config, Kernel *kernel, - bool vectorize, bool grad, bool ad_use_stack, bool verbose, diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 66b59835db4c1..eb220683bbadd 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -78,10 +78,6 @@ std::string PointerType::to_string() const { } } -std::string VectorType::to_string() const { - return fmt::format("[{} x {}]", num_elements_, element_->to_string()); -} - std::string TensorType::to_string() const { std::string s = "[Tensor ("; for (int i = 0; i < (int)shape_.size(); ++i) { @@ -92,11 +88,7 @@ std::string TensorType::to_string() const { } int Type::vector_width() const { - if (auto vec = cast()) { - return vec->get_num_elements(); - } else { - return 1; - } + return 1; // TODO: CPU vectorization } bool Type::is_primitive(PrimitiveTypeID type) const { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 6b4acf069565a..d3f3a7f5a9f61 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -11,7 +11,7 @@ enum class PrimitiveTypeID : int { #undef PER_TYPE }; -class Type { +class TI_DLL_EXPORT Type { public: virtual std::string to_string() const = 0; @@ -51,7 +51,7 @@ class Type { }; // A "Type" handle. This should be removed later. -class DataType { +class TI_DLL_EXPORT DataType { public: DataType(); @@ -107,7 +107,7 @@ class DataType { // Note that all types are immutable once created. -class PrimitiveType : public Type { +class TI_DLL_EXPORT PrimitiveType : public Type { public: #define PER_TYPE(x) static DataType x; #include "taichi/inc/data_type.inc.h" @@ -121,7 +121,7 @@ class PrimitiveType : public Type { std::string to_string() const override; - virtual Type *get_compute_type() override { + Type *get_compute_type() override { return this; } @@ -154,28 +154,6 @@ class PointerType : public Type { bool is_bit_pointer_{false}; }; -class VectorType : public Type { - public: - VectorType(int num_elements, Type *element) - : num_elements_(num_elements), element_(element) { - TI_ASSERT(num_elements_ != 1); - } - - Type *get_element_type() const { - return element_; - } - - int get_num_elements() const { - return num_elements_; - } - - std::string to_string() const override; - - private: - int num_elements_{0}; - Type *element_{nullptr}; -}; - class TensorType : public Type { public: TensorType(std::vector shape, Type *element) @@ -197,6 +175,10 @@ class TensorType : public Type { return shape_; } + Type *get_compute_type() override { + return this; + } + std::string to_string() const override; private: @@ -385,6 +367,24 @@ class TypedConstant { TypedConstant(float64 x) : dt(PrimitiveType::f64), val_f64(x) { } + TypedConstant(int8 x) : dt(PrimitiveType::i8), val_i8(x) { + } + + TypedConstant(int16 x) : dt(PrimitiveType::i16), val_i16(x) { + } + + TypedConstant(uint8 x) : dt(PrimitiveType::u8), val_u8(x) { + } + + TypedConstant(uint16 x) : dt(PrimitiveType::u16), val_u16(x) { + } + + TypedConstant(uint32 x) : dt(PrimitiveType::u32), val_u32(x) { + } + + TypedConstant(uint64 x) : dt(PrimitiveType::u64), val_u64(x) { + } + template TypedConstant(DataType dt, const T &value) : dt(dt) { // TODO: loud failure on pointers diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 95f12ee1969ec..2df31d64fc070 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -22,14 +22,6 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { return primitive_types_[id].get(); } -Type *TypeFactory::get_vector_type(int num_elements, Type *element) { - auto key = std::make_pair(num_elements, element); - if (vector_types_.find(key) == vector_types_.end()) { - vector_types_[key] = std::make_unique(num_elements, element); - } - return vector_types_[key].get(); -} - Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { auto encode = [](const std::vector &shape) -> std::string { std::string s; @@ -57,11 +49,11 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed, Type *compute_type) { auto key = std::make_tuple(num_bits, is_signed, compute_type); - if (custom_int_types.find(key) == custom_int_types.end()) { - custom_int_types[key] = + if (custom_int_types_.find(key) == custom_int_types_.end()) { + custom_int_types_[key] = std::make_unique(num_bits, is_signed, compute_type); } - return custom_int_types[key].get(); + return custom_int_types_[key].get(); } Type *TypeFactory::get_custom_float_type(Type *digits_type, @@ -69,11 +61,11 @@ Type *TypeFactory::get_custom_float_type(Type *digits_type, Type *compute_type, float64 scale) { auto key = std::make_tuple(digits_type, exponent_type, compute_type, scale); - if (custom_float_types.find(key) == custom_float_types.end()) { - custom_float_types[key] = std::make_unique( + if (custom_float_types_.find(key) == custom_float_types_.end()) { + custom_float_types_[key] = std::make_unique( digits_type, exponent_type, compute_type, scale); } - return custom_float_types[key].get(); + return custom_float_types_[key].get(); } Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type, @@ -174,11 +166,6 @@ class TypePromotionMapping { TI_WARN("promoted_type got a pointer input."); } - if (d->is()) { - d = d->as()->get_element_type(); - TI_WARN("promoted_type got a vector input."); - } - if (d->is()) { d = d->as()->get_element_type(); TI_WARN("promoted_type got a tensor input."); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 8257a7e02704d..ddcd498b15c47 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -17,8 +17,6 @@ class TypeFactory { PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); - Type *get_vector_type(int num_elements, Type *element); - Type *get_tensor_type(std::vector shape, Type *element); Type *get_pointer_type(Type *element, bool is_bit_pointer = false); @@ -60,11 +58,11 @@ class TypeFactory { // TODO: use unordered map std::map, std::unique_ptr> - custom_int_types; + custom_int_types_; // TODO: use unordered map std::map, std::unique_ptr> - custom_float_types; + custom_float_types_; // TODO: avoid duplication std::vector> bit_struct_types_; diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index 7ba723994c7c9..a14e435473c02 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -20,7 +20,11 @@ std::string data_type_name(DataType t) { } std::string data_type_format(DataType dt) { - if (dt->is_primitive(PrimitiveTypeID::i32)) { + if (dt->is_primitive(PrimitiveTypeID::i16)) { + return "%hd"; + } else if (dt->is_primitive(PrimitiveTypeID::u16)) { + return "%hu"; + } else if (dt->is_primitive(PrimitiveTypeID::i32)) { return "%d"; } else if (dt->is_primitive(PrimitiveTypeID::u32)) { return "%u"; @@ -36,6 +40,11 @@ std::string data_type_format(DataType dt) { return "%.12f"; } else if (dt->is()) { return "%d"; + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + // f16 (and f32) is converted to f64 before printing, see + // CodeGenLLVM::visit(PrintStmt *stmt) and + // CodeGenLLVMCUDA::visit(PrintStmt *stmt) for more details. + return "%f"; } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index b27f78cdc3ee7..b5e10ed0f29c5 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -73,6 +73,10 @@ inline PrimitiveTypeID get_primitive_data_type() { } } +inline bool is_custom_type(DataType dt) { + return dt->is() || dt->is(); +} + inline bool is_real(DataType dt) { return dt->is_primitive(PrimitiveTypeID::f16) || dt->is_primitive(PrimitiveTypeID::f32) || diff --git a/taichi/ir/visitors.h b/taichi/ir/visitors.h index cc1807fd43c4e..1b1642054d889 100644 --- a/taichi/ir/visitors.h +++ b/taichi/ir/visitors.h @@ -21,9 +21,9 @@ class BasicStmtVisitor : public IRVisitor { void visit(StructForStmt *for_stmt) override; - void visit(OffloadedStmt *stmt) override; + void visit(MeshForStmt *for_stmt) override; - void visit(FuncBodyStmt *stmt) override; + void visit(OffloadedStmt *stmt) override; void visit(FrontendWhileStmt *stmt) override; diff --git a/taichi/jit/jit_module.h b/taichi/jit/jit_module.h index 8422f4e966a44..d78a4a1d744a7 100644 --- a/taichi/jit/jit_module.h +++ b/taichi/jit/jit_module.h @@ -4,7 +4,6 @@ #include #include "taichi/inc/constants.h" -#include "taichi/llvm/llvm_fwd.h" #include "taichi/lang_util.h" #include "taichi/program/kernel_profiler.h" diff --git a/taichi/jit/jit_session.cpp b/taichi/jit/jit_session.cpp index 3dbe3c421d3b8..dd9547589949b 100644 --- a/taichi/jit/jit_session.cpp +++ b/taichi/jit/jit_session.cpp @@ -1,29 +1,39 @@ #include "taichi/jit/jit_session.h" +#ifdef TI_WITH_LLVM #include "llvm/IR/DataLayout.h" +#endif TLANG_NAMESPACE_BEGIN -std::unique_ptr create_llvm_jit_session_cpu(Arch arch); -std::unique_ptr create_llvm_jit_session_cuda(Arch arch); +#ifdef TI_WITH_LLVM +std::unique_ptr create_llvm_jit_session_cpu( + LlvmProgramImpl *llvm_prog, + Arch arch); +std::unique_ptr create_llvm_jit_session_cuda( + LlvmProgramImpl *llvm_prog, + Arch arch); +#endif -std::unique_ptr JITSession::create(Arch arch) { +JITSession::JITSession(LlvmProgramImpl *llvm_prog) : llvm_prog_(llvm_prog) { +} + +std::unique_ptr JITSession::create(LlvmProgramImpl *llvm_prog, + Arch arch) { +#ifdef TI_WITH_LLVM if (arch_is_cpu(arch)) { - return create_llvm_jit_session_cpu(arch); + return create_llvm_jit_session_cpu(llvm_prog, arch); } else if (arch == Arch::cuda) { #if defined(TI_WITH_CUDA) - return create_llvm_jit_session_cuda(arch); + return create_llvm_jit_session_cuda(llvm_prog, arch); #else TI_NOT_IMPLEMENTED #endif - } else - TI_NOT_IMPLEMENTED -} - -std::size_t JITSession::get_type_size(llvm::Type *type) { - return get_data_layout().getTypeAllocSize(type); + } +#else + TI_ERROR("Llvm disabled"); +#endif + return nullptr; } -llvm::DataLayout JITSession::get_data_layout(){TI_NOT_IMPLEMENTED} - TLANG_NAMESPACE_END diff --git a/taichi/jit/jit_session.h b/taichi/jit/jit_session.h index 5a9f5cd71d312..7a5e141d5bd3a 100644 --- a/taichi/jit/jit_session.h +++ b/taichi/jit/jit_session.h @@ -11,13 +11,17 @@ TLANG_NAMESPACE_BEGIN // Backend JIT compiler for all archs +class LlvmProgramImpl; + class JITSession { + private: + LlvmProgramImpl *llvm_prog_; + protected: std::vector> modules; public: - JITSession() { - } + JITSession(LlvmProgramImpl *llvm_prog); virtual JITModule *add_module(std::unique_ptr M, int max_reg = 0) = 0; @@ -28,16 +32,20 @@ class JITSession { TI_NOT_IMPLEMENTED } - virtual llvm::DataLayout get_data_layout(); - - std::size_t get_type_size(llvm::Type *type); + virtual llvm::DataLayout get_data_layout() = 0; - static std::unique_ptr create(Arch arch); + static std::unique_ptr create(LlvmProgramImpl *llvm_prog, + Arch arch); virtual void global_optimize_module(llvm::Module *module) { } virtual ~JITSession() = default; + + protected: + LlvmProgramImpl *llvm_prog() const { + return llvm_prog_; + } }; TLANG_NAMESPACE_END diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index cc6441acb0b47..31f6f89f489df 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -3,7 +3,7 @@ #include "taichi/lang_util.h" #include "taichi/math/linalg.h" -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" #include "taichi/program/program.h" #include "taichi/program/compile_config.h" #include "taichi/system/timer.h" diff --git a/taichi/lang_util.h b/taichi/lang_util.h index 14be02a6a3440..020f5fa8eb0be 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -4,7 +4,7 @@ #include "taichi/util/io.h" #include "taichi/common/core.h" #include "taichi/system/profiler.h" -#include "taichi/ir/ir_modified.h" +#include "taichi/common/exceptions.h" #include "taichi/ir/stmt_op_types.h" #include "taichi/ir/type.h" #include "taichi/ir/type_utils.h" @@ -20,9 +20,9 @@ real measure_cpe(std::function target, int64 elements_per_call, real time_second = default_measurement_time); -struct Context; +struct RuntimeContext; -using FunctionType = std::function; +using FunctionType = std::function; inline std::string make_list(const std::vector &data, std::string bracket = "") { diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index 601cff5bb158a..fcb941abde4f7 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -46,11 +46,6 @@ std::string type_name(llvm::Type *type); void check_func_call_signature(llvm::Value *func, std::vector arglist); -template -inline bool check_func_call_signature(llvm::Value *func, Args &&... args) { - return check_func_call_signature(func, {args...}); -} - class LLVMModuleBuilder { public: std::unique_ptr module{nullptr}; diff --git a/taichi/llvm/llvm_context.cpp b/taichi/llvm/llvm_context.cpp index b5da5eeeb32f1..11ddc2220d166 100644 --- a/taichi/llvm/llvm_context.cpp +++ b/taichi/llvm/llvm_context.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/LLVMContext.h" @@ -57,10 +58,11 @@ namespace lang { using namespace llvm; -TaichiLLVMContext::TaichiLLVMContext(Arch arch) : arch(arch) { +TaichiLLVMContext::TaichiLLVMContext(LlvmProgramImpl *llvm_prog, Arch arch) + : arch_(arch) { TI_TRACE("Creating Taichi llvm context for arch: {}", arch_name(arch)); - main_thread_id = std::this_thread::get_id(); - main_thread_data = get_this_thread_data(); + main_thread_id_ = std::this_thread::get_id(); + main_thread_data_ = get_this_thread_data(); llvm::remove_fatal_error_handler(); llvm::install_fatal_error_handler( [](void *user_data, const std::string &reason, bool gen_crash_diag) { @@ -91,7 +93,7 @@ TaichiLLVMContext::TaichiLLVMContext(Arch arch) : arch(arch) { TI_NOT_IMPLEMENTED #endif } - jit = JITSession::create(arch); + jit = JITSession::create(llvm_prog, arch); TI_TRACE("Taichi llvm context created."); } @@ -120,6 +122,8 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt32Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::u64)) { return llvm::Type::getInt64Ty(*ctx); + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + return llvm::Type::getHalfTy(*ctx); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED @@ -158,7 +162,7 @@ std::unique_ptr TaichiLLVMContext::clone_module_to_context( std::string bitcode; { - std::lock_guard _(mut); + std::lock_guard _(mut_); llvm::raw_string_ostream sos(bitcode); // Use a scope to make sure sos flushes on destruction llvm::WriteBitcodeToFile(*module, sos); @@ -261,10 +265,10 @@ void TaichiLLVMContext::init_runtime_jit_module() { std::unique_ptr TaichiLLVMContext::clone_runtime_module() { TI_AUTO_PROF - TI_ASSERT(std::this_thread::get_id() == main_thread_id); + TI_ASSERT(std::this_thread::get_id() == main_thread_id_); auto data = get_this_thread_data(); if (!data->runtime_module) { - data->runtime_module = clone_module(get_runtime_fn(arch)); + data->runtime_module = clone_module(get_runtime_fn(arch_)); } std::unique_ptr cloned; @@ -283,7 +287,7 @@ std::unique_ptr TaichiLLVMContext::clone_module( auto ctx = get_this_thread_context(); std::unique_ptr module = module_from_bitcode_file( fmt::format("{}/{}", runtime_lib_dir(), file), ctx); - if (arch == Arch::cuda) { + if (arch_ == Arch::cuda) { module->setTargetTriple("nvptx64-nvidia-cuda"); #if defined(TI_WITH_CUDA) @@ -405,7 +409,7 @@ std::unique_ptr TaichiLLVMContext::clone_module( void TaichiLLVMContext::link_module_with_cuda_libdevice( std::unique_ptr &module) { TI_AUTO_PROF - TI_ASSERT(arch == Arch::cuda); + TI_ASSERT(arch_ == Arch::cuda); auto libdevice_module = module_from_bitcode_file(libdevice_path(), get_this_thread_context()); @@ -432,7 +436,7 @@ void TaichiLLVMContext::link_module_with_cuda_libdevice( if (!func) { TI_INFO("Function {} not found", func_name); } else - func->setLinkage(Function::InternalLinkage); + func->setLinkage(llvm::Function::InternalLinkage); } } @@ -460,6 +464,8 @@ llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) { auto ctx = get_this_thread_context(); if (dt->is_primitive(PrimitiveTypeID::f32)) { return llvm::ConstantFP::get(*ctx, llvm::APFloat((float32)t)); + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + return llvm::ConstantFP::get(llvm::Type::getHalfTy(*ctx), (float32)t); } else if (dt->is_primitive(PrimitiveTypeID::f64)) { return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t)); } else if (is_integral(dt)) { @@ -513,7 +519,12 @@ std::string TaichiLLVMContext::type_name(llvm::Type *type) { } std::size_t TaichiLLVMContext::get_type_size(llvm::Type *type) { - return jit->get_type_size(type); + return get_data_layout().getTypeAllocSize(type); +} + +std::size_t TaichiLLVMContext::get_struct_element_offset(llvm::StructType *type, + int idx) { + return get_data_layout().getStructLayout(type)->getElementOffset(idx); } void TaichiLLVMContext::mark_inline(llvm::Function *f) { @@ -630,15 +641,15 @@ void TaichiLLVMContext::eliminate_unused_functions( } TaichiLLVMContext::ThreadLocalData *TaichiLLVMContext::get_this_thread_data() { - std::lock_guard _(thread_map_mut); + std::lock_guard _(thread_map_mut_); auto tid = std::this_thread::get_id(); - if (per_thread_data.find(tid) == per_thread_data.end()) { + if (per_thread_data_.find(tid) == per_thread_data_.end()) { std::stringstream ss; ss << tid; TI_TRACE("Creating thread local data for thread {}", ss.str()); - per_thread_data[tid] = std::make_unique(); + per_thread_data_[tid] = std::make_unique(); } - return per_thread_data[tid].get(); + return per_thread_data_[tid].get(); } llvm::LLVMContext *TaichiLLVMContext::get_this_thread_context() { @@ -663,7 +674,7 @@ llvm::Module *TaichiLLVMContext::get_this_thread_struct_module() { ThreadLocalData *data = get_this_thread_data(); if (!data->struct_module) { data->struct_module = clone_module_to_this_thread_context( - main_thread_data->struct_module.get()); + main_thread_data_->struct_module.get()); } return data->struct_module.get(); } @@ -702,7 +713,7 @@ auto make_slim_libdevice = [](const std::vector &args) { void TaichiLLVMContext::update_runtime_jit_module( std::unique_ptr module) { - if (arch == Arch::cuda) { + if (arch_ == Arch::cuda) { for (auto &f : *module) { bool is_kernel = false; const std::string func_name = f.getName(); @@ -726,6 +737,22 @@ void TaichiLLVMContext::update_runtime_jit_module( runtime_jit_module = add_module(std::move(module)); } +void TaichiLLVMContext::delete_functions_of_snode_tree(int id) { + if (!snode_tree_funcs_.count(id)) { + return; + } + llvm::Module *module = get_this_thread_struct_module(); + for (auto str : snode_tree_funcs_[id]) { + auto *func = module->getFunction(str); + func->eraseFromParent(); + } + snode_tree_funcs_.erase(id); +} + +void TaichiLLVMContext::add_function_to_snode_tree(int id, std::string func) { + snode_tree_funcs_[id].push_back(func); +} + TI_REGISTER_TASK(make_slim_libdevice); } // namespace lang diff --git a/taichi/llvm/llvm_context.h b/taichi/llvm/llvm_context.h index b99ad00f08973..dd2510d873570 100644 --- a/taichi/llvm/llvm_context.h +++ b/taichi/llvm/llvm_context.h @@ -37,7 +37,7 @@ class TaichiLLVMContext { // main_thread is defined to be the thread that runs the initializer JITModule *runtime_jit_module{nullptr}; - TaichiLLVMContext(Arch arch); + TaichiLLVMContext(LlvmProgramImpl *llvm_prog, Arch arch); virtual ~TaichiLLVMContext(); @@ -104,6 +104,8 @@ class TaichiLLVMContext { std::size_t get_type_size(llvm::Type *type); + std::size_t get_struct_element_offset(llvm::StructType *type, int idx); + template llvm::Value *get_constant(T t); @@ -126,6 +128,10 @@ class TaichiLLVMContext { void mark_function_as_cuda_kernel(llvm::Function *func, int block_dim = 0); + void add_function_to_snode_tree(int id, std::string func); + + void delete_functions_of_snode_tree(int id); + private: std::unique_ptr clone_module_to_context( llvm::Module *module, @@ -145,15 +151,20 @@ class TaichiLLVMContext { void update_runtime_jit_module(std::unique_ptr module); std::unordered_map> - per_thread_data; + per_thread_data_; - Arch arch; + Arch arch_; - std::thread::id main_thread_id; - ThreadLocalData *main_thread_data{nullptr}; - std::mutex mut; - std::mutex thread_map_mut; + std::thread::id main_thread_id_; + ThreadLocalData *main_thread_data_{nullptr}; + std::mutex mut_; + std::mutex thread_map_mut_; + + std::unordered_map> snode_tree_funcs_; }; +std::unique_ptr module_from_bitcode_file(std::string bitcode_path, + llvm::LLVMContext *ctx); + } // namespace lang } // namespace taichi diff --git a/taichi/llvm/llvm_device.cpp b/taichi/llvm/llvm_device.cpp new file mode 100644 index 0000000000000..5643dd9dd16f2 --- /dev/null +++ b/taichi/llvm/llvm_device.cpp @@ -0,0 +1,16 @@ +#include "taichi/llvm/llvm_device.h" + +namespace taichi { +namespace lang { + +uint64_t *LlvmDevice::allocate_llvm_runtime_memory_jit( + const LlvmRuntimeAllocParams ¶ms) { + params.runtime_jit->call( + "runtime_memory_allocate_aligned", params.runtime, params.size, + taichi_page_size); + return taichi_union_cast_with_different_sizes(fetch_result_uint64( + taichi_result_buffer_runtime_query_id, params.result_buffer)); +} + +} // namespace lang +} // namespace taichi diff --git a/taichi/llvm/llvm_device.h b/taichi/llvm/llvm_device.h new file mode 100644 index 0000000000000..e692f68521459 --- /dev/null +++ b/taichi/llvm/llvm_device.h @@ -0,0 +1,27 @@ +#pragma once + +#include "taichi/backends/device.h" + +namespace taichi { +namespace lang { + +class LlvmDevice : public Device { + public: + struct LlvmRuntimeAllocParams : AllocParams { + bool use_cached{true}; + JITModule *runtime_jit{nullptr}; + LLVMRuntime *runtime{nullptr}; + uint64 *result_buffer{nullptr}; + }; + + virtual DeviceAllocation allocate_memory_runtime( + const LlvmRuntimeAllocParams ¶ms) { + TI_NOT_IMPLEMENTED; + } + + uint64_t *allocate_llvm_runtime_memory_jit( + const LlvmRuntimeAllocParams ¶ms); +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/llvm/llvm_fwd.h b/taichi/llvm/llvm_fwd.h index 70621ae6a8bd7..dae04dde079a5 100644 --- a/taichi/llvm/llvm_fwd.h +++ b/taichi/llvm/llvm_fwd.h @@ -7,6 +7,7 @@ class Value; class Module; class Function; class DataLayout; +class StructType; class JITSymbol; class ExitOnError; namespace orc { diff --git a/taichi/llvm/llvm_program.cpp b/taichi/llvm/llvm_program.cpp index 6dd2d937788ae..c07596707360f 100644 --- a/taichi/llvm/llvm_program.cpp +++ b/taichi/llvm/llvm_program.cpp @@ -1,7 +1,8 @@ #include "llvm_program.h" +#include "llvm/IR/Module.h" #include "taichi/backends/cuda/cuda_driver.h" -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" #include "taichi/platform/cuda/detect_cuda.h" #include "taichi/math/arithmetic.h" #include "taichi/runtime/llvm/mem_request.h" @@ -36,15 +37,15 @@ void *taichi_allocate_aligned(MemoryPool *memory_pool, LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_, KernelProfilerBase *profiler) : ProgramImpl(config_) { - runtime_mem_info = Runtime::create(config_.arch); + runtime_mem_info_ = Runtime::create(config_.arch); if (config_.arch == Arch::cuda) { - if (!runtime_mem_info) { + if (!runtime_mem_info_) { TI_WARN("Taichi is not compiled with CUDA."); config_.arch = host_arch(); } else if (!is_cuda_api_available()) { TI_WARN("No CUDA driver API detected."); config_.arch = host_arch(); - } else if (!runtime_mem_info->detected()) { + } else if (!runtime_mem_info_->detected()) { TI_WARN("No CUDA device detected."); config_.arch = host_arch(); } else { @@ -55,40 +56,51 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_, } } - snode_tree_buffer_manager = std::make_unique(this); + snode_tree_buffer_manager_ = std::make_unique(this); - thread_pool = std::make_unique(config->cpu_max_num_threads); + thread_pool_ = std::make_unique(config->cpu_max_num_threads); - preallocated_device_buffer = nullptr; - llvm_runtime = nullptr; - llvm_context_host = std::make_unique(host_arch()); + preallocated_device_buffer_ = nullptr; + llvm_runtime_ = nullptr; + llvm_context_host_ = std::make_unique(this, host_arch()); if (config_.arch == Arch::cuda) { #if defined(TI_WITH_CUDA) - int num_SMs; + int num_SMs{1}; CUDADriver::get_instance().device_get_attribute( &num_SMs, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, nullptr); - int query_max_block_dim; + int query_max_block_dim{1024}; CUDADriver::get_instance().device_get_attribute( &query_max_block_dim, CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, nullptr); + int version{0}; + CUDADriver::get_instance().driver_get_version(&version); + int query_max_block_per_sm{16}; + if (version >= 11000) { + // query this attribute only when CUDA version is above 11.0 + CUDADriver::get_instance().device_get_attribute( + &query_max_block_per_sm, + CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR, nullptr); + } if (config_.max_block_dim == 0) { config_.max_block_dim = query_max_block_dim; } if (config_.saturating_grid_dim == 0) { - // each SM can have 16-32 resident blocks - config_.saturating_grid_dim = num_SMs * 32; + if (version >= 11000) { + TI_TRACE("CUDA max blocks per SM = {}", query_max_block_per_sm); + } + config_.saturating_grid_dim = num_SMs * query_max_block_per_sm * 2; } #endif } if (arch_is_cpu(config->arch)) { config_.max_block_dim = 1024; - device_ = std::make_unique(); + device_ = std::make_shared(); } - if (config->kernel_profiler && runtime_mem_info) { - runtime_mem_info->set_profiler(profiler); + if (config->kernel_profiler && runtime_mem_info_) { + runtime_mem_info_->set_profiler(profiler); } #if defined(TI_WITH_CUDA) if (config_.arch == Arch::cuda) { @@ -98,7 +110,7 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_, CUDAContext::get_instance().set_profiler(nullptr); } CUDAContext::get_instance().set_debug(config->debug); - device_ = std::make_unique(); + device_ = std::make_shared(); } #endif } @@ -106,13 +118,14 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_, void LlvmProgramImpl::initialize_host() { // Note this cannot be placed inside LlvmProgramImpl constructor, see doc // string for init_runtime_jit_module() for more details. - llvm_context_host->init_runtime_jit_module(); + llvm_context_host_->init_runtime_jit_module(); } void LlvmProgramImpl::maybe_initialize_cuda_llvm_context() { - if (config->arch == Arch::cuda && llvm_context_device == nullptr) { - llvm_context_device = std::make_unique(Arch::cuda); - llvm_context_device->init_runtime_jit_module(); + if (config->arch == Arch::cuda && llvm_context_device_ == nullptr) { + llvm_context_device_ = + std::make_unique(this, Arch::cuda); + llvm_context_device_->init_runtime_jit_module(); } } @@ -150,13 +163,14 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree, TaichiLLVMContext *tlctx = nullptr; if (config->arch == Arch::cuda) { #if defined(TI_WITH_CUDA) - tlctx = llvm_context_device.get(); + tlctx = llvm_context_device_.get(); #else TI_NOT_IMPLEMENTED #endif } else { - tlctx = llvm_context_host.get(); + tlctx = llvm_context_host_.get(); } + auto *const runtime_jit = tlctx->runtime_jit_module; // By the time this creator is called, "this" is already destroyed. // Therefore it is necessary to capture members by values. @@ -167,9 +181,18 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree, std::size_t rounded_size = taichi::iroundup(scomp->root_size, taichi_page_size); - Ptr root_buffer = snode_tree_buffer_manager->allocate( - runtime_jit, llvm_runtime, rounded_size, taichi_page_size, tree->id(), + Ptr root_buffer = snode_tree_buffer_manager_->allocate( + runtime_jit, llvm_runtime_, rounded_size, taichi_page_size, tree->id(), result_buffer); + if (config->arch == Arch::cuda) { +#if defined(TI_WITH_CUDA) + CUDADriver::get_instance().memset(root_buffer, 0, rounded_size); +#else + TI_NOT_IMPLEMENTED +#endif + } else { + std::memset(root_buffer, 0, rounded_size); + } DeviceAllocation alloc{kDeviceNullAllocation}; @@ -185,9 +208,19 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree, snode_tree_allocs_[tree->id()] = alloc; + bool all_dense = config->demote_dense_struct_fors; + for (int i = 0; i < (int)snodes.size(); i++) { + if (snodes[i]->type != SNodeType::dense && + snodes[i]->type != SNodeType::place && + snodes[i]->type != SNodeType::root) { + all_dense = false; + break; + } + } + runtime_jit->call( - "runtime_initialize_snodes", llvm_runtime, scomp->root_size, root_id, - (int)snodes.size(), tree->id(), rounded_size, root_buffer); + "runtime_initialize_snodes", llvm_runtime_, scomp->root_size, root_id, + (int)snodes.size(), tree->id(), rounded_size, root_buffer, all_dense); for (int i = 0; i < (int)snodes.size(); i++) { if (is_gc_able(snodes[i]->type)) { @@ -203,7 +236,7 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree, } TI_TRACE("Initializing allocator for snode {} (node size {})", snode_id, node_size); - auto rt = llvm_runtime; + auto rt = llvm_runtime_; runtime_jit->call( "runtime_NodeAllocator_initialize", rt, snode_id, node_size); TI_TRACE("Allocating ambient element for snode {} (node size {})", @@ -214,34 +247,32 @@ void LlvmProgramImpl::initialize_llvm_runtime_snodes(const SNodeTree *tree, } } -void LlvmProgramImpl::materialize_snode_tree( +void LlvmProgramImpl::compile_snode_tree_types( SNodeTree *tree, - std::vector> &snode_trees_, - std::unordered_map &snodes, - uint64 *result_buffer) { + std::vector> &snode_trees) { auto *const root = tree->root(); - auto host_module = clone_struct_compiler_initial_context( - snode_trees_, llvm_context_host.get()); - std::unique_ptr scomp = std::make_unique( - host_arch(), this, std::move(host_module)); - scomp->run(*root); - - for (auto snode : scomp->snodes) { - snodes[snode->id] = snode; - } - if (arch_is_cpu(config->arch)) { - initialize_llvm_runtime_snodes(tree, scomp.get(), result_buffer); - } else if (config->arch == Arch::cuda) { - auto device_module = clone_struct_compiler_initial_context( - snode_trees_, llvm_context_device.get()); + auto host_module = clone_struct_compiler_initial_context( + snode_trees, llvm_context_host_.get()); + struct_compiler_ = std::make_unique( + host_arch(), this, std::move(host_module), tree->id()); - std::unique_ptr scomp_gpu = - std::make_unique(Arch::cuda, this, - std::move(device_module)); - scomp_gpu->run(*root); - initialize_llvm_runtime_snodes(tree, scomp_gpu.get(), result_buffer); + } else { + TI_ASSERT(config->arch == Arch::cuda); + auto device_module = clone_struct_compiler_initial_context( + snode_trees, llvm_context_device_.get()); + struct_compiler_ = std::make_unique( + Arch::cuda, this, std::move(device_module), tree->id()); } + struct_compiler_->run(*root); +} + +void LlvmProgramImpl::materialize_snode_tree( + SNodeTree *tree, + std::vector> &snode_trees_, + uint64 *result_buffer) { + compile_snode_tree_types(tree, snode_trees_); + initialize_llvm_runtime_snodes(tree, struct_compiler_.get(), result_buffer); } uint64 LlvmProgramImpl::fetch_result_uint64(int i, uint64 *result_buffer) { @@ -270,7 +301,7 @@ std::size_t LlvmProgramImpl::get_snode_num_dynamically_allocated( auto node_allocator = runtime_query("LLVMRuntime_get_node_allocators", result_buffer, - llvm_runtime, snode->id); + llvm_runtime_, snode->id); auto data_list = runtime_query("NodeManager_get_data_list", result_buffer, node_allocator); @@ -313,7 +344,7 @@ void LlvmProgramImpl::materialize_runtime(MemoryPool *memory_pool, CUDADriver::get_instance().malloc( (void **)result_buffer_ptr, sizeof(uint64) * taichi_result_buffer_entries); - const auto total_mem = runtime_mem_info->get_total_memory(); + const auto total_mem = runtime_mem_info_->get_total_memory(); if (config->device_memory_fraction == 0) { TI_ASSERT(config->device_memory_GB > 0); prealloc_size = std::size_t(config->device_memory_GB * (1UL << 30)); @@ -327,22 +358,22 @@ void LlvmProgramImpl::materialize_runtime(MemoryPool *memory_pool, Device::AllocParams preallocated_device_buffer_alloc_params; preallocated_device_buffer_alloc_params.size = prealloc_size; - preallocated_device_buffer_alloc = + preallocated_device_buffer_alloc_ = cuda_device()->allocate_memory(preallocated_device_buffer_alloc_params); cuda::CudaDevice::AllocInfo preallocated_device_buffer_alloc_info = - cuda_device()->get_alloc_info(preallocated_device_buffer_alloc); - preallocated_device_buffer = preallocated_device_buffer_alloc_info.ptr; + cuda_device()->get_alloc_info(preallocated_device_buffer_alloc_); + preallocated_device_buffer_ = preallocated_device_buffer_alloc_info.ptr; - CUDADriver::get_instance().memset(preallocated_device_buffer, 0, + CUDADriver::get_instance().memset(preallocated_device_buffer_, 0, prealloc_size); - tlctx = llvm_context_device.get(); + tlctx = llvm_context_device_.get(); #else TI_NOT_IMPLEMENTED #endif } else { *result_buffer_ptr = (uint64 *)memory_pool->allocate( sizeof(uint64) * taichi_result_buffer_entries, 8); - tlctx = llvm_context_host.get(); + tlctx = llvm_context_host_.get(); } auto *const runtime_jit = tlctx->runtime_jit_module; @@ -371,17 +402,17 @@ void LlvmProgramImpl::materialize_runtime(MemoryPool *memory_pool, runtime_jit->call( "runtime_initialize", *result_buffer_ptr, memory_pool, prealloc_size, - preallocated_device_buffer, starting_rand_state, num_rand_states, + preallocated_device_buffer_, starting_rand_state, num_rand_states, (void *)&taichi_allocate_aligned, (void *)std::printf, (void *)std::vsnprintf); TI_TRACE("LLVMRuntime initialized (excluding `root`)"); - llvm_runtime = fetch_result(taichi_result_buffer_ret_value_id, - *result_buffer_ptr); + llvm_runtime_ = fetch_result(taichi_result_buffer_ret_value_id, + *result_buffer_ptr); TI_TRACE("LLVMRuntime pointer fetched"); if (arch_use_host_memory(config->arch)) { - runtime_jit->call("runtime_get_mem_req_queue", llvm_runtime); + runtime_jit->call("runtime_get_mem_req_queue", llvm_runtime_); auto mem_req_queue = fetch_result(taichi_result_buffer_ret_value_id, *result_buffer_ptr); memory_pool->set_queue((MemRequestQueue *)mem_req_queue); @@ -389,36 +420,37 @@ void LlvmProgramImpl::materialize_runtime(MemoryPool *memory_pool, if (arch_use_host_memory(config->arch)) { runtime_jit->call( - "LLVMRuntime_initialize_thread_pool", llvm_runtime, thread_pool.get(), + "LLVMRuntime_initialize_thread_pool", llvm_runtime_, thread_pool_.get(), (void *)ThreadPool::static_run); runtime_jit->call("LLVMRuntime_set_assert_failed", - llvm_runtime, (void *)assert_failed_host); + llvm_runtime_, + (void *)assert_failed_host); } if (arch_is_cpu(config->arch)) { // Profiler functions can only be called on CPU kernels - runtime_jit->call("LLVMRuntime_set_profiler", llvm_runtime, + runtime_jit->call("LLVMRuntime_set_profiler", llvm_runtime_, profiler); runtime_jit->call( - "LLVMRuntime_set_profiler_start", llvm_runtime, + "LLVMRuntime_set_profiler_start", llvm_runtime_, (void *)&KernelProfilerBase::profiler_start); runtime_jit->call( - "LLVMRuntime_set_profiler_stop", llvm_runtime, + "LLVMRuntime_set_profiler_stop", llvm_runtime_, (void *)&KernelProfilerBase::profiler_stop); } } void LlvmProgramImpl::check_runtime_error(uint64 *result_buffer) { synchronize(); - auto tlctx = llvm_context_host.get(); - if (llvm_context_device) { + auto tlctx = llvm_context_host_.get(); + if (llvm_context_device_) { // In case there is a standalone device context (e.g. CUDA without unified // memory), use the device context instead. - tlctx = llvm_context_device.get(); + tlctx = llvm_context_device_.get(); } auto *runtime_jit_module = tlctx->runtime_jit_module; runtime_jit_module->call("runtime_retrieve_and_reset_error_code", - llvm_runtime); + llvm_runtime_); auto error_code = fetch_result(taichi_result_buffer_error_id, result_buffer); @@ -431,7 +463,7 @@ void LlvmProgramImpl::check_runtime_error(uint64 *result_buffer) { // "fetch_result" that works across device/host memroy is necessary. for (int i = 0;; i++) { runtime_jit_module->call("runtime_retrieve_error_message", - llvm_runtime, i); + llvm_runtime_, i); auto c = fetch_result(taichi_result_buffer_error_id, result_buffer); error_message_template += c; if (c == '\0') { @@ -444,7 +476,7 @@ void LlvmProgramImpl::check_runtime_error(uint64 *result_buffer) { error_message_template, [runtime_jit_module, result_buffer, this](int argument_id) { runtime_jit_module->call( - "runtime_retrieve_error_message_argument", llvm_runtime, + "runtime_retrieve_error_message_argument", llvm_runtime_, argument_id); return fetch_result(taichi_result_buffer_error_id, result_buffer); @@ -457,11 +489,11 @@ void LlvmProgramImpl::check_runtime_error(uint64 *result_buffer) { } void LlvmProgramImpl::finalize() { - if (runtime_mem_info) - runtime_mem_info->set_profiler(nullptr); + if (runtime_mem_info_) + runtime_mem_info_->set_profiler(nullptr); #if defined(TI_WITH_CUDA) - if (preallocated_device_buffer != nullptr) { - cuda_device()->dealloc_memory(preallocated_device_buffer_alloc); + if (preallocated_device_buffer_ != nullptr) { + cuda_device()->dealloc_memory(preallocated_device_buffer_alloc_); } #endif } @@ -481,7 +513,7 @@ void LlvmProgramImpl::print_memory_profiler_info( std::function visit = [&](SNode *snode, int depth) { auto element_list = runtime_query("LLVMRuntime_get_element_lists", result_buffer, - llvm_runtime, snode->id); + llvm_runtime_, snode->id); if (snode->type != SNodeType::place) { fmt::print("SNode {:10}\n", snode->get_node_type_name_hinted()); @@ -492,7 +524,7 @@ void LlvmProgramImpl::print_memory_profiler_info( auto node_allocator = runtime_query("LLVMRuntime_get_node_allocators", - result_buffer, llvm_runtime, snode->id); + result_buffer, llvm_runtime_, snode->id); if (node_allocator) { auto free_list = runtime_query("NodeManager_get_free_list", @@ -531,7 +563,7 @@ void LlvmProgramImpl::print_memory_profiler_info( } auto total_requested_memory = runtime_query( - "LLVMRuntime_get_total_requested_memory", result_buffer, llvm_runtime); + "LLVMRuntime_get_total_requested_memory", result_buffer, llvm_runtime_); fmt::print( "Total requested dynamic memory (excluding alignment padding): {:n} B\n", @@ -550,10 +582,65 @@ cpu::CpuDevice *LlvmProgramImpl::cpu_device() { return static_cast(device_.get()); } +LlvmDevice *LlvmProgramImpl::llvm_device() { + TI_ASSERT(dynamic_cast(device_.get())); + return static_cast(device_.get()); +} + DevicePtr LlvmProgramImpl::get_snode_tree_device_ptr(int tree_id) { DeviceAllocation tree_alloc = snode_tree_allocs_[tree_id]; return tree_alloc.get_ptr(); } +DeviceAllocation LlvmProgramImpl::allocate_memory_ndarray( + std::size_t alloc_size, + uint64 *result_buffer) { + TaichiLLVMContext *tlctx = nullptr; + if (llvm_context_device_) { + tlctx = llvm_context_device_.get(); + } else { + tlctx = llvm_context_host_.get(); + } + + return llvm_device()->allocate_memory_runtime( + {{alloc_size, /*host_write=*/false, /*host_read=*/false, + /*export_sharing=*/false, AllocUsage::Storage}, + config->ndarray_use_cached_allocator, + tlctx->runtime_jit_module, + get_llvm_runtime(), + result_buffer}); +} + +std::shared_ptr LlvmProgramImpl::get_device_shared() { + return device_; +} + +uint64_t *LlvmProgramImpl::get_ndarray_alloc_info_ptr( + const DeviceAllocation &alloc) { + if (config->arch == Arch::cuda) { +#if defined(TI_WITH_CUDA) + return (uint64_t *)cuda_device()->get_alloc_info(alloc).ptr; +#else + TI_NOT_IMPLEMENTED +#endif + } else { + return (uint64_t *)cpu_device()->get_alloc_info(alloc).ptr; + } +} + +void LlvmProgramImpl::fill_ndarray(const DeviceAllocation &alloc, + std::size_t size, + uint32_t data) { + auto ptr = get_ndarray_alloc_info_ptr(alloc); + if (config->arch == Arch::cuda) { +#if defined(TI_WITH_CUDA) + CUDADriver::get_instance().memsetd32((void *)ptr, data, size); +#else + TI_NOT_IMPLEMENTED +#endif + } else { + std::fill((uint32_t *)ptr, (uint32_t *)ptr + size, data); + } +} } // namespace lang } // namespace taichi diff --git a/taichi/llvm/llvm_program.h b/taichi/llvm/llvm_program.h index 3b7aefc15218e..72ef3c111b42f 100644 --- a/taichi/llvm/llvm_program.h +++ b/taichi/llvm/llvm_program.h @@ -1,4 +1,5 @@ #pragma once +#include "taichi/llvm/llvm_device.h" #include "taichi/system/snode_tree_buffer_manager.h" #include "taichi/inc/constants.h" #include "taichi/program/compile_config.h" @@ -6,7 +7,6 @@ #include "taichi/llvm/llvm_context.h" #include "taichi/runtime/runtime.h" #include "taichi/system/threading.h" -#include "llvm/IR/Module.h" #include "taichi/struct/struct.h" #include "taichi/struct/struct_llvm.h" #include "taichi/program/snode_expr_utils.h" @@ -18,6 +18,10 @@ #include +namespace llvm { +class Module; +} + namespace taichi { namespace lang { class StructCompiler; @@ -45,22 +49,25 @@ class LlvmProgramImpl : public ProgramImpl { TaichiLLVMContext *get_llvm_context(Arch arch) { if (arch_is_cpu(arch)) { - return llvm_context_host.get(); + return llvm_context_host_.get(); } else { - return llvm_context_device.get(); + return llvm_context_device_.get(); } } LLVMRuntime *get_llvm_runtime() { - return static_cast(llvm_runtime); + return static_cast(llvm_runtime_); } FunctionType compile(Kernel *kernel, OffloadedStmt *offloaded) override; + void compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) override; + void materialize_snode_tree( SNodeTree *tree, std::vector> &snode_trees_, - std::unordered_map &snodes, uint64 *result_buffer) override; template @@ -80,8 +87,10 @@ class LlvmProgramImpl : public ProgramImpl { SNode *snode, uint64 *result_buffer) override; - virtual void destroy_snode_tree(SNodeTree *snode_tree) override { - snode_tree_buffer_manager->destroy(snode_tree); + void destroy_snode_tree(SNodeTree *snode_tree) override { + get_llvm_context(host_arch()) + ->delete_functions_of_snode_tree(snode_tree->id()); + snode_tree_buffer_manager_->destroy(snode_tree); } void print_memory_profiler_info( @@ -94,6 +103,17 @@ class LlvmProgramImpl : public ProgramImpl { void finalize(); + DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, + uint64 *result_buffer) override; + + uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc); + + std::shared_ptr get_device_shared() override; + + void fill_ndarray(const DeviceAllocation &alloc, + std::size_t size, + uint32_t data); + private: std::unique_ptr clone_struct_compiler_initial_context( const std::vector> &snode_trees_, @@ -113,14 +133,14 @@ class LlvmProgramImpl : public ProgramImpl { TI_ASSERT(arch_uses_llvm(config->arch)); TaichiLLVMContext *tlctx = nullptr; - if (llvm_context_device) { - tlctx = llvm_context_device.get(); + if (llvm_context_device_) { + tlctx = llvm_context_device_.get(); } else { - tlctx = llvm_context_host.get(); + tlctx = llvm_context_host_.get(); } auto runtime = tlctx->runtime_jit_module; - runtime->call("runtime_" + key, llvm_runtime, + runtime->call("runtime_" + key, llvm_runtime_, std::forward(args)...); return taichi_union_cast_with_different_sizes(fetch_result_uint64( taichi_result_buffer_runtime_query_id, result_buffer)); @@ -132,28 +152,30 @@ class LlvmProgramImpl : public ProgramImpl { TI_NOT_IMPLEMENTED; } - virtual Device *get_compute_device() override { + Device *get_compute_device() override { return device_.get(); } DevicePtr get_snode_tree_device_ptr(int tree_id) override; private: - std::unique_ptr llvm_context_host{nullptr}; - std::unique_ptr llvm_context_device{nullptr}; - std::unique_ptr thread_pool{nullptr}; - std::unique_ptr runtime_mem_info{nullptr}; - std::unique_ptr snode_tree_buffer_manager{nullptr}; - void *llvm_runtime{nullptr}; - void *preallocated_device_buffer{nullptr}; // TODO: move to memory allocator + std::unique_ptr llvm_context_host_{nullptr}; + std::unique_ptr llvm_context_device_{nullptr}; + std::unique_ptr thread_pool_{nullptr}; + std::unique_ptr runtime_mem_info_{nullptr}; + std::unique_ptr snode_tree_buffer_manager_{nullptr}; + std::unique_ptr struct_compiler_{nullptr}; + void *llvm_runtime_{nullptr}; + void *preallocated_device_buffer_{nullptr}; // TODO: move to memory allocator - DeviceAllocation preallocated_device_buffer_alloc{kDeviceNullAllocation}; + DeviceAllocation preallocated_device_buffer_alloc_{kDeviceNullAllocation}; std::unordered_map snode_tree_allocs_; - std::unique_ptr device_; + std::shared_ptr device_{nullptr}; cuda::CudaDevice *cuda_device(); cpu::CpuDevice *cpu_device(); + LlvmDevice *llvm_device(); }; } // namespace lang } // namespace taichi diff --git a/taichi/math/arithmetic.h b/taichi/math/arithmetic.h index 34ddb762c0c8e..abf0f1d2db628 100644 --- a/taichi/math/arithmetic.h +++ b/taichi/math/arithmetic.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "taichi/common/trait.h" namespace taichi { diff --git a/taichi/math/array_2d.h b/taichi/math/array_2d.h index 122f8a7a584b1..4a995f2b42b10 100644 --- a/taichi/math/array_2d.h +++ b/taichi/math/array_2d.h @@ -19,7 +19,7 @@ TI_NAMESPACE_BEGIN template <> class IndexND<2> { private: - int x[2], y[2]; + int x_[2], y_[2]; public: using Index = IndexND<2>; @@ -37,28 +37,28 @@ class IndexND<2> { int y0, int y1, Vector2 storage_offset = Vector2(0.5f, 0.5f)) { - x[0] = x0; - x[1] = x1; - y[0] = y0; - y[1] = y1; - i = x[0]; - j = y[0]; + x_[0] = x0; + x_[1] = x1; + y_[0] = y0; + y_[1] = y1; + i = x_[0]; + j = y_[0]; // offset = 0; - stride = y[1] - y[0]; + stride = y_[1] - y_[0]; this->storage_offset = storage_offset; } IndexND(Vector2i start, Vector2i end, Vector2 storage_offset = Vector2(0.5f, 0.5f)) { - x[0] = start[0]; - x[1] = end[0]; - y[0] = start[1]; - y[1] = end[1]; - i = x[0]; - j = y[0]; + x_[0] = start[0]; + x_[1] = end[0]; + y_[0] = start[1]; + y_[1] = end[1]; + i = x_[0]; + j = y_[0]; // offset = 0; - stride = y[1] - y[0]; + stride = y_[1] - y_[0]; this->storage_offset = storage_offset; } @@ -70,10 +70,10 @@ class IndexND<2> { void next() { j++; // offset++; - if (j == y[1]) { - j = y[0]; + if (j == y_[1]) { + j = y_[0]; i++; - if (i == x[1]) { + if (i == x_[1]) { } } } @@ -92,8 +92,8 @@ class IndexND<2> { } Index &to_end() { - i = x[1]; - j = y[0]; + i = x_[1]; + j = y_[0]; // offset = (x[1] - x[0]) * (y[1] - y[0]); return *this; } @@ -146,10 +146,10 @@ typedef IndexND<2> Index2D; template <> class RegionND<2> { private: - int x[2], y[2]; - Index2D index_begin; - Index2D index_end; - Vector2 storage_offset; + int x_[2], y_[2]; + Index2D index_begin_; + Index2D index_end_; + Vector2 storage_offset_; public: using Region = RegionND<2>; @@ -162,41 +162,41 @@ class RegionND<2> { int y0, int y1, Vector2 storage_offset = Vector2(0.5f, 0.5f)) { - x[0] = x0; - x[1] = x1; - y[0] = y0; - y[1] = y1; - index_begin = Index2D(x0, x1, y0, y1, storage_offset); - index_end = Index2D(x0, x1, y0, y1, storage_offset).to_end(); - this->storage_offset = storage_offset; + x_[0] = x0; + x_[1] = x1; + y_[0] = y0; + y_[1] = y1; + index_begin_ = Index2D(x0, x1, y0, y1, storage_offset); + index_end_ = Index2D(x0, x1, y0, y1, storage_offset).to_end(); + this->storage_offset_ = storage_offset; } RegionND(Vector2i start, Vector2i end, Vector2 storage_offset = Vector2(0.5f, 0.5f)) { - x[0] = start[0]; - x[1] = end[0]; - y[0] = start[1]; - y[1] = end[1]; - index_begin = Index2D(start, end, storage_offset); - index_end = Index2D(start, end, storage_offset).to_end(); - this->storage_offset = storage_offset; + x_[0] = start[0]; + x_[1] = end[0]; + y_[0] = start[1]; + y_[1] = end[1]; + index_begin_ = Index2D(start, end, storage_offset); + index_end_ = Index2D(start, end, storage_offset).to_end(); + this->storage_offset_ = storage_offset; } const Index2D begin() const { - return index_begin; + return index_begin_; } Index2D begin() { - return index_begin; + return index_begin_; } const Index2D end() const { - return index_end; + return index_end_; } Index2D end() { - return index_end; + return index_end_; } }; diff --git a/taichi/math/scalar.h b/taichi/math/scalar.h index 2f4164ff7817a..1ce736428050a 100644 --- a/taichi/math/scalar.h +++ b/taichi/math/scalar.h @@ -161,12 +161,13 @@ TI_FORCE_INLINE bool abnormal(T m) noexcept { inline int64 get_largest_pot(int64 a) noexcept { TI_ASSERT_INFO(a > 0, "a should be positive, instead of " + std::to_string(a)); - // TODO: optimize - int64 i = 1; - while (i <= a / 2) { - i *= 2; + + /* This code was copied from https://stackoverflow.com/a/20207950 and edited + It uses loop unrolling, which all (modern) compilers will do. */ + for (int64 i = 1; i < 64; i *= 2) { + a |= (a >> i); } - return i; + return a - (a >> 1); } TI_NAMESPACE_END diff --git a/taichi/math/svd.h b/taichi/math/svd.h index 5521486441099..584c61e378cc1 100644 --- a/taichi/math/svd.h +++ b/taichi/math/svd.h @@ -1,3 +1,5 @@ +#include +#include "taichi/ir/frontend_ir.h" #include "taichi/ir/ir.h" TLANG_NAMESPACE_BEGIN @@ -39,7 +41,8 @@ std::tuple -sifakis_svd_export(const Expr &a00, +sifakis_svd_export(ASTBuilder *ast_builder, + const Expr &a00, const Expr &a01, const Expr &a02, const Expr &a10, @@ -54,802 +57,842 @@ sifakis_svd_export(const Expr &a00, constexpr Tf Sine_Pi_Over_Eight = 0.3826834323650897f; constexpr Tf Cosine_Pi_Over_Eight = 0.9238795325112867f; - auto Sfour_gamma_squared = Var(Tf(0.0)); - auto Ssine_pi_over_eight = Var(Tf(0.0)); - auto Scosine_pi_over_eight = Var(Tf(0.0)); - auto Sone_half = Var(Tf(0.0)); - auto Sone = Var(Tf(0.0)); - auto Stiny_number = Var(Tf(0.0)); - auto Ssmall_number = Var(Tf(0.0)); - auto Sa11 = Var(Tf(0.0)); - auto Sa21 = Var(Tf(0.0)); - auto Sa31 = Var(Tf(0.0)); - auto Sa12 = Var(Tf(0.0)); - auto Sa22 = Var(Tf(0.0)); - auto Sa32 = Var(Tf(0.0)); - auto Sa13 = Var(Tf(0.0)); - auto Sa23 = Var(Tf(0.0)); - auto Sa33 = Var(Tf(0.0)); - auto Sv11 = Var(Tf(0.0)); - auto Sv21 = Var(Tf(0.0)); - auto Sv31 = Var(Tf(0.0)); - auto Sv12 = Var(Tf(0.0)); - auto Sv22 = Var(Tf(0.0)); - auto Sv32 = Var(Tf(0.0)); - auto Sv13 = Var(Tf(0.0)); - auto Sv23 = Var(Tf(0.0)); - auto Sv33 = Var(Tf(0.0)); - auto Su11 = Var(Tf(0.0)); - auto Su21 = Var(Tf(0.0)); - auto Su31 = Var(Tf(0.0)); - auto Su12 = Var(Tf(0.0)); - auto Su22 = Var(Tf(0.0)); - auto Su32 = Var(Tf(0.0)); - auto Su13 = Var(Tf(0.0)); - auto Su23 = Var(Tf(0.0)); - auto Su33 = Var(Tf(0.0)); - auto Sc = Var(Tf(0.0)); - auto Ss = Var(Tf(0.0)); - auto Sch = Var(Tf(0.0)); - auto Ssh = Var(Tf(0.0)); - auto Stmp1 = Var(Tf(0.0)); - auto Stmp2 = Var(Tf(0.0)); - auto Stmp3 = Var(Tf(0.0)); - auto Stmp4 = Var(Tf(0.0)); - auto Stmp5 = Var(Tf(0.0)); - auto Sqvs = Var(Tf(0.0)); - auto Sqvvx = Var(Tf(0.0)); - auto Sqvvy = Var(Tf(0.0)); - auto Sqvvz = Var(Tf(0.0)); - auto Ss11 = Var(Tf(0.0)); - auto Ss21 = Var(Tf(0.0)); - auto Ss31 = Var(Tf(0.0)); - auto Ss22 = Var(Tf(0.0)); - auto Ss32 = Var(Tf(0.0)); - auto Ss33 = Var(Tf(0.0)); - Sfour_gamma_squared = Expr(Four_Gamma_Squared); - Ssine_pi_over_eight = Expr(Sine_Pi_Over_Eight); - Scosine_pi_over_eight = Expr(Cosine_Pi_Over_Eight); - Sone_half = Tf(0.5f); - Sone = Tf(1.0f); - Stiny_number = Tf(1.e-20f); - Ssmall_number = Tf(1.e-12f); - Sa11 = load_if_ptr(a00); - Sa21 = load_if_ptr(a10); - Sa31 = load_if_ptr(a20); - Sa12 = load_if_ptr(a01); - Sa22 = load_if_ptr(a11); - Sa32 = load_if_ptr(a21); - Sa13 = load_if_ptr(a02); - Sa23 = load_if_ptr(a12); - Sa33 = load_if_ptr(a22); - Sqvs = Tf(1.0f); - Sqvvx = Tf(0.0f); - Sqvvy = Tf(0.0f); - Sqvvz = Tf(0.0f); - Ss11 = Sa11 * Sa11; - Stmp1 = Sa21 * Sa21; - Ss11 = Stmp1 + Ss11; - Stmp1 = Sa31 * Sa31; - Ss11 = Stmp1 + Ss11; - Ss21 = Sa12 * Sa11; - Stmp1 = Sa22 * Sa21; - Ss21 = Stmp1 + Ss21; - Stmp1 = Sa32 * Sa31; - Ss21 = Stmp1 + Ss21; - Ss31 = Sa13 * Sa11; - Stmp1 = Sa23 * Sa21; - Ss31 = Stmp1 + Ss31; - Stmp1 = Sa33 * Sa31; - Ss31 = Stmp1 + Ss31; - Ss22 = Sa12 * Sa12; - Stmp1 = Sa22 * Sa22; - Ss22 = Stmp1 + Ss22; - Stmp1 = Sa32 * Sa32; - Ss22 = Stmp1 + Ss22; - Ss32 = Sa13 * Sa12; - Stmp1 = Sa23 * Sa22; - Ss32 = Stmp1 + Ss32; - Stmp1 = Sa33 * Sa32; - Ss32 = Stmp1 + Ss32; - Ss33 = Sa13 * Sa13; - Stmp1 = Sa23 * Sa23; - Ss33 = Stmp1 + Ss33; - Stmp1 = Sa33 * Sa33; - Ss33 = Stmp1 + Ss33; - StrictlySerialize(); - For(0, num_iters, [&](Expr sweep) { - Ssh = Ss21 * Sone_half; - Stmp5 = Ss11 - Ss22; - Stmp2 = Ssh * Ssh; - Stmp1 = bit_cast(select(Stmp2 >= Stiny_number, - Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Ssh = svd_bitwise_and(Stmp1, Ssh); - Sch = svd_bitwise_and(Stmp1, Stmp5); - Stmp2 = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sone); - Sch = svd_bitwise_or(Sch, Stmp2); - Stmp1 = Ssh * Ssh; - Stmp2 = Sch * Sch; - Stmp3 = Stmp1 + Stmp2; - Stmp4 = rsqrt(Stmp3); - Ssh = Stmp4 * Ssh; - Sch = Stmp4 * Sch; - Stmp1 = Sfour_gamma_squared * Stmp1; - Stmp1 = bit_cast( - select(Stmp2 <= Stmp1, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp2 = svd_bitwise_and(Ssine_pi_over_eight, Stmp1); - Ssh = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Ssh); - Ssh = svd_bitwise_or(Ssh, Stmp2); - Stmp2 = svd_bitwise_and(Scosine_pi_over_eight, Stmp1); - Sch = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sch); - Sch = svd_bitwise_or(Sch, Stmp2); - Stmp1 = Ssh * Ssh; - Stmp2 = Sch * Sch; - Sc = Stmp2 - Stmp1; - Ss = Sch * Ssh; - Ss = Ss + Ss; - Stmp3 = Stmp1 + Stmp2; - Ss33 = Ss33 * Stmp3; - Ss31 = Ss31 * Stmp3; - Ss32 = Ss32 * Stmp3; - Ss33 = Ss33 * Stmp3; - Stmp1 = Ss * Ss31; - Stmp2 = Ss * Ss32; - Ss31 = Sc * Ss31; - Ss32 = Sc * Ss32; - Ss31 = Stmp2 + Ss31; - Ss32 = Ss32 - Stmp1; - Stmp2 = Ss * Ss; - Stmp1 = Ss22 * Stmp2; - Stmp3 = Ss11 * Stmp2; - Stmp4 = Sc * Sc; - Ss11 = Ss11 * Stmp4; - Ss22 = Ss22 * Stmp4; - Ss11 = Ss11 + Stmp1; - Ss22 = Ss22 + Stmp3; - Stmp4 = Stmp4 - Stmp2; - Stmp2 = Ss21 + Ss21; - Ss21 = Ss21 * Stmp4; - Stmp4 = Sc * Ss; - Stmp2 = Stmp2 * Stmp4; - Stmp5 = Stmp5 * Stmp4; - Ss11 = Ss11 + Stmp2; - Ss21 = Ss21 - Stmp5; - Ss22 = Ss22 - Stmp2; - Stmp1 = Ssh * Sqvvx; - Stmp2 = Ssh * Sqvvy; - Stmp3 = Ssh * Sqvvz; - Ssh = Ssh * Sqvs; - Sqvs = Sch * Sqvs; - Sqvvx = Sch * Sqvvx; - Sqvvy = Sch * Sqvvy; - Sqvvz = Sch * Sqvvz; - Sqvvz = Sqvvz + Ssh; - Sqvs = Sqvs - Stmp3; - Sqvvx = Sqvvx + Stmp2; - Sqvvy = Sqvvy - Stmp1; - Ssh = Ss32 * Sone_half; - Stmp5 = Ss22 - Ss33; - Stmp2 = Ssh * Ssh; - Stmp1 = bit_cast(select(Stmp2 >= Stiny_number, - Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Ssh = svd_bitwise_and(Stmp1, Ssh); - Sch = svd_bitwise_and(Stmp1, Stmp5); - Stmp2 = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sone); - Sch = svd_bitwise_or(Sch, Stmp2); - Stmp1 = Ssh * Ssh; - Stmp2 = Sch * Sch; - Stmp3 = Stmp1 + Stmp2; - Stmp4 = rsqrt(Stmp3); - Ssh = Stmp4 * Ssh; - Sch = Stmp4 * Sch; - Stmp1 = Sfour_gamma_squared * Stmp1; - Stmp1 = bit_cast( - select(Stmp2 <= Stmp1, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp2 = svd_bitwise_and(Ssine_pi_over_eight, Stmp1); - Ssh = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Ssh); - Ssh = svd_bitwise_or(Ssh, Stmp2); - Stmp2 = svd_bitwise_and(Scosine_pi_over_eight, Stmp1); - Sch = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sch); - Sch = svd_bitwise_or(Sch, Stmp2); - Stmp1 = Ssh * Ssh; - Stmp2 = Sch * Sch; - Sc = Stmp2 - Stmp1; - Ss = Sch * Ssh; - Ss = Ss + Ss; - Stmp3 = Stmp1 + Stmp2; - Ss11 = Ss11 * Stmp3; - Ss21 = Ss21 * Stmp3; - Ss31 = Ss31 * Stmp3; - Ss11 = Ss11 * Stmp3; - Stmp1 = Ss * Ss21; - Stmp2 = Ss * Ss31; - Ss21 = Sc * Ss21; - Ss31 = Sc * Ss31; - Ss21 = Stmp2 + Ss21; - Ss31 = Ss31 - Stmp1; - Stmp2 = Ss * Ss; - Stmp1 = Ss33 * Stmp2; - Stmp3 = Ss22 * Stmp2; - Stmp4 = Sc * Sc; - Ss22 = Ss22 * Stmp4; - Ss33 = Ss33 * Stmp4; - Ss22 = Ss22 + Stmp1; - Ss33 = Ss33 + Stmp3; - Stmp4 = Stmp4 - Stmp2; - Stmp2 = Ss32 + Ss32; - Ss32 = Ss32 * Stmp4; - Stmp4 = Sc * Ss; - Stmp2 = Stmp2 * Stmp4; - Stmp5 = Stmp5 * Stmp4; - Ss22 = Ss22 + Stmp2; - Ss32 = Ss32 - Stmp5; - Ss33 = Ss33 - Stmp2; - Stmp1 = Ssh * Sqvvx; - Stmp2 = Ssh * Sqvvy; - Stmp3 = Ssh * Sqvvz; - Ssh = Ssh * Sqvs; - Sqvs = Sch * Sqvs; - Sqvvx = Sch * Sqvvx; - Sqvvy = Sch * Sqvvy; - Sqvvz = Sch * Sqvvz; - Sqvvx = Sqvvx + Ssh; - Sqvs = Sqvs - Stmp1; - Sqvvy = Sqvvy + Stmp3; - Sqvvz = Sqvvz - Stmp2; - Ssh = Ss31 * Sone_half; - Stmp5 = Ss33 - Ss11; - Stmp2 = Ssh * Ssh; - Stmp1 = bit_cast(select(Stmp2 >= Stiny_number, - Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Ssh = svd_bitwise_and(Stmp1, Ssh); - Sch = svd_bitwise_and(Stmp1, Stmp5); - Stmp2 = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sone); - Sch = svd_bitwise_or(Sch, Stmp2); - Stmp1 = Ssh * Ssh; - Stmp2 = Sch * Sch; - Stmp3 = Stmp1 + Stmp2; - Stmp4 = rsqrt(Stmp3); - Ssh = Stmp4 * Ssh; - Sch = Stmp4 * Sch; - Stmp1 = Sfour_gamma_squared * Stmp1; - Stmp1 = bit_cast( - select(Stmp2 <= Stmp1, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp2 = svd_bitwise_and(Ssine_pi_over_eight, Stmp1); - Ssh = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Ssh); - Ssh = svd_bitwise_or(Ssh, Stmp2); - Stmp2 = svd_bitwise_and(Scosine_pi_over_eight, Stmp1); - Sch = svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sch); - Sch = svd_bitwise_or(Sch, Stmp2); - Stmp1 = Ssh * Ssh; - Stmp2 = Sch * Sch; - Sc = Stmp2 - Stmp1; - Ss = Sch * Ssh; - Ss = Ss + Ss; - Stmp3 = Stmp1 + Stmp2; - Ss22 = Ss22 * Stmp3; - Ss32 = Ss32 * Stmp3; - Ss21 = Ss21 * Stmp3; - Ss22 = Ss22 * Stmp3; - Stmp1 = Ss * Ss32; - Stmp2 = Ss * Ss21; - Ss32 = Sc * Ss32; - Ss21 = Sc * Ss21; - Ss32 = Stmp2 + Ss32; - Ss21 = Ss21 - Stmp1; - Stmp2 = Ss * Ss; - Stmp1 = Ss11 * Stmp2; - Stmp3 = Ss33 * Stmp2; - Stmp4 = Sc * Sc; - Ss33 = Ss33 * Stmp4; - Ss11 = Ss11 * Stmp4; - Ss33 = Ss33 + Stmp1; - Ss11 = Ss11 + Stmp3; - Stmp4 = Stmp4 - Stmp2; - Stmp2 = Ss31 + Ss31; - Ss31 = Ss31 * Stmp4; - Stmp4 = Sc * Ss; - Stmp2 = Stmp2 * Stmp4; - Stmp5 = Stmp5 * Stmp4; - Ss33 = Ss33 + Stmp2; - Ss31 = Ss31 - Stmp5; - Ss11 = Ss11 - Stmp2; - Stmp1 = Ssh * Sqvvx; - Stmp2 = Ssh * Sqvvy; - Stmp3 = Ssh * Sqvvz; - Ssh = Ssh * Sqvs; - Sqvs = Sch * Sqvs; - Sqvvx = Sch * Sqvvx; - Sqvvy = Sch * Sqvvy; - Sqvvz = Sch * Sqvvz; - Sqvvy = Sqvvy + Ssh; - Sqvs = Sqvs - Stmp2; - Sqvvz = Sqvvz + Stmp1; - Sqvvx = Sqvvx - Stmp3; + auto Var = + std::bind(&ASTBuilder::make_var, ast_builder, std::placeholders::_1); + + auto Sfour_gamma_squared = Var(Expr(Tf(0.0))); + auto Ssine_pi_over_eight = Var(Expr(Tf(0.0))); + auto Scosine_pi_over_eight = Var(Expr(Tf(0.0))); + auto Sone_half = Var(Expr(Tf(0.0))); + auto Sone = Var(Expr(Tf(0.0))); + auto Stiny_number = Var(Expr(Tf(0.0))); + auto Ssmall_number = Var(Expr(Tf(0.0))); + auto Sa11 = Var(Expr(Tf(0.0))); + auto Sa21 = Var(Expr(Tf(0.0))); + auto Sa31 = Var(Expr(Tf(0.0))); + auto Sa12 = Var(Expr(Tf(0.0))); + auto Sa22 = Var(Expr(Tf(0.0))); + auto Sa32 = Var(Expr(Tf(0.0))); + auto Sa13 = Var(Expr(Tf(0.0))); + auto Sa23 = Var(Expr(Tf(0.0))); + auto Sa33 = Var(Expr(Tf(0.0))); + auto Sv11 = Var(Expr(Tf(0.0))); + auto Sv21 = Var(Expr(Tf(0.0))); + auto Sv31 = Var(Expr(Tf(0.0))); + auto Sv12 = Var(Expr(Tf(0.0))); + auto Sv22 = Var(Expr(Tf(0.0))); + auto Sv32 = Var(Expr(Tf(0.0))); + auto Sv13 = Var(Expr(Tf(0.0))); + auto Sv23 = Var(Expr(Tf(0.0))); + auto Sv33 = Var(Expr(Tf(0.0))); + auto Su11 = Var(Expr(Tf(0.0))); + auto Su21 = Var(Expr(Tf(0.0))); + auto Su31 = Var(Expr(Tf(0.0))); + auto Su12 = Var(Expr(Tf(0.0))); + auto Su22 = Var(Expr(Tf(0.0))); + auto Su32 = Var(Expr(Tf(0.0))); + auto Su13 = Var(Expr(Tf(0.0))); + auto Su23 = Var(Expr(Tf(0.0))); + auto Su33 = Var(Expr(Tf(0.0))); + auto Sc = Var(Expr(Tf(0.0))); + auto Ss = Var(Expr(Tf(0.0))); + auto Sch = Var(Expr(Tf(0.0))); + auto Ssh = Var(Expr(Tf(0.0))); + auto Stmp1 = Var(Expr(Tf(0.0))); + auto Stmp2 = Var(Expr(Tf(0.0))); + auto Stmp3 = Var(Expr(Tf(0.0))); + auto Stmp4 = Var(Expr(Tf(0.0))); + auto Stmp5 = Var(Expr(Tf(0.0))); + auto Sqvs = Var(Expr(Tf(0.0))); + auto Sqvvx = Var(Expr(Tf(0.0))); + auto Sqvvy = Var(Expr(Tf(0.0))); + auto Sqvvz = Var(Expr(Tf(0.0))); + auto Ss11 = Var(Expr(Tf(0.0))); + auto Ss21 = Var(Expr(Tf(0.0))); + auto Ss31 = Var(Expr(Tf(0.0))); + auto Ss22 = Var(Expr(Tf(0.0))); + auto Ss32 = Var(Expr(Tf(0.0))); + auto Ss33 = Var(Expr(Tf(0.0))); + ast_builder->insert_assignment(Sfour_gamma_squared, Expr(Four_Gamma_Squared)); + ast_builder->insert_assignment(Ssine_pi_over_eight, Expr(Sine_Pi_Over_Eight)); + ast_builder->insert_assignment(Scosine_pi_over_eight, + Expr(Cosine_Pi_Over_Eight)); + ast_builder->insert_assignment(Sone_half, Expr(Tf(0.5f))); + ast_builder->insert_assignment(Sone, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Stiny_number, Expr(Tf(1.e-20f))); + ast_builder->insert_assignment(Ssmall_number, Expr(Tf(1.e-12f))); + ast_builder->insert_assignment(Sa11, a00); + ast_builder->insert_assignment(Sa21, a10); + ast_builder->insert_assignment(Sa31, a20); + ast_builder->insert_assignment(Sa12, a01); + ast_builder->insert_assignment(Sa22, a11); + ast_builder->insert_assignment(Sa32, a21); + ast_builder->insert_assignment(Sa13, a02); + ast_builder->insert_assignment(Sa23, a12); + ast_builder->insert_assignment(Sa33, a22); + ast_builder->insert_assignment(Sqvs, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Sqvvx, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Sqvvy, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Sqvvz, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Ss11, Sa11 * Sa11); + ast_builder->insert_assignment(Stmp1, Sa21 * Sa21); + ast_builder->insert_assignment(Ss11, Stmp1 + Ss11); + ast_builder->insert_assignment(Stmp1, Sa31 * Sa31); + ast_builder->insert_assignment(Ss11, Stmp1 + Ss11); + ast_builder->insert_assignment(Ss21, Sa12 * Sa11); + ast_builder->insert_assignment(Stmp1, Sa22 * Sa21); + ast_builder->insert_assignment(Ss21, Stmp1 + Ss21); + ast_builder->insert_assignment(Stmp1, Sa32 * Sa31); + ast_builder->insert_assignment(Ss21, Stmp1 + Ss21); + ast_builder->insert_assignment(Ss31, Sa13 * Sa11); + ast_builder->insert_assignment(Stmp1, Sa23 * Sa21); + ast_builder->insert_assignment(Ss31, Stmp1 + Ss31); + ast_builder->insert_assignment(Stmp1, Sa33 * Sa31); + ast_builder->insert_assignment(Ss31, Stmp1 + Ss31); + ast_builder->insert_assignment(Ss22, Sa12 * Sa12); + ast_builder->insert_assignment(Stmp1, Sa22 * Sa22); + ast_builder->insert_assignment(Ss22, Stmp1 + Ss22); + ast_builder->insert_assignment(Stmp1, Sa32 * Sa32); + ast_builder->insert_assignment(Ss22, Stmp1 + Ss22); + ast_builder->insert_assignment(Ss32, Sa13 * Sa12); + ast_builder->insert_assignment(Stmp1, Sa23 * Sa22); + ast_builder->insert_assignment(Ss32, Stmp1 + Ss32); + ast_builder->insert_assignment(Stmp1, Sa33 * Sa32); + ast_builder->insert_assignment(Ss32, Stmp1 + Ss32); + ast_builder->insert_assignment(Ss33, Sa13 * Sa13); + ast_builder->insert_assignment(Stmp1, Sa23 * Sa23); + ast_builder->insert_assignment(Ss33, Stmp1 + Ss33); + ast_builder->insert_assignment(Stmp1, Sa33 * Sa33); + ast_builder->insert_assignment(Ss33, Stmp1 + Ss33); + ast_builder->strictly_serialize(); + ast_builder->insert_for(Expr(0), Expr(num_iters), [&](Expr sweep) { + ast_builder->insert_assignment(Ssh, Ss21 * Sone_half); + ast_builder->insert_assignment(Stmp5, Ss11 - Ss22); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment( + Stmp1, bit_cast(select(Stmp2 >= Stiny_number, + Expr(Ti(int32(0xffffffff))), Expr(Ti(0))))); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Stmp1, Ssh)); + ast_builder->insert_assignment(Sch, svd_bitwise_and(Stmp1, Stmp5)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sone)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp2)); + ast_builder->insert_assignment(Stmp1, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Sch * Sch); + ast_builder->insert_assignment(Stmp3, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp4, rsqrt(Stmp3)); + ast_builder->insert_assignment(Ssh, Stmp4 * Ssh); + ast_builder->insert_assignment(Sch, Stmp4 * Sch); + ast_builder->insert_assignment(Stmp1, Sfour_gamma_squared * Stmp1); + ast_builder->insert_assignment( + Stmp1, bit_cast(select(Stmp2 <= Stmp1, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Ssine_pi_over_eight, Stmp1)); + ast_builder->insert_assignment( + Ssh, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Ssh)); + ast_builder->insert_assignment(Ssh, svd_bitwise_or(Ssh, Stmp2)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Scosine_pi_over_eight, Stmp1)); + ast_builder->insert_assignment( + Sch, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sch)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp2)); + ast_builder->insert_assignment(Stmp1, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Sch * Sch); + ast_builder->insert_assignment(Sc, Stmp2 - Stmp1); + ast_builder->insert_assignment(Ss, Sch * Ssh); + ast_builder->insert_assignment(Ss, Ss + Ss); + ast_builder->insert_assignment(Stmp3, Stmp1 + Stmp2); + ast_builder->insert_assignment(Ss33, Ss33 * Stmp3); + ast_builder->insert_assignment(Ss31, Ss31 * Stmp3); + ast_builder->insert_assignment(Ss32, Ss32 * Stmp3); + ast_builder->insert_assignment(Ss33, Ss33 * Stmp3); + ast_builder->insert_assignment(Stmp1, Ss * Ss31); + ast_builder->insert_assignment(Stmp2, Ss * Ss32); + ast_builder->insert_assignment(Ss31, Sc * Ss31); + ast_builder->insert_assignment(Ss32, Sc * Ss32); + ast_builder->insert_assignment(Ss31, Stmp2 + Ss31); + ast_builder->insert_assignment(Ss32, Ss32 - Stmp1); + ast_builder->insert_assignment(Stmp2, Ss * Ss); + ast_builder->insert_assignment(Stmp1, Ss22 * Stmp2); + ast_builder->insert_assignment(Stmp3, Ss11 * Stmp2); + ast_builder->insert_assignment(Stmp4, Sc * Sc); + ast_builder->insert_assignment(Ss11, Ss11 * Stmp4); + ast_builder->insert_assignment(Ss22, Ss22 * Stmp4); + ast_builder->insert_assignment(Ss11, Ss11 + Stmp1); + ast_builder->insert_assignment(Ss22, Ss22 + Stmp3); + ast_builder->insert_assignment(Stmp4, Stmp4 - Stmp2); + ast_builder->insert_assignment(Stmp2, Ss21 + Ss21); + ast_builder->insert_assignment(Ss21, Ss21 * Stmp4); + ast_builder->insert_assignment(Stmp4, Sc * Ss); + ast_builder->insert_assignment(Stmp2, Stmp2 * Stmp4); + ast_builder->insert_assignment(Stmp5, Stmp5 * Stmp4); + ast_builder->insert_assignment(Ss11, Ss11 + Stmp2); + ast_builder->insert_assignment(Ss21, Ss21 - Stmp5); + ast_builder->insert_assignment(Ss22, Ss22 - Stmp2); + ast_builder->insert_assignment(Stmp1, Ssh * Sqvvx); + ast_builder->insert_assignment(Stmp2, Ssh * Sqvvy); + ast_builder->insert_assignment(Stmp3, Ssh * Sqvvz); + ast_builder->insert_assignment(Ssh, Ssh * Sqvs); + ast_builder->insert_assignment(Sqvs, Sch * Sqvs); + ast_builder->insert_assignment(Sqvvx, Sch * Sqvvx); + ast_builder->insert_assignment(Sqvvy, Sch * Sqvvy); + ast_builder->insert_assignment(Sqvvz, Sch * Sqvvz); + ast_builder->insert_assignment(Sqvvz, Sqvvz + Ssh); + ast_builder->insert_assignment(Sqvs, Sqvs - Stmp3); + ast_builder->insert_assignment(Sqvvx, Sqvvx + Stmp2); + ast_builder->insert_assignment(Sqvvy, Sqvvy - Stmp1); + ast_builder->insert_assignment(Ssh, Ss32 * Sone_half); + ast_builder->insert_assignment(Stmp5, Ss22 - Ss33); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment( + Stmp1, bit_cast(select(Stmp2 >= Stiny_number, + Expr(Ti(int32(0xffffffff))), Expr(Ti(0))))); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Stmp1, Ssh)); + ast_builder->insert_assignment(Sch, svd_bitwise_and(Stmp1, Stmp5)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sone)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp2)); + ast_builder->insert_assignment(Stmp1, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Sch * Sch); + ast_builder->insert_assignment(Stmp3, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp4, rsqrt(Stmp3)); + ast_builder->insert_assignment(Ssh, Stmp4 * Ssh); + ast_builder->insert_assignment(Sch, Stmp4 * Sch); + ast_builder->insert_assignment(Stmp1, Sfour_gamma_squared * Stmp1); + ast_builder->insert_assignment( + Stmp1, bit_cast(select(Stmp2 <= Stmp1, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Ssine_pi_over_eight, Stmp1)); + ast_builder->insert_assignment( + Ssh, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Ssh)); + ast_builder->insert_assignment(Ssh, svd_bitwise_or(Ssh, Stmp2)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Scosine_pi_over_eight, Stmp1)); + ast_builder->insert_assignment( + Sch, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sch)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp2)); + ast_builder->insert_assignment(Stmp1, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Sch * Sch); + ast_builder->insert_assignment(Sc, Stmp2 - Stmp1); + ast_builder->insert_assignment(Ss, Sch * Ssh); + ast_builder->insert_assignment(Ss, Ss + Ss); + ast_builder->insert_assignment(Stmp3, Stmp1 + Stmp2); + ast_builder->insert_assignment(Ss11, Ss11 * Stmp3); + ast_builder->insert_assignment(Ss21, Ss21 * Stmp3); + ast_builder->insert_assignment(Ss31, Ss31 * Stmp3); + ast_builder->insert_assignment(Ss11, Ss11 * Stmp3); + ast_builder->insert_assignment(Stmp1, Ss * Ss21); + ast_builder->insert_assignment(Stmp2, Ss * Ss31); + ast_builder->insert_assignment(Ss21, Sc * Ss21); + ast_builder->insert_assignment(Ss31, Sc * Ss31); + ast_builder->insert_assignment(Ss21, Stmp2 + Ss21); + ast_builder->insert_assignment(Ss31, Ss31 - Stmp1); + ast_builder->insert_assignment(Stmp2, Ss * Ss); + ast_builder->insert_assignment(Stmp1, Ss33 * Stmp2); + ast_builder->insert_assignment(Stmp3, Ss22 * Stmp2); + ast_builder->insert_assignment(Stmp4, Sc * Sc); + ast_builder->insert_assignment(Ss22, Ss22 * Stmp4); + ast_builder->insert_assignment(Ss33, Ss33 * Stmp4); + ast_builder->insert_assignment(Ss22, Ss22 + Stmp1); + ast_builder->insert_assignment(Ss33, Ss33 + Stmp3); + ast_builder->insert_assignment(Stmp4, Stmp4 - Stmp2); + ast_builder->insert_assignment(Stmp2, Ss32 + Ss32); + ast_builder->insert_assignment(Ss32, Ss32 * Stmp4); + ast_builder->insert_assignment(Stmp4, Sc * Ss); + ast_builder->insert_assignment(Stmp2, Stmp2 * Stmp4); + ast_builder->insert_assignment(Stmp5, Stmp5 * Stmp4); + ast_builder->insert_assignment(Ss22, Ss22 + Stmp2); + ast_builder->insert_assignment(Ss32, Ss32 - Stmp5); + ast_builder->insert_assignment(Ss33, Ss33 - Stmp2); + ast_builder->insert_assignment(Stmp1, Ssh * Sqvvx); + ast_builder->insert_assignment(Stmp2, Ssh * Sqvvy); + ast_builder->insert_assignment(Stmp3, Ssh * Sqvvz); + ast_builder->insert_assignment(Ssh, Ssh * Sqvs); + ast_builder->insert_assignment(Sqvs, Sch * Sqvs); + ast_builder->insert_assignment(Sqvvx, Sch * Sqvvx); + ast_builder->insert_assignment(Sqvvy, Sch * Sqvvy); + ast_builder->insert_assignment(Sqvvz, Sch * Sqvvz); + ast_builder->insert_assignment(Sqvvx, Sqvvx + Ssh); + ast_builder->insert_assignment(Sqvs, Sqvs - Stmp1); + ast_builder->insert_assignment(Sqvvy, Sqvvy + Stmp3); + ast_builder->insert_assignment(Sqvvz, Sqvvz - Stmp2); + ast_builder->insert_assignment(Ssh, Ss31 * Sone_half); + ast_builder->insert_assignment(Stmp5, Ss33 - Ss11); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment( + Stmp1, bit_cast(select(Stmp2 >= Stiny_number, + Expr(Ti(int32(0xffffffff))), Expr(Ti(0))))); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Stmp1, Ssh)); + ast_builder->insert_assignment(Sch, svd_bitwise_and(Stmp1, Stmp5)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sone)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp2)); + ast_builder->insert_assignment(Stmp1, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Sch * Sch); + ast_builder->insert_assignment(Stmp3, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp4, rsqrt(Stmp3)); + ast_builder->insert_assignment(Ssh, Stmp4 * Ssh); + ast_builder->insert_assignment(Sch, Stmp4 * Sch); + ast_builder->insert_assignment(Stmp1, Sfour_gamma_squared * Stmp1); + ast_builder->insert_assignment( + Stmp1, bit_cast(select(Stmp2 <= Stmp1, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Ssine_pi_over_eight, Stmp1)); + ast_builder->insert_assignment( + Ssh, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Ssh)); + ast_builder->insert_assignment(Ssh, svd_bitwise_or(Ssh, Stmp2)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Scosine_pi_over_eight, Stmp1)); + ast_builder->insert_assignment( + Sch, svd_bitwise_and(Expr(~bit_cast(Stmp1)), Sch)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp2)); + ast_builder->insert_assignment(Stmp1, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Sch * Sch); + ast_builder->insert_assignment(Sc, Stmp2 - Stmp1); + ast_builder->insert_assignment(Ss, Sch * Ssh); + ast_builder->insert_assignment(Ss, Ss + Ss); + ast_builder->insert_assignment(Stmp3, Stmp1 + Stmp2); + ast_builder->insert_assignment(Ss22, Ss22 * Stmp3); + ast_builder->insert_assignment(Ss32, Ss32 * Stmp3); + ast_builder->insert_assignment(Ss21, Ss21 * Stmp3); + ast_builder->insert_assignment(Ss22, Ss22 * Stmp3); + ast_builder->insert_assignment(Stmp1, Ss * Ss32); + ast_builder->insert_assignment(Stmp2, Ss * Ss21); + ast_builder->insert_assignment(Ss32, Sc * Ss32); + ast_builder->insert_assignment(Ss21, Sc * Ss21); + ast_builder->insert_assignment(Ss32, Stmp2 + Ss32); + ast_builder->insert_assignment(Ss21, Ss21 - Stmp1); + ast_builder->insert_assignment(Stmp2, Ss * Ss); + ast_builder->insert_assignment(Stmp1, Ss11 * Stmp2); + ast_builder->insert_assignment(Stmp3, Ss33 * Stmp2); + ast_builder->insert_assignment(Stmp4, Sc * Sc); + ast_builder->insert_assignment(Ss33, Ss33 * Stmp4); + ast_builder->insert_assignment(Ss11, Ss11 * Stmp4); + ast_builder->insert_assignment(Ss33, Ss33 + Stmp1); + ast_builder->insert_assignment(Ss11, Ss11 + Stmp3); + ast_builder->insert_assignment(Stmp4, Stmp4 - Stmp2); + ast_builder->insert_assignment(Stmp2, Ss31 + Ss31); + ast_builder->insert_assignment(Ss31, Ss31 * Stmp4); + ast_builder->insert_assignment(Stmp4, Sc * Ss); + ast_builder->insert_assignment(Stmp2, Stmp2 * Stmp4); + ast_builder->insert_assignment(Stmp5, Stmp5 * Stmp4); + ast_builder->insert_assignment(Ss33, Ss33 + Stmp2); + ast_builder->insert_assignment(Ss31, Ss31 - Stmp5); + ast_builder->insert_assignment(Ss11, Ss11 - Stmp2); + ast_builder->insert_assignment(Stmp1, Ssh * Sqvvx); + ast_builder->insert_assignment(Stmp2, Ssh * Sqvvy); + ast_builder->insert_assignment(Stmp3, Ssh * Sqvvz); + ast_builder->insert_assignment(Ssh, Ssh * Sqvs); + ast_builder->insert_assignment(Sqvs, Sch * Sqvs); + ast_builder->insert_assignment(Sqvvx, Sch * Sqvvx); + ast_builder->insert_assignment(Sqvvy, Sch * Sqvvy); + ast_builder->insert_assignment(Sqvvz, Sch * Sqvvz); + ast_builder->insert_assignment(Sqvvy, Sqvvy + Ssh); + ast_builder->insert_assignment(Sqvs, Sqvs - Stmp2); + ast_builder->insert_assignment(Sqvvz, Sqvvz + Stmp1); + ast_builder->insert_assignment(Sqvvx, Sqvvx - Stmp3); }); - Stmp2 = Sqvs * Sqvs; - Stmp1 = Sqvvx * Sqvvx; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = Sqvvy * Sqvvy; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = Sqvvz * Sqvvz; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Sqvs = Sqvs * Stmp1; - Sqvvx = Sqvvx * Stmp1; - Sqvvy = Sqvvy * Stmp1; - Sqvvz = Sqvvz * Stmp1; - Stmp1 = Sqvvx * Sqvvx; - Stmp2 = Sqvvy * Sqvvy; - Stmp3 = Sqvvz * Sqvvz; - Sv11 = Sqvs * Sqvs; - Sv22 = Sv11 - Stmp1; - Sv33 = Sv22 - Stmp2; - Sv33 = Sv33 + Stmp3; - Sv22 = Sv22 + Stmp2; - Sv22 = Sv22 - Stmp3; - Sv11 = Sv11 + Stmp1; - Sv11 = Sv11 - Stmp2; - Sv11 = Sv11 - Stmp3; - Stmp1 = Sqvvx + Sqvvx; - Stmp2 = Sqvvy + Sqvvy; - Stmp3 = Sqvvz + Sqvvz; - Sv32 = Sqvs * Stmp1; - Sv13 = Sqvs * Stmp2; - Sv21 = Sqvs * Stmp3; - Stmp1 = Sqvvy * Stmp1; - Stmp2 = Sqvvz * Stmp2; - Stmp3 = Sqvvx * Stmp3; - Sv12 = Stmp1 - Sv21; - Sv23 = Stmp2 - Sv32; - Sv31 = Stmp3 - Sv13; - Sv21 = Stmp1 + Sv21; - Sv32 = Stmp2 + Sv32; - Sv13 = Stmp3 + Sv13; - Stmp2 = Sa12; - Stmp3 = Sa13; - Sa12 = Sv12 * Sa11; - Sa13 = Sv13 * Sa11; - Sa11 = Sv11 * Sa11; - Stmp1 = Sv21 * Stmp2; - Sa11 = Sa11 + Stmp1; - Stmp1 = Sv31 * Stmp3; - Sa11 = Sa11 + Stmp1; - Stmp1 = Sv22 * Stmp2; - Sa12 = Sa12 + Stmp1; - Stmp1 = Sv32 * Stmp3; - Sa12 = Sa12 + Stmp1; - Stmp1 = Sv23 * Stmp2; - Sa13 = Sa13 + Stmp1; - Stmp1 = Sv33 * Stmp3; - Sa13 = Sa13 + Stmp1; - Stmp2 = Sa22; - Stmp3 = Sa23; - Sa22 = Sv12 * Sa21; - Sa23 = Sv13 * Sa21; - Sa21 = Sv11 * Sa21; - Stmp1 = Sv21 * Stmp2; - Sa21 = Sa21 + Stmp1; - Stmp1 = Sv31 * Stmp3; - Sa21 = Sa21 + Stmp1; - Stmp1 = Sv22 * Stmp2; - Sa22 = Sa22 + Stmp1; - Stmp1 = Sv32 * Stmp3; - Sa22 = Sa22 + Stmp1; - Stmp1 = Sv23 * Stmp2; - Sa23 = Sa23 + Stmp1; - Stmp1 = Sv33 * Stmp3; - Sa23 = Sa23 + Stmp1; - Stmp2 = Sa32; - Stmp3 = Sa33; - Sa32 = Sv12 * Sa31; - Sa33 = Sv13 * Sa31; - Sa31 = Sv11 * Sa31; - Stmp1 = Sv21 * Stmp2; - Sa31 = Sa31 + Stmp1; - Stmp1 = Sv31 * Stmp3; - Sa31 = Sa31 + Stmp1; - Stmp1 = Sv22 * Stmp2; - Sa32 = Sa32 + Stmp1; - Stmp1 = Sv32 * Stmp3; - Sa32 = Sa32 + Stmp1; - Stmp1 = Sv23 * Stmp2; - Sa33 = Sa33 + Stmp1; - Stmp1 = Sv33 * Stmp3; - Sa33 = Sa33 + Stmp1; - Stmp1 = Sa11 * Sa11; - Stmp4 = Sa21 * Sa21; - Stmp1 = Stmp1 + Stmp4; - Stmp4 = Sa31 * Sa31; - Stmp1 = Stmp1 + Stmp4; - Stmp2 = Sa12 * Sa12; - Stmp4 = Sa22 * Sa22; - Stmp2 = Stmp2 + Stmp4; - Stmp4 = Sa32 * Sa32; - Stmp2 = Stmp2 + Stmp4; - Stmp3 = Sa13 * Sa13; - Stmp4 = Sa23 * Sa23; - Stmp3 = Stmp3 + Stmp4; - Stmp4 = Sa33 * Sa33; - Stmp3 = Stmp3 + Stmp4; - Stmp4 = bit_cast( - select(Stmp1 < Stmp2, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp5 = svd_bitwise_xor(Sa11, Sa12); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa11 = svd_bitwise_xor(Sa11, Stmp5); - Sa12 = svd_bitwise_xor(Sa12, Stmp5); - Stmp5 = svd_bitwise_xor(Sa21, Sa22); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa21 = svd_bitwise_xor(Sa21, Stmp5); - Sa22 = svd_bitwise_xor(Sa22, Stmp5); - Stmp5 = svd_bitwise_xor(Sa31, Sa32); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa31 = svd_bitwise_xor(Sa31, Stmp5); - Sa32 = svd_bitwise_xor(Sa32, Stmp5); - Stmp5 = svd_bitwise_xor(Sv11, Sv12); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv11 = svd_bitwise_xor(Sv11, Stmp5); - Sv12 = svd_bitwise_xor(Sv12, Stmp5); - Stmp5 = svd_bitwise_xor(Sv21, Sv22); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv21 = svd_bitwise_xor(Sv21, Stmp5); - Sv22 = svd_bitwise_xor(Sv22, Stmp5); - Stmp5 = svd_bitwise_xor(Sv31, Sv32); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv31 = svd_bitwise_xor(Sv31, Stmp5); - Sv32 = svd_bitwise_xor(Sv32, Stmp5); - Stmp5 = svd_bitwise_xor(Stmp1, Stmp2); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Stmp1 = svd_bitwise_xor(Stmp1, Stmp5); - Stmp2 = svd_bitwise_xor(Stmp2, Stmp5); - Stmp5 = Tf(-2.0f); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Stmp4 = Tf(1.0f); - Stmp4 = Stmp4 + Stmp5; - Sa12 = Sa12 * Stmp4; - Sa22 = Sa22 * Stmp4; - Sa32 = Sa32 * Stmp4; - Sv12 = Sv12 * Stmp4; - Sv22 = Sv22 * Stmp4; - Sv32 = Sv32 * Stmp4; - Stmp4 = bit_cast( - select(Stmp1 < Stmp3, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp5 = svd_bitwise_xor(Sa11, Sa13); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa11 = svd_bitwise_xor(Sa11, Stmp5); - Sa13 = svd_bitwise_xor(Sa13, Stmp5); - Stmp5 = svd_bitwise_xor(Sa21, Sa23); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa21 = svd_bitwise_xor(Sa21, Stmp5); - Sa23 = svd_bitwise_xor(Sa23, Stmp5); - Stmp5 = svd_bitwise_xor(Sa31, Sa33); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa31 = svd_bitwise_xor(Sa31, Stmp5); - Sa33 = svd_bitwise_xor(Sa33, Stmp5); - Stmp5 = svd_bitwise_xor(Sv11, Sv13); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv11 = svd_bitwise_xor(Sv11, Stmp5); - Sv13 = svd_bitwise_xor(Sv13, Stmp5); - Stmp5 = svd_bitwise_xor(Sv21, Sv23); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv21 = svd_bitwise_xor(Sv21, Stmp5); - Sv23 = svd_bitwise_xor(Sv23, Stmp5); - Stmp5 = svd_bitwise_xor(Sv31, Sv33); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv31 = svd_bitwise_xor(Sv31, Stmp5); - Sv33 = svd_bitwise_xor(Sv33, Stmp5); - Stmp5 = svd_bitwise_xor(Stmp1, Stmp3); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Stmp1 = svd_bitwise_xor(Stmp1, Stmp5); - Stmp3 = svd_bitwise_xor(Stmp3, Stmp5); - Stmp5 = Tf(-2.0f); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Stmp4 = Tf(1.0f); - Stmp4 = Stmp4 + Stmp5; - Sa11 = Sa11 * Stmp4; - Sa21 = Sa21 * Stmp4; - Sa31 = Sa31 * Stmp4; - Sv11 = Sv11 * Stmp4; - Sv21 = Sv21 * Stmp4; - Sv31 = Sv31 * Stmp4; - Stmp4 = bit_cast( - select(Stmp2 < Stmp3, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp5 = svd_bitwise_xor(Sa12, Sa13); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa12 = svd_bitwise_xor(Sa12, Stmp5); - Sa13 = svd_bitwise_xor(Sa13, Stmp5); - Stmp5 = svd_bitwise_xor(Sa22, Sa23); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa22 = svd_bitwise_xor(Sa22, Stmp5); - Sa23 = svd_bitwise_xor(Sa23, Stmp5); - Stmp5 = svd_bitwise_xor(Sa32, Sa33); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sa32 = svd_bitwise_xor(Sa32, Stmp5); - Sa33 = svd_bitwise_xor(Sa33, Stmp5); - Stmp5 = svd_bitwise_xor(Sv12, Sv13); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv12 = svd_bitwise_xor(Sv12, Stmp5); - Sv13 = svd_bitwise_xor(Sv13, Stmp5); - Stmp5 = svd_bitwise_xor(Sv22, Sv23); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv22 = svd_bitwise_xor(Sv22, Stmp5); - Sv23 = svd_bitwise_xor(Sv23, Stmp5); - Stmp5 = svd_bitwise_xor(Sv32, Sv33); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Sv32 = svd_bitwise_xor(Sv32, Stmp5); - Sv33 = svd_bitwise_xor(Sv33, Stmp5); - Stmp5 = svd_bitwise_xor(Stmp2, Stmp3); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Stmp2 = svd_bitwise_xor(Stmp2, Stmp5); - Stmp3 = svd_bitwise_xor(Stmp3, Stmp5); - Stmp5 = Tf(-2.0f); - Stmp5 = svd_bitwise_and(Stmp5, Stmp4); - Stmp4 = Tf(1.0f); - Stmp4 = Stmp4 + Stmp5; - Sa13 = Sa13 * Stmp4; - Sa23 = Sa23 * Stmp4; - Sa33 = Sa33 * Stmp4; - Sv13 = Sv13 * Stmp4; - Sv23 = Sv23 * Stmp4; - Sv33 = Sv33 * Stmp4; - Su11 = Tf(1.0f); - Su21 = Tf(0.0f); - Su31 = Tf(0.0f); - Su12 = Tf(0.0f); - Su22 = Tf(1.0f); - Su32 = Tf(0.0f); - Su13 = Tf(0.0f); - Su23 = Tf(0.0f); - Su33 = Tf(1.0f); - Ssh = Sa21 * Sa21; - Ssh = bit_cast( - select(Ssh >= Ssmall_number, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Ssh = svd_bitwise_and(Ssh, Sa21); - Stmp5 = Tf(0.0f); - Sch = Stmp5 - Sa11; - Sch = max(Sch, Sa11); - Sch = max(Sch, Ssmall_number); - Stmp5 = bit_cast( - select(Sa11 >= Stmp5, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp1 = Sch * Sch; - Stmp2 = Ssh * Ssh; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Stmp1 = Stmp1 * Stmp2; - Sch = Sch + Stmp1; - Stmp1 = svd_bitwise_and(Expr(~bit_cast(Stmp5)), Ssh); - Stmp2 = svd_bitwise_and(Expr(~bit_cast(Stmp5)), Sch); - Sch = svd_bitwise_and(Stmp5, Sch); - Ssh = svd_bitwise_and(Stmp5, Ssh); - Sch = svd_bitwise_or(Sch, Stmp1); - Ssh = svd_bitwise_or(Ssh, Stmp2); - Stmp1 = Sch * Sch; - Stmp2 = Ssh * Ssh; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Sch = Sch * Stmp1; - Ssh = Ssh * Stmp1; - Sc = Sch * Sch; - Ss = Ssh * Ssh; - Sc = Sc - Ss; - Ss = Ssh * Sch; - Ss = Ss + Ss; - Stmp1 = Ss * Sa11; - Stmp2 = Ss * Sa21; - Sa11 = Sc * Sa11; - Sa21 = Sc * Sa21; - Sa11 = Sa11 + Stmp2; - Sa21 = Sa21 - Stmp1; - Stmp1 = Ss * Sa12; - Stmp2 = Ss * Sa22; - Sa12 = Sc * Sa12; - Sa22 = Sc * Sa22; - Sa12 = Sa12 + Stmp2; - Sa22 = Sa22 - Stmp1; - Stmp1 = Ss * Sa13; - Stmp2 = Ss * Sa23; - Sa13 = Sc * Sa13; - Sa23 = Sc * Sa23; - Sa13 = Sa13 + Stmp2; - Sa23 = Sa23 - Stmp1; - Stmp1 = Ss * Su11; - Stmp2 = Ss * Su12; - Su11 = Sc * Su11; - Su12 = Sc * Su12; - Su11 = Su11 + Stmp2; - Su12 = Su12 - Stmp1; - Stmp1 = Ss * Su21; - Stmp2 = Ss * Su22; - Su21 = Sc * Su21; - Su22 = Sc * Su22; - Su21 = Su21 + Stmp2; - Su22 = Su22 - Stmp1; - Stmp1 = Ss * Su31; - Stmp2 = Ss * Su32; - Su31 = Sc * Su31; - Su32 = Sc * Su32; - Su31 = Su31 + Stmp2; - Su32 = Su32 - Stmp1; - Ssh = Sa31 * Sa31; - Ssh = bit_cast( - select(Ssh >= Ssmall_number, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Ssh = svd_bitwise_and(Ssh, Sa31); - Stmp5 = Tf(0.0f); - Sch = Stmp5 - Sa11; - Sch = max(Sch, Sa11); - Sch = max(Sch, Ssmall_number); - Stmp5 = bit_cast( - select(Sa11 >= Stmp5, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp1 = Sch * Sch; - Stmp2 = Ssh * Ssh; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Stmp1 = Stmp1 * Stmp2; - Sch = Sch + Stmp1; - Stmp1 = svd_bitwise_and(Expr(~bit_cast(Stmp5)), Ssh); - Stmp2 = svd_bitwise_and(Expr(~bit_cast(Stmp5)), Sch); - Sch = svd_bitwise_and(Stmp5, Sch); - Ssh = svd_bitwise_and(Stmp5, Ssh); - Sch = svd_bitwise_or(Sch, Stmp1); - Ssh = svd_bitwise_or(Ssh, Stmp2); - Stmp1 = Sch * Sch; - Stmp2 = Ssh * Ssh; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Sch = Sch * Stmp1; - Ssh = Ssh * Stmp1; - Sc = Sch * Sch; - Ss = Ssh * Ssh; - Sc = Sc - Ss; - Ss = Ssh * Sch; - Ss = Ss + Ss; - Stmp1 = Ss * Sa11; - Stmp2 = Ss * Sa31; - Sa11 = Sc * Sa11; - Sa31 = Sc * Sa31; - Sa11 = Sa11 + Stmp2; - Sa31 = Sa31 - Stmp1; - Stmp1 = Ss * Sa12; - Stmp2 = Ss * Sa32; - Sa12 = Sc * Sa12; - Sa32 = Sc * Sa32; - Sa12 = Sa12 + Stmp2; - Sa32 = Sa32 - Stmp1; - Stmp1 = Ss * Sa13; - Stmp2 = Ss * Sa33; - Sa13 = Sc * Sa13; - Sa33 = Sc * Sa33; - Sa13 = Sa13 + Stmp2; - Sa33 = Sa33 - Stmp1; - Stmp1 = Ss * Su11; - Stmp2 = Ss * Su13; - Su11 = Sc * Su11; - Su13 = Sc * Su13; - Su11 = Su11 + Stmp2; - Su13 = Su13 - Stmp1; - Stmp1 = Ss * Su21; - Stmp2 = Ss * Su23; - Su21 = Sc * Su21; - Su23 = Sc * Su23; - Su21 = Su21 + Stmp2; - Su23 = Su23 - Stmp1; - Stmp1 = Ss * Su31; - Stmp2 = Ss * Su33; - Su31 = Sc * Su31; - Su33 = Sc * Su33; - Su31 = Su31 + Stmp2; - Su33 = Su33 - Stmp1; - Ssh = Sa32 * Sa32; - Ssh = bit_cast( - select(Ssh >= Ssmall_number, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Ssh = svd_bitwise_and(Ssh, Sa32); - Stmp5 = Tf(0.0f); - Sch = Stmp5 - Sa22; - Sch = max(Sch, Sa22); - Sch = max(Sch, Ssmall_number); - Stmp5 = bit_cast( - select(Sa22 >= Stmp5, Expr(Ti(int32(0xffffffff))), Expr(Ti(0)))); - Stmp1 = Sch * Sch; - Stmp2 = Ssh * Ssh; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Stmp1 = Stmp1 * Stmp2; - Sch = Sch + Stmp1; - Stmp1 = svd_bitwise_and(Expr(~bit_cast(Stmp5)), Ssh); - Stmp2 = svd_bitwise_and(Expr(~bit_cast(Stmp5)), Sch); - Sch = svd_bitwise_and(Stmp5, Sch); - Ssh = svd_bitwise_and(Stmp5, Ssh); - Sch = svd_bitwise_or(Sch, Stmp1); - Ssh = svd_bitwise_or(Ssh, Stmp2); - Stmp1 = Sch * Sch; - Stmp2 = Ssh * Ssh; - Stmp2 = Stmp1 + Stmp2; - Stmp1 = rsqrt(Stmp2); - Stmp4 = Stmp1 * Sone_half; - Stmp3 = Stmp1 * Stmp4; - Stmp3 = Stmp1 * Stmp3; - Stmp3 = Stmp2 * Stmp3; - Stmp1 = Stmp1 + Stmp4; - Stmp1 = Stmp1 - Stmp3; - Sch = Sch * Stmp1; - Ssh = Ssh * Stmp1; - Sc = Sch * Sch; - Ss = Ssh * Ssh; - Sc = Sc - Ss; - Ss = Ssh * Sch; - Ss = Ss + Ss; - Stmp1 = Ss * Sa21; - Stmp2 = Ss * Sa31; - Sa21 = Sc * Sa21; - Sa31 = Sc * Sa31; - Sa21 = Sa21 + Stmp2; - Sa31 = Sa31 - Stmp1; - Stmp1 = Ss * Sa22; - Stmp2 = Ss * Sa32; - Sa22 = Sc * Sa22; - Sa32 = Sc * Sa32; - Sa22 = Sa22 + Stmp2; - Sa32 = Sa32 - Stmp1; - Stmp1 = Ss * Sa23; - Stmp2 = Ss * Sa33; - Sa23 = Sc * Sa23; - Sa33 = Sc * Sa33; - Sa23 = Sa23 + Stmp2; - Sa33 = Sa33 - Stmp1; - Stmp1 = Ss * Su12; - Stmp2 = Ss * Su13; - Su12 = Sc * Su12; - Su13 = Sc * Su13; - Su12 = Su12 + Stmp2; - Su13 = Su13 - Stmp1; - Stmp1 = Ss * Su22; - Stmp2 = Ss * Su23; - Su22 = Sc * Su22; - Su23 = Sc * Su23; - Su22 = Su22 + Stmp2; - Su23 = Su23 - Stmp1; - Stmp1 = Ss * Su32; - Stmp2 = Ss * Su33; - Su32 = Sc * Su32; - Su33 = Sc * Su33; - Su32 = Su32 + Stmp2; - Su33 = Su33 - Stmp1; + ast_builder->insert_assignment(Stmp2, Sqvs * Sqvs); + ast_builder->insert_assignment(Stmp1, Sqvvx * Sqvvx); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, Sqvvy * Sqvvy); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, Sqvvz * Sqvvz); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Sqvs, Sqvs * Stmp1); + ast_builder->insert_assignment(Sqvvx, Sqvvx * Stmp1); + ast_builder->insert_assignment(Sqvvy, Sqvvy * Stmp1); + ast_builder->insert_assignment(Sqvvz, Sqvvz * Stmp1); + ast_builder->insert_assignment(Stmp1, Sqvvx * Sqvvx); + ast_builder->insert_assignment(Stmp2, Sqvvy * Sqvvy); + ast_builder->insert_assignment(Stmp3, Sqvvz * Sqvvz); + ast_builder->insert_assignment(Sv11, Sqvs * Sqvs); + ast_builder->insert_assignment(Sv22, Sv11 - Stmp1); + ast_builder->insert_assignment(Sv33, Sv22 - Stmp2); + ast_builder->insert_assignment(Sv33, Sv33 + Stmp3); + ast_builder->insert_assignment(Sv22, Sv22 + Stmp2); + ast_builder->insert_assignment(Sv22, Sv22 - Stmp3); + ast_builder->insert_assignment(Sv11, Sv11 + Stmp1); + ast_builder->insert_assignment(Sv11, Sv11 - Stmp2); + ast_builder->insert_assignment(Sv11, Sv11 - Stmp3); + ast_builder->insert_assignment(Stmp1, Sqvvx + Sqvvx); + ast_builder->insert_assignment(Stmp2, Sqvvy + Sqvvy); + ast_builder->insert_assignment(Stmp3, Sqvvz + Sqvvz); + ast_builder->insert_assignment(Sv32, Sqvs * Stmp1); + ast_builder->insert_assignment(Sv13, Sqvs * Stmp2); + ast_builder->insert_assignment(Sv21, Sqvs * Stmp3); + ast_builder->insert_assignment(Stmp1, Sqvvy * Stmp1); + ast_builder->insert_assignment(Stmp2, Sqvvz * Stmp2); + ast_builder->insert_assignment(Stmp3, Sqvvx * Stmp3); + ast_builder->insert_assignment(Sv12, Stmp1 - Sv21); + ast_builder->insert_assignment(Sv23, Stmp2 - Sv32); + ast_builder->insert_assignment(Sv31, Stmp3 - Sv13); + ast_builder->insert_assignment(Sv21, Stmp1 + Sv21); + ast_builder->insert_assignment(Sv32, Stmp2 + Sv32); + ast_builder->insert_assignment(Sv13, Stmp3 + Sv13); + ast_builder->insert_assignment(Stmp2, Sa12); + ast_builder->insert_assignment(Stmp3, Sa13); + ast_builder->insert_assignment(Sa12, Sv12 * Sa11); + ast_builder->insert_assignment(Sa13, Sv13 * Sa11); + ast_builder->insert_assignment(Sa11, Sv11 * Sa11); + ast_builder->insert_assignment(Stmp1, Sv21 * Stmp2); + ast_builder->insert_assignment(Sa11, Sa11 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv31 * Stmp3); + ast_builder->insert_assignment(Sa11, Sa11 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv22 * Stmp2); + ast_builder->insert_assignment(Sa12, Sa12 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv32 * Stmp3); + ast_builder->insert_assignment(Sa12, Sa12 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv23 * Stmp2); + ast_builder->insert_assignment(Sa13, Sa13 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv33 * Stmp3); + ast_builder->insert_assignment(Sa13, Sa13 + Stmp1); + ast_builder->insert_assignment(Stmp2, Sa22); + ast_builder->insert_assignment(Stmp3, Sa23); + ast_builder->insert_assignment(Sa22, Sv12 * Sa21); + ast_builder->insert_assignment(Sa23, Sv13 * Sa21); + ast_builder->insert_assignment(Sa21, Sv11 * Sa21); + ast_builder->insert_assignment(Stmp1, Sv21 * Stmp2); + ast_builder->insert_assignment(Sa21, Sa21 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv31 * Stmp3); + ast_builder->insert_assignment(Sa21, Sa21 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv22 * Stmp2); + ast_builder->insert_assignment(Sa22, Sa22 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv32 * Stmp3); + ast_builder->insert_assignment(Sa22, Sa22 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv23 * Stmp2); + ast_builder->insert_assignment(Sa23, Sa23 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv33 * Stmp3); + ast_builder->insert_assignment(Sa23, Sa23 + Stmp1); + ast_builder->insert_assignment(Stmp2, Sa32); + ast_builder->insert_assignment(Stmp3, Sa33); + ast_builder->insert_assignment(Sa32, Sv12 * Sa31); + ast_builder->insert_assignment(Sa33, Sv13 * Sa31); + ast_builder->insert_assignment(Sa31, Sv11 * Sa31); + ast_builder->insert_assignment(Stmp1, Sv21 * Stmp2); + ast_builder->insert_assignment(Sa31, Sa31 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv31 * Stmp3); + ast_builder->insert_assignment(Sa31, Sa31 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv22 * Stmp2); + ast_builder->insert_assignment(Sa32, Sa32 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv32 * Stmp3); + ast_builder->insert_assignment(Sa32, Sa32 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv23 * Stmp2); + ast_builder->insert_assignment(Sa33, Sa33 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sv33 * Stmp3); + ast_builder->insert_assignment(Sa33, Sa33 + Stmp1); + ast_builder->insert_assignment(Stmp1, Sa11 * Sa11); + ast_builder->insert_assignment(Stmp4, Sa21 * Sa21); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp4, Sa31 * Sa31); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp2, Sa12 * Sa12); + ast_builder->insert_assignment(Stmp4, Sa22 * Sa22); + ast_builder->insert_assignment(Stmp2, Stmp2 + Stmp4); + ast_builder->insert_assignment(Stmp4, Sa32 * Sa32); + ast_builder->insert_assignment(Stmp2, Stmp2 + Stmp4); + ast_builder->insert_assignment(Stmp3, Sa13 * Sa13); + ast_builder->insert_assignment(Stmp4, Sa23 * Sa23); + ast_builder->insert_assignment(Stmp3, Stmp3 + Stmp4); + ast_builder->insert_assignment(Stmp4, Sa33 * Sa33); + ast_builder->insert_assignment(Stmp3, Stmp3 + Stmp4); + ast_builder->insert_assignment( + Stmp4, bit_cast(select(Stmp1 < Stmp2, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa11, Sa12)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa11, svd_bitwise_xor(Sa11, Stmp5)); + ast_builder->insert_assignment(Sa12, svd_bitwise_xor(Sa12, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa21, Sa22)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa21, svd_bitwise_xor(Sa21, Stmp5)); + ast_builder->insert_assignment(Sa22, svd_bitwise_xor(Sa22, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa31, Sa32)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa31, svd_bitwise_xor(Sa31, Stmp5)); + ast_builder->insert_assignment(Sa32, svd_bitwise_xor(Sa32, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv11, Sv12)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv11, svd_bitwise_xor(Sv11, Stmp5)); + ast_builder->insert_assignment(Sv12, svd_bitwise_xor(Sv12, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv21, Sv22)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv21, svd_bitwise_xor(Sv21, Stmp5)); + ast_builder->insert_assignment(Sv22, svd_bitwise_xor(Sv22, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv31, Sv32)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv31, svd_bitwise_xor(Sv31, Stmp5)); + ast_builder->insert_assignment(Sv32, svd_bitwise_xor(Sv32, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Stmp1, Stmp2)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Stmp1, svd_bitwise_xor(Stmp1, Stmp5)); + ast_builder->insert_assignment(Stmp2, svd_bitwise_xor(Stmp2, Stmp5)); + ast_builder->insert_assignment(Stmp5, Expr(Tf(-2.0f))); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Stmp4, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Stmp4, Stmp4 + Stmp5); + ast_builder->insert_assignment(Sa12, Sa12 * Stmp4); + ast_builder->insert_assignment(Sa22, Sa22 * Stmp4); + ast_builder->insert_assignment(Sa32, Sa32 * Stmp4); + ast_builder->insert_assignment(Sv12, Sv12 * Stmp4); + ast_builder->insert_assignment(Sv22, Sv22 * Stmp4); + ast_builder->insert_assignment(Sv32, Sv32 * Stmp4); + ast_builder->insert_assignment( + Stmp4, bit_cast(select(Stmp1 < Stmp3, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa11, Sa13)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa11, svd_bitwise_xor(Sa11, Stmp5)); + ast_builder->insert_assignment(Sa13, svd_bitwise_xor(Sa13, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa21, Sa23)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa21, svd_bitwise_xor(Sa21, Stmp5)); + ast_builder->insert_assignment(Sa23, svd_bitwise_xor(Sa23, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa31, Sa33)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa31, svd_bitwise_xor(Sa31, Stmp5)); + ast_builder->insert_assignment(Sa33, svd_bitwise_xor(Sa33, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv11, Sv13)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv11, svd_bitwise_xor(Sv11, Stmp5)); + ast_builder->insert_assignment(Sv13, svd_bitwise_xor(Sv13, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv21, Sv23)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv21, svd_bitwise_xor(Sv21, Stmp5)); + ast_builder->insert_assignment(Sv23, svd_bitwise_xor(Sv23, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv31, Sv33)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv31, svd_bitwise_xor(Sv31, Stmp5)); + ast_builder->insert_assignment(Sv33, svd_bitwise_xor(Sv33, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Stmp1, Stmp3)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Stmp1, svd_bitwise_xor(Stmp1, Stmp5)); + ast_builder->insert_assignment(Stmp3, svd_bitwise_xor(Stmp3, Stmp5)); + ast_builder->insert_assignment(Stmp5, Expr(Tf(-2.0f))); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Stmp4, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Stmp4, Stmp4 + Stmp5); + ast_builder->insert_assignment(Sa11, Sa11 * Stmp4); + ast_builder->insert_assignment(Sa21, Sa21 * Stmp4); + ast_builder->insert_assignment(Sa31, Sa31 * Stmp4); + ast_builder->insert_assignment(Sv11, Sv11 * Stmp4); + ast_builder->insert_assignment(Sv21, Sv21 * Stmp4); + ast_builder->insert_assignment(Sv31, Sv31 * Stmp4); + ast_builder->insert_assignment( + Stmp4, bit_cast(select(Stmp2 < Stmp3, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa12, Sa13)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa12, svd_bitwise_xor(Sa12, Stmp5)); + ast_builder->insert_assignment(Sa13, svd_bitwise_xor(Sa13, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa22, Sa23)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa22, svd_bitwise_xor(Sa22, Stmp5)); + ast_builder->insert_assignment(Sa23, svd_bitwise_xor(Sa23, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sa32, Sa33)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sa32, svd_bitwise_xor(Sa32, Stmp5)); + ast_builder->insert_assignment(Sa33, svd_bitwise_xor(Sa33, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv12, Sv13)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv12, svd_bitwise_xor(Sv12, Stmp5)); + ast_builder->insert_assignment(Sv13, svd_bitwise_xor(Sv13, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv22, Sv23)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv22, svd_bitwise_xor(Sv22, Stmp5)); + ast_builder->insert_assignment(Sv23, svd_bitwise_xor(Sv23, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Sv32, Sv33)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Sv32, svd_bitwise_xor(Sv32, Stmp5)); + ast_builder->insert_assignment(Sv33, svd_bitwise_xor(Sv33, Stmp5)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_xor(Stmp2, Stmp3)); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Stmp2, svd_bitwise_xor(Stmp2, Stmp5)); + ast_builder->insert_assignment(Stmp3, svd_bitwise_xor(Stmp3, Stmp5)); + ast_builder->insert_assignment(Stmp5, Expr(Tf(-2.0f))); + ast_builder->insert_assignment(Stmp5, svd_bitwise_and(Stmp5, Stmp4)); + ast_builder->insert_assignment(Stmp4, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Stmp4, Stmp4 + Stmp5); + ast_builder->insert_assignment(Sa13, Sa13 * Stmp4); + ast_builder->insert_assignment(Sa23, Sa23 * Stmp4); + ast_builder->insert_assignment(Sa33, Sa33 * Stmp4); + ast_builder->insert_assignment(Sv13, Sv13 * Stmp4); + ast_builder->insert_assignment(Sv23, Sv23 * Stmp4); + ast_builder->insert_assignment(Sv33, Sv33 * Stmp4); + ast_builder->insert_assignment(Su11, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Su21, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Su31, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Su12, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Su22, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Su32, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Su13, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Su23, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Su33, Expr(Tf(1.0f))); + ast_builder->insert_assignment(Ssh, Sa21 * Sa21); + ast_builder->insert_assignment( + Ssh, bit_cast(select(Ssh >= Ssmall_number, + Expr(Ti(int32(0xffffffff))), Expr(Ti(0))))); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Ssh, Sa21)); + ast_builder->insert_assignment(Stmp5, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Sch, Stmp5 - Sa11); + ast_builder->insert_assignment(Sch, max(Sch, Sa11)); + ast_builder->insert_assignment(Sch, max(Sch, Ssmall_number)); + ast_builder->insert_assignment( + Stmp5, bit_cast(select(Sa11 >= Stmp5, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment(Stmp1, Sch * Sch); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 * Stmp2); + ast_builder->insert_assignment(Sch, Sch + Stmp1); + ast_builder->insert_assignment( + Stmp1, svd_bitwise_and(Expr(~bit_cast(Stmp5)), Ssh)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Expr(~bit_cast(Stmp5)), Sch)); + ast_builder->insert_assignment(Sch, svd_bitwise_and(Stmp5, Sch)); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Stmp5, Ssh)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp1)); + ast_builder->insert_assignment(Ssh, svd_bitwise_or(Ssh, Stmp2)); + ast_builder->insert_assignment(Stmp1, Sch * Sch); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Sch, Sch * Stmp1); + ast_builder->insert_assignment(Ssh, Ssh * Stmp1); + ast_builder->insert_assignment(Sc, Sch * Sch); + ast_builder->insert_assignment(Ss, Ssh * Ssh); + ast_builder->insert_assignment(Sc, Sc - Ss); + ast_builder->insert_assignment(Ss, Ssh * Sch); + ast_builder->insert_assignment(Ss, Ss + Ss); + ast_builder->insert_assignment(Stmp1, Ss * Sa11); + ast_builder->insert_assignment(Stmp2, Ss * Sa21); + ast_builder->insert_assignment(Sa11, Sc * Sa11); + ast_builder->insert_assignment(Sa21, Sc * Sa21); + ast_builder->insert_assignment(Sa11, Sa11 + Stmp2); + ast_builder->insert_assignment(Sa21, Sa21 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Sa12); + ast_builder->insert_assignment(Stmp2, Ss * Sa22); + ast_builder->insert_assignment(Sa12, Sc * Sa12); + ast_builder->insert_assignment(Sa22, Sc * Sa22); + ast_builder->insert_assignment(Sa12, Sa12 + Stmp2); + ast_builder->insert_assignment(Sa22, Sa22 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Sa13); + ast_builder->insert_assignment(Stmp2, Ss * Sa23); + ast_builder->insert_assignment(Sa13, Sc * Sa13); + ast_builder->insert_assignment(Sa23, Sc * Sa23); + ast_builder->insert_assignment(Sa13, Sa13 + Stmp2); + ast_builder->insert_assignment(Sa23, Sa23 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su11); + ast_builder->insert_assignment(Stmp2, Ss * Su12); + ast_builder->insert_assignment(Su11, Sc * Su11); + ast_builder->insert_assignment(Su12, Sc * Su12); + ast_builder->insert_assignment(Su11, Su11 + Stmp2); + ast_builder->insert_assignment(Su12, Su12 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su21); + ast_builder->insert_assignment(Stmp2, Ss * Su22); + ast_builder->insert_assignment(Su21, Sc * Su21); + ast_builder->insert_assignment(Su22, Sc * Su22); + ast_builder->insert_assignment(Su21, Su21 + Stmp2); + ast_builder->insert_assignment(Su22, Su22 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su31); + ast_builder->insert_assignment(Stmp2, Ss * Su32); + ast_builder->insert_assignment(Su31, Sc * Su31); + ast_builder->insert_assignment(Su32, Sc * Su32); + ast_builder->insert_assignment(Su31, Su31 + Stmp2); + ast_builder->insert_assignment(Su32, Su32 - Stmp1); + ast_builder->insert_assignment(Ssh, Sa31 * Sa31); + ast_builder->insert_assignment( + Ssh, bit_cast(select(Ssh >= Ssmall_number, + Expr(Ti(int32(0xffffffff))), Expr(Ti(0))))); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Ssh, Sa31)); + ast_builder->insert_assignment(Stmp5, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Sch, Stmp5 - Sa11); + ast_builder->insert_assignment(Sch, max(Sch, Sa11)); + ast_builder->insert_assignment(Sch, max(Sch, Ssmall_number)); + ast_builder->insert_assignment( + Stmp5, bit_cast(select(Sa11 >= Stmp5, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment(Stmp1, Sch * Sch); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 * Stmp2); + ast_builder->insert_assignment(Sch, Sch + Stmp1); + ast_builder->insert_assignment( + Stmp1, svd_bitwise_and(Expr(~bit_cast(Stmp5)), Ssh)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Expr(~bit_cast(Stmp5)), Sch)); + ast_builder->insert_assignment(Sch, svd_bitwise_and(Stmp5, Sch)); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Stmp5, Ssh)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp1)); + ast_builder->insert_assignment(Ssh, svd_bitwise_or(Ssh, Stmp2)); + ast_builder->insert_assignment(Stmp1, Sch * Sch); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Sch, Sch * Stmp1); + ast_builder->insert_assignment(Ssh, Ssh * Stmp1); + ast_builder->insert_assignment(Sc, Sch * Sch); + ast_builder->insert_assignment(Ss, Ssh * Ssh); + ast_builder->insert_assignment(Sc, Sc - Ss); + ast_builder->insert_assignment(Ss, Ssh * Sch); + ast_builder->insert_assignment(Ss, Ss + Ss); + ast_builder->insert_assignment(Stmp1, Ss * Sa11); + ast_builder->insert_assignment(Stmp2, Ss * Sa31); + ast_builder->insert_assignment(Sa11, Sc * Sa11); + ast_builder->insert_assignment(Sa31, Sc * Sa31); + ast_builder->insert_assignment(Sa11, Sa11 + Stmp2); + ast_builder->insert_assignment(Sa31, Sa31 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Sa12); + ast_builder->insert_assignment(Stmp2, Ss * Sa32); + ast_builder->insert_assignment(Sa12, Sc * Sa12); + ast_builder->insert_assignment(Sa32, Sc * Sa32); + ast_builder->insert_assignment(Sa12, Sa12 + Stmp2); + ast_builder->insert_assignment(Sa32, Sa32 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Sa13); + ast_builder->insert_assignment(Stmp2, Ss * Sa33); + ast_builder->insert_assignment(Sa13, Sc * Sa13); + ast_builder->insert_assignment(Sa33, Sc * Sa33); + ast_builder->insert_assignment(Sa13, Sa13 + Stmp2); + ast_builder->insert_assignment(Sa33, Sa33 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su11); + ast_builder->insert_assignment(Stmp2, Ss * Su13); + ast_builder->insert_assignment(Su11, Sc * Su11); + ast_builder->insert_assignment(Su13, Sc * Su13); + ast_builder->insert_assignment(Su11, Su11 + Stmp2); + ast_builder->insert_assignment(Su13, Su13 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su21); + ast_builder->insert_assignment(Stmp2, Ss * Su23); + ast_builder->insert_assignment(Su21, Sc * Su21); + ast_builder->insert_assignment(Su23, Sc * Su23); + ast_builder->insert_assignment(Su21, Su21 + Stmp2); + ast_builder->insert_assignment(Su23, Su23 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su31); + ast_builder->insert_assignment(Stmp2, Ss * Su33); + ast_builder->insert_assignment(Su31, Sc * Su31); + ast_builder->insert_assignment(Su33, Sc * Su33); + ast_builder->insert_assignment(Su31, Su31 + Stmp2); + ast_builder->insert_assignment(Su33, Su33 - Stmp1); + ast_builder->insert_assignment(Ssh, Sa32 * Sa32); + ast_builder->insert_assignment( + Ssh, bit_cast(select(Ssh >= Ssmall_number, + Expr(Ti(int32(0xffffffff))), Expr(Ti(0))))); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Ssh, Sa32)); + ast_builder->insert_assignment(Stmp5, Expr(Tf(0.0f))); + ast_builder->insert_assignment(Sch, Stmp5 - Sa22); + ast_builder->insert_assignment(Sch, max(Sch, Sa22)); + ast_builder->insert_assignment(Sch, max(Sch, Ssmall_number)); + ast_builder->insert_assignment( + Stmp5, bit_cast(select(Sa22 >= Stmp5, Expr(Ti(int32(0xffffffff))), + Expr(Ti(0))))); + ast_builder->insert_assignment(Stmp1, Sch * Sch); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 * Stmp2); + ast_builder->insert_assignment(Sch, Sch + Stmp1); + ast_builder->insert_assignment( + Stmp1, svd_bitwise_and(Expr(~bit_cast(Stmp5)), Ssh)); + ast_builder->insert_assignment( + Stmp2, svd_bitwise_and(Expr(~bit_cast(Stmp5)), Sch)); + ast_builder->insert_assignment(Sch, svd_bitwise_and(Stmp5, Sch)); + ast_builder->insert_assignment(Ssh, svd_bitwise_and(Stmp5, Ssh)); + ast_builder->insert_assignment(Sch, svd_bitwise_or(Sch, Stmp1)); + ast_builder->insert_assignment(Ssh, svd_bitwise_or(Ssh, Stmp2)); + ast_builder->insert_assignment(Stmp1, Sch * Sch); + ast_builder->insert_assignment(Stmp2, Ssh * Ssh); + ast_builder->insert_assignment(Stmp2, Stmp1 + Stmp2); + ast_builder->insert_assignment(Stmp1, rsqrt(Stmp2)); + ast_builder->insert_assignment(Stmp4, Stmp1 * Sone_half); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp4); + ast_builder->insert_assignment(Stmp3, Stmp1 * Stmp3); + ast_builder->insert_assignment(Stmp3, Stmp2 * Stmp3); + ast_builder->insert_assignment(Stmp1, Stmp1 + Stmp4); + ast_builder->insert_assignment(Stmp1, Stmp1 - Stmp3); + ast_builder->insert_assignment(Sch, Sch * Stmp1); + ast_builder->insert_assignment(Ssh, Ssh * Stmp1); + ast_builder->insert_assignment(Sc, Sch * Sch); + ast_builder->insert_assignment(Ss, Ssh * Ssh); + ast_builder->insert_assignment(Sc, Sc - Ss); + ast_builder->insert_assignment(Ss, Ssh * Sch); + ast_builder->insert_assignment(Ss, Ss + Ss); + ast_builder->insert_assignment(Stmp1, Ss * Sa21); + ast_builder->insert_assignment(Stmp2, Ss * Sa31); + ast_builder->insert_assignment(Sa21, Sc * Sa21); + ast_builder->insert_assignment(Sa31, Sc * Sa31); + ast_builder->insert_assignment(Sa21, Sa21 + Stmp2); + ast_builder->insert_assignment(Sa31, Sa31 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Sa22); + ast_builder->insert_assignment(Stmp2, Ss * Sa32); + ast_builder->insert_assignment(Sa22, Sc * Sa22); + ast_builder->insert_assignment(Sa32, Sc * Sa32); + ast_builder->insert_assignment(Sa22, Sa22 + Stmp2); + ast_builder->insert_assignment(Sa32, Sa32 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Sa23); + ast_builder->insert_assignment(Stmp2, Ss * Sa33); + ast_builder->insert_assignment(Sa23, Sc * Sa23); + ast_builder->insert_assignment(Sa33, Sc * Sa33); + ast_builder->insert_assignment(Sa23, Sa23 + Stmp2); + ast_builder->insert_assignment(Sa33, Sa33 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su12); + ast_builder->insert_assignment(Stmp2, Ss * Su13); + ast_builder->insert_assignment(Su12, Sc * Su12); + ast_builder->insert_assignment(Su13, Sc * Su13); + ast_builder->insert_assignment(Su12, Su12 + Stmp2); + ast_builder->insert_assignment(Su13, Su13 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su22); + ast_builder->insert_assignment(Stmp2, Ss * Su23); + ast_builder->insert_assignment(Su22, Sc * Su22); + ast_builder->insert_assignment(Su23, Sc * Su23); + ast_builder->insert_assignment(Su22, Su22 + Stmp2); + ast_builder->insert_assignment(Su23, Su23 - Stmp1); + ast_builder->insert_assignment(Stmp1, Ss * Su32); + ast_builder->insert_assignment(Stmp2, Ss * Su33); + ast_builder->insert_assignment(Su32, Sc * Su32); + ast_builder->insert_assignment(Su33, Sc * Su33); + ast_builder->insert_assignment(Su32, Su32 + Stmp2); + ast_builder->insert_assignment(Su33, Su33 - Stmp1); return std::make_tuple(Su11, Su12, Su13, Su21, Su22, Su23, Su31, Su32, Su33, Sv11, Sv12, Sv13, Sv21, Sv22, Sv23, Sv31, Sv32, Sv33, Sa11, Sa22, Sa33); diff --git a/taichi/platform/mac/objc_api.h b/taichi/platform/mac/objc_api.h index 76ad0c3a5d32f..c9264dfdaf9f6 100644 --- a/taichi/platform/mac/objc_api.h +++ b/taichi/platform/mac/objc_api.h @@ -1,3 +1,5 @@ +#pragma once + #include #include "taichi/common/core.h" diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 730451f719dd8..a8e644ef610f2 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -19,17 +19,17 @@ TLANG_NAMESPACE_BEGIN ParallelExecutor::ParallelExecutor(const std::string &name, int num_threads) : name_(name), - num_threads(num_threads), - status(ExecutorStatus::uninitialized), - running_threads(0) { + num_threads_(num_threads), + status_(ExecutorStatus::uninitialized), + running_threads_(0) { { - auto _ = std::lock_guard(mut); + auto _ = std::lock_guard(mut_); for (int i = 0; i < num_threads; i++) { - threads.emplace_back([this]() { this->worker_loop(); }); + threads_.emplace_back([this]() { this->worker_loop(); }); } - status = ExecutorStatus::initialized; + status_ = ExecutorStatus::initialized; } init_cv_.notify_all(); } @@ -39,47 +39,47 @@ ParallelExecutor::~ParallelExecutor() { // new tasks from being enqueued during shut down. flush(); { - auto _ = std::lock_guard(mut); - status = ExecutorStatus::finalized; + auto _ = std::lock_guard(mut_); + status_ = ExecutorStatus::finalized; } // Signal the workers that they need to shutdown. worker_cv_.notify_all(); - for (auto &th : threads) { + for (auto &th : threads_) { th.join(); } } void ParallelExecutor::enqueue(const TaskType &func) { { - std::lock_guard _(mut); - task_queue.push_back(func); + std::lock_guard _(mut_); + task_queue_.push_back(func); } worker_cv_.notify_all(); } void ParallelExecutor::flush() { - std::unique_lock lock(mut); + std::unique_lock lock(mut_); while (!flush_cv_cond()) { flush_cv_.wait(lock); } } bool ParallelExecutor::flush_cv_cond() { - return (task_queue.empty() && running_threads == 0); + return (task_queue_.empty() && running_threads_ == 0); } void ParallelExecutor::worker_loop() { TI_DEBUG("Starting worker thread."); - auto thread_id = thread_counter++; + auto thread_id = thread_counter_++; std::string thread_name = name_; - if (num_threads != 1) + if (num_threads_ != 1) thread_name += fmt::format("_{}", thread_id); Timeline::get_this_thread_instance().set_name(thread_name); { - std::unique_lock lock(mut); - while (status == ExecutorStatus::uninitialized) { + std::unique_lock lock(mut_); + while (status_ == ExecutorStatus::uninitialized) { init_cv_.wait(lock); } } @@ -89,25 +89,25 @@ void ParallelExecutor::worker_loop() { while (!done) { bool notify_flush_cv = false; { - std::unique_lock lock(mut); - while (task_queue.empty() && status == ExecutorStatus::initialized) { + std::unique_lock lock(mut_); + while (task_queue_.empty() && status_ == ExecutorStatus::initialized) { worker_cv_.wait(lock); } // So long as |task_queue| is not empty, we keep running. - if (!task_queue.empty()) { - auto task = task_queue.front(); - running_threads++; - task_queue.pop_front(); + if (!task_queue_.empty()) { + auto task = task_queue_.front(); + running_threads_++; + task_queue_.pop_front(); lock.unlock(); // Run the task task(); lock.lock(); - running_threads--; + running_threads_--; } notify_flush_cv = flush_cv_cond(); - if (status == ExecutorStatus::finalized && task_queue.empty()) { + if (status_ == ExecutorStatus::finalized && task_queue_.empty()) { done = true; } } @@ -188,16 +188,15 @@ ExecutionQueue::ExecutionQueue( } AsyncEngine::AsyncEngine(const CompileConfig *const config, - const std::unordered_map &snodes, const BackendExecCompilationFunc &compile_to_backend) : queue(&ir_bank_, compile_to_backend), config_(config), - sfg(std::make_unique(this, &ir_bank_, config, snodes)) { + sfg(std::make_unique(this, &ir_bank_, config)) { Timeline::get_this_thread_instance().set_name("host"); ir_bank_.set_sfg(sfg.get()); } -void AsyncEngine::launch(Kernel *kernel, Context &context) { +void AsyncEngine::launch(Kernel *kernel, RuntimeContext &context) { if (!kernel->lowered()) { kernel->lower(/*to_executable=*/false); } diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index a3f7b08cc26a3..4f0e248f1e324 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -32,7 +32,7 @@ class ParallelExecutor { void flush(); int get_num_threads() { - return num_threads; + return num_threads_; } private: @@ -48,15 +48,15 @@ class ParallelExecutor { bool flush_cv_cond(); std::string name_; - int num_threads; - std::atomic thread_counter{0}; - std::mutex mut; + int num_threads_; + std::atomic thread_counter_{0}; + std::mutex mut_; // All guarded by |mut| - ExecutorStatus status; - std::vector threads; - std::deque task_queue; - int running_threads; + ExecutorStatus status_; + std::vector threads_; + std::deque task_queue_; + int running_threads_; // Used to signal the workers that they can start polling from |task_queue|. std::condition_variable init_cv_; @@ -138,14 +138,13 @@ class AsyncEngine { std::unique_ptr sfg; explicit AsyncEngine(const CompileConfig *const config, - const std::unordered_map &snodes, const BackendExecCompilationFunc &compile_to_backend); void clear_cache() { queue.clear_cache(); } - void launch(Kernel *kernel, Context &context); + void launch(Kernel *kernel, RuntimeContext &context); // Flush the tasks only. void flush(); diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 3a5183ca4ab30..5058653056757 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -27,7 +27,7 @@ TaskLaunchRecord::TaskLaunchRecord() : kernel(nullptr), ir_handle(nullptr, 0) { // Initial node has rec.id == 0, so we start from rec.id == 1. std::atomic TaskLaunchRecord::task_counter = 1; -TaskLaunchRecord::TaskLaunchRecord(Context context, +TaskLaunchRecord::TaskLaunchRecord(RuntimeContext context, Kernel *kernel, IRHandle ir_handle) : context(context), kernel(kernel), ir_handle(ir_handle) { @@ -150,7 +150,7 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { t.kernel->name + "_" + offloaded_task_type_name(root_stmt->task_type); meta.type = root_stmt->task_type; get_meta_input_value_states(root_stmt, &meta, ir_bank); - meta.loop_unique = gather_uniquely_accessed_pointers(root_stmt); + meta.loop_unique = gather_uniquely_accessed_pointers(root_stmt).first; std::unordered_set activates, deactivates; diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 8952f8b9404d9..513ded3147e3d 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -63,14 +63,14 @@ class IRHandle { // Records the necessary data for launching an offloaded task. class TaskLaunchRecord { public: - Context context; + RuntimeContext context; Kernel *kernel; // TODO: remove this IRHandle ir_handle; int id; TaskLaunchRecord(); - TaskLaunchRecord(Context context, Kernel *kernel, IRHandle ir_handle); + TaskLaunchRecord(RuntimeContext context, Kernel *kernel, IRHandle ir_handle); OffloadedStmt *stmt() const; diff --git a/taichi/program/callable.cpp b/taichi/program/callable.cpp index c97be6f1c9bbd..914a5973ac937 100644 --- a/taichi/program/callable.cpp +++ b/taichi/program/callable.cpp @@ -4,8 +4,12 @@ namespace taichi { namespace lang { -int Callable::insert_arg(const DataType &dt, bool is_external_array) { - args.emplace_back(dt->get_compute_type(), is_external_array, /*size=*/0); +Callable::Callable() = default; + +Callable::~Callable() = default; + +int Callable::insert_arg(const DataType &dt, bool is_array) { + args.emplace_back(dt->get_compute_type(), is_array); return (int)args.size() - 1; } @@ -13,16 +17,23 @@ int Callable::insert_ret(const DataType &dt) { rets.emplace_back(dt->get_compute_type()); return (int)rets.size() - 1; } +int Callable::insert_arr_arg(const DataType &dt, + int total_dim, + std::vector element_shape) { + args.emplace_back(dt->get_compute_type(), true, /*size=*/0, total_dim, + element_shape); + return (int)args.size() - 1; +} Callable::CurrentCallableGuard::CurrentCallableGuard(Program *program, Callable *callable) - : program(program) { - old_callable = program->current_callable; + : program_(program) { + old_callable_ = program->current_callable; program->current_callable = callable; } Callable::CurrentCallableGuard::~CurrentCallableGuard() { - program->current_callable = old_callable; + program_->current_callable = old_callable_; } } // namespace lang diff --git a/taichi/program/callable.h b/taichi/program/callable.h index 43f49ad6b1e93..dc9910a55fff3 100644 --- a/taichi/program/callable.h +++ b/taichi/program/callable.h @@ -9,21 +9,30 @@ class Program; class IRNode; class FrontendContext; -class Callable { +class TI_DLL_EXPORT Callable { public: - Program *program; - std::unique_ptr ir; - std::unique_ptr context; + Program *program{nullptr}; + std::unique_ptr ir{nullptr}; + std::unique_ptr context{nullptr}; struct Arg { DataType dt; - bool is_external_array; - std::size_t size; + bool is_array{ + false}; // This is true for both ndarray and external array args. + std::size_t size{0}; // TODO: size is runtime information, maybe remove? + std::size_t total_dim{0}; // total dim of array + std::vector element_shape = {}; // shape of each element explicit Arg(const DataType &dt = PrimitiveType::unknown, - bool is_external_array = false, - std::size_t size = 0) - : dt(dt), is_external_array(is_external_array), size(size) { + bool is_array = false, + std::size_t size = 0, + int total_dim = 0, + std::vector element_shape = {}) + : dt(dt), + is_array(is_array), + size(size), + total_dim(total_dim), + element_shape(std::move(element_shape)) { } }; @@ -37,17 +46,22 @@ class Callable { std::vector args; std::vector rets; - virtual ~Callable() = default; + Callable(); + virtual ~Callable(); - int insert_arg(const DataType &dt, bool is_external_array); + int insert_arg(const DataType &dt, bool is_array); + + int insert_arr_arg(const DataType &dt, + int total_dim, + std::vector element_shape); int insert_ret(const DataType &dt); [[nodiscard]] virtual std::string get_name() const = 0; class CurrentCallableGuard { - Callable *old_callable; - Program *program; + Callable *old_callable_; + Program *program_; public: CurrentCallableGuard(Program *program, Callable *callable); diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 7c260e179e38e..356dea7dc61fe 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -7,9 +7,11 @@ TLANG_NAMESPACE_BEGIN CompileConfig::CompileConfig() { arch = host_arch(); simd_width = default_simd_width(arch); + opt_level = 1; external_optimization_level = 3; packed = false; print_ir = false; + print_preprocessed_ir = false; print_accessor_ir = false; print_evaluator_ir = false; print_benchmark_stat = false; @@ -42,6 +44,7 @@ CompileConfig::CompileConfig() { make_thread_local = true; make_block_local = true; detect_read_only = true; + ndarray_use_cached_allocator = true; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 03666d7ec06ba..9b329f6f1e4ae 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -1,9 +1,10 @@ #pragma once -#include "arch.h" +#include "taichi/backends/arch.h" #include "taichi/lang_util.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { struct CompileConfig { Arch arch; @@ -12,9 +13,11 @@ struct CompileConfig { bool check_out_of_bound; int simd_width; bool lazy_compilation; + int opt_level; int external_optimization_level; int max_vector_width; bool packed; + bool print_preprocessed_ir; bool print_ir; bool print_accessor_ir; bool print_evaluator_ir; @@ -39,6 +42,7 @@ struct CompileConfig { bool make_thread_local; bool make_block_local; bool detect_read_only; + bool ndarray_use_cached_allocator; DataType default_fp; DataType default_ip; std::string extra_flags; @@ -69,6 +73,10 @@ struct CompileConfig { std::string cc_compile_cmd; std::string cc_link_cmd; + // Opengl backend options: + bool allow_nv_shader_extension{true}; + bool use_gles{false}; + // Async options int async_opt_passes{3}; bool async_opt_fusion{true}; @@ -86,9 +94,20 @@ struct CompileConfig { bool quant_opt_store_fusion{true}; bool quant_opt_atomic_demotion{true}; + // Mesh related. + // MeshTaichi options + bool make_mesh_block_local{true}; + bool optimize_mesh_reordered_mapping{true}; + bool mesh_localize_to_end_mapping{true}; + bool mesh_localize_from_end_mapping{false}; + bool mesh_localize_all_attr_mappings{false}; + bool demote_no_access_mesh_fors{true}; + bool experimental_auto_mesh_local{false}; + int auto_mesh_local_default_occupacy{4}; + CompileConfig(); }; -extern CompileConfig default_compile_config; +extern TI_DLL_EXPORT CompileConfig default_compile_config; TLANG_NAMESPACE_END diff --git a/taichi/program/context.h b/taichi/program/context.h index 6641f8ae3152b..0d6d7eef413a6 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -10,13 +10,25 @@ namespace lang { struct LLVMRuntime; -// "Context" holds necessary data for kernel body execution, such as a pointer -// to the LLVMRuntime struct, kernel arguments, and the thread id (if on CPU). -struct Context { +// "RuntimeContext" holds necessary data for kernel body execution, such as a +// pointer to the LLVMRuntime struct, kernel arguments, and the thread id (if on +// CPU). +struct RuntimeContext { LLVMRuntime *runtime; + // args can contain: + // - primitive_types + // - raw ptrs: for external array, or torch-based ndarray + // - DeviceAllocation*: for taichi ndaray uint64 args[taichi_max_num_args_total]; int32 extra_args[taichi_max_num_args_extra][taichi_max_num_indices]; int32 cpu_thread_id; + // |is_device_allocation| is true iff args[i] is a DeviceAllocation*. + bool is_device_allocation[taichi_max_num_args_total]{false}; + // We move the pointer of result buffer from LLVMRuntime to RuntimeContext + // because each real function need a place to store its result, but + // LLVMRuntime is shared among functions. So we moved the pointer to + // RuntimeContext which each function have one. + uint64 *result_buffer; static constexpr size_t extra_args_size = sizeof(extra_args); @@ -34,6 +46,10 @@ struct Context { void set_arg(int i, T v) { args[i] = taichi_union_cast_with_different_sizes(v); } + + void set_device_allocation(int i, bool is_device_allocation_) { + is_device_allocation[i] = is_device_allocation_; + } #endif }; diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 86a8d2cb604d1..0db7e436fca61 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -12,7 +12,7 @@ bool is_extension_supported(Arch arch, Extension ext) { {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::assertion, Extension::extfunc, Extension::packed, - Extension::dynamic_index}}, + Extension::dynamic_index, Extension::mesh}}, {Arch::arm64, {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, @@ -21,7 +21,7 @@ bool is_extension_supported(Arch arch, Extension ext) { {Extension::sparse, Extension::async_mode, Extension::quant, Extension::quant_basic, Extension::data64, Extension::adstack, Extension::bls, Extension::assertion, Extension::packed, - Extension::dynamic_index}}, + Extension::dynamic_index, Extension::mesh}}, // TODO: supporting quant & async in metal(tests randomly crashed) {Arch::metal, {Extension::adstack, Extension::assertion, Extension::sparse}}, diff --git a/taichi/program/extension.h b/taichi/program/extension.h index 24ad3f4fd2080..29619f96955bd 100644 --- a/taichi/program/extension.h +++ b/taichi/program/extension.h @@ -1,10 +1,11 @@ #pragma once -#include "arch.h" +#include "taichi/backends/arch.h" #include -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { // The Taichi core feature set (dense SNode) should probably be supported by all // the backends. In addition, each backend can optionally support features in @@ -23,4 +24,5 @@ enum class Extension { bool is_extension_supported(Arch arch, Extension ext); -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index d083c6e92a992..665423e805380 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -11,7 +11,7 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(); + context = std::make_unique(program->config.arch); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index cfb7402d04fcb..522debb2e33f7 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -60,7 +60,7 @@ bool IRBank::insert(std::unique_ptr &&ir, uint64 hash) { } void IRBank::insert_to_trash_bin(std::unique_ptr &&ir) { - trash_bin.push_back(std::move(ir)); + trash_bin_.push_back(std::move(ir)); } IRNode *IRBank::find(IRHandle ir_handle) { @@ -94,6 +94,7 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { TI_ASSERT(!task_a->tls_prologue && !task_a->bls_prologue && !task_a->tls_epilogue && !task_a->tls_epilogue && !task_b->tls_prologue && !task_b->bls_prologue && + !task_a->mesh_prologue && !task_b->mesh_prologue && !task_b->tls_epilogue && !task_b->tls_epilogue); // TODO: in certain cases this optimization can be wrong! // Fuse task b into task_a diff --git a/taichi/program/ir_bank.h b/taichi/program/ir_bank.h index 2b6c6f2bd2799..f598343b21611 100644 --- a/taichi/program/ir_bank.h +++ b/taichi/program/ir_bank.h @@ -45,7 +45,7 @@ class IRBank { StateFlowGraph *sfg_; std::unordered_map hash_bank_; std::unordered_map> ir_bank_; - std::vector> trash_bin; // prevent IR from deleted + std::vector> trash_bin_; // prevent IR from deleted std::unordered_map, IRHandle> fuse_bank_; std::unordered_map demote_activation_bank_; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 60369fbb490ce..582ee70c9eb70 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -10,7 +10,10 @@ #include "taichi/program/program.h" #include "taichi/util/action_recorder.h" #include "taichi/util/statistics.h" + +#ifdef TI_WITH_LLVM #include "taichi/llvm/llvm_program.h" +#endif TLANG_NAMESPACE_BEGIN @@ -19,38 +22,15 @@ class Function; Kernel::Kernel(Program &program, const std::function &func, const std::string &primal_name, - bool grad) - : grad(grad), lowered_(false) { - this->program = &program; - if (auto *llvm_program_impl = program.get_llvm_program_impl()) { - llvm_program_impl->maybe_initialize_cuda_llvm_context(); - } - is_accessor = false; - is_evaluator = false; - compiled_ = nullptr; - context = std::make_unique(); - ir = context->get_root(); - ir_is_ast_ = true; - - { - // Note: this is NOT a mutex. If we want to call Kernel::Kernel() - // concurrently, we need to lock this block of code together with - // taichi::lang::context with a mutex. - CurrentCallableGuard _(this->program, this); - func(); - ir->as()->kernel = this; - } - - arch = program.config.arch; - - if (!grad) { - name = primal_name; - } else { - name = primal_name + "_grad"; - } + bool grad) { + this->init(program, func, primal_name, grad); +} - if (!program.config.lazy_compilation) - compile(); +Kernel::Kernel(Program &program, + const std::function &func, + const std::string &primal_name, + bool grad) { + this->init(program, std::bind(func, this), primal_name, grad); } Kernel::Kernel(Program &program, @@ -94,9 +74,17 @@ void Kernel::lower(bool to_executable) { (is_evaluator && !config.print_evaluator_ir)) verbose = false; + if (config.print_preprocessed_ir) { + TI_INFO("[{}] {}:", get_name(), "Preprocessed IR"); + std::cout << std::flush; + irpass::re_id(ir.get()); + irpass::print(ir.get()); + std::cout << std::flush; + } + if (to_executable) { irpass::compile_to_executable( - ir.get(), config, this, /*vectorize*/ arch_is_cpu(arch), grad, + ir.get(), config, this, grad, /*ad_use_stack=*/true, verbose, /*lower_global_access=*/to_executable, /*make_thread_local=*/config.make_thread_local, /*make_block_local=*/ @@ -104,8 +92,7 @@ void Kernel::lower(bool to_executable) { config.make_block_local, /*start_from_ast=*/ir_is_ast_); } else { - irpass::compile_to_offloads(ir.get(), config, this, verbose, - /*vectorize=*/arch_is_cpu(arch), grad, + irpass::compile_to_offloads(ir.get(), config, this, verbose, grad, /*ad_use_stack=*/true, /*start_from_ast=*/ir_is_ast_); } @@ -146,18 +133,19 @@ Kernel::LaunchContextBuilder Kernel::make_launch_context() { return LaunchContextBuilder(this); } -Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel, Context *ctx) +Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel, + RuntimeContext *ctx) : kernel_(kernel), owned_ctx_(nullptr), ctx_(ctx) { } Kernel::LaunchContextBuilder::LaunchContextBuilder(Kernel *kernel) : kernel_(kernel), - owned_ctx_(std::make_unique()), + owned_ctx_(std::make_unique()), ctx_(owned_ctx_.get()) { } void Kernel::LaunchContextBuilder::set_arg_float(int arg_id, float64 d) { - TI_ASSERT_INFO(!kernel_->args[arg_id].is_external_array, + TI_ASSERT_INFO(!kernel_->args[arg_id].is_array, "Assigning scalar value to external (numpy) array argument is " "not allowed."); @@ -187,13 +175,16 @@ void Kernel::LaunchContextBuilder::set_arg_float(int arg_id, float64 d) { ctx_->set_arg(arg_id, (uint32)d); } else if (dt->is_primitive(PrimitiveTypeID::u64)) { ctx_->set_arg(arg_id, (uint64)d); + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + // use f32 to interact with python + ctx_->set_arg(arg_id, (float32)d); } else { TI_NOT_IMPLEMENTED } } void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) { - TI_ASSERT_INFO(!kernel_->args[arg_id].is_external_array, + TI_ASSERT_INFO(!kernel_->args[arg_id].is_array, "Assigning scalar value to external (numpy) array argument is " "not allowed."); @@ -219,10 +210,6 @@ void Kernel::LaunchContextBuilder::set_arg_int(int arg_id, int64 d) { ctx_->set_arg(arg_id, (uint32)d); } else if (dt->is_primitive(PrimitiveTypeID::u64)) { ctx_->set_arg(arg_id, (uint64)d); - } else if (dt->is_primitive(PrimitiveTypeID::f32)) { - ctx_->set_arg(arg_id, (float32)d); - } else if (dt->is_primitive(PrimitiveTypeID::f64)) { - ctx_->set_arg(arg_id, (float64)d); } else { TI_INFO(dt->to_string()); TI_NOT_IMPLEMENTED @@ -233,11 +220,13 @@ void Kernel::LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) { ctx_->extra_args[i][j] = d; } -void Kernel::LaunchContextBuilder::set_arg_external_array(int arg_id, - uint64 ptr, - uint64 size) { +void Kernel::LaunchContextBuilder::set_arg_external_array( + int arg_id, + uintptr_t ptr, + uint64 size, + bool is_device_allocation) { TI_ASSERT_INFO( - kernel_->args[arg_id].is_external_array, + kernel_->args[arg_id].is_array, "Assigning external (numpy) array to scalar argument is not allowed."); ActionRecorder::get_instance().record( @@ -248,10 +237,36 @@ void Kernel::LaunchContextBuilder::set_arg_external_array(int arg_id, kernel_->args[arg_id].size = size; ctx_->set_arg(arg_id, ptr); + ctx_->set_device_allocation(arg_id, is_device_allocation); +} + +void Kernel::LaunchContextBuilder::set_arg_external_array_with_shape( + int arg_id, + uintptr_t ptr, + uint64 size, + const std::vector &shape) { + this->set_arg_external_array(arg_id, ptr, size, false); + TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices, + "External array cannot have > {max_num_indices} indices"); + for (uint64 i = 0; i < shape.size(); ++i) { + this->set_extra_arg_int(arg_id, i, shape[i]); + } +} + +void Kernel::LaunchContextBuilder::set_arg_ndarray(int arg_id, + const Ndarray &arr) { + intptr_t ptr = arr.get_device_allocation_ptr_as_int(); + uint64 arr_size = arr.get_element_size() * arr.get_nelement(); + this->set_arg_external_array(arg_id, ptr, arr_size, true); + TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices, + "External array cannot have > {max_num_indices} indices"); + for (uint64 i = 0; i < arr.shape.size(); ++i) { + this->set_extra_arg_int(arg_id, i, arr.shape[i]); + } } void Kernel::LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) { - TI_ASSERT_INFO(!kernel_->args[arg_id].is_external_array, + TI_ASSERT_INFO(!kernel_->args[arg_id].is_array, "Assigning scalar value to external (numpy) array argument is " "not allowed."); @@ -264,65 +279,74 @@ void Kernel::LaunchContextBuilder::set_arg_raw(int arg_id, uint64 d) { ctx_->set_arg(arg_id, d); } -Context &Kernel::LaunchContextBuilder::get_context() { +RuntimeContext &Kernel::LaunchContextBuilder::get_context() { +#ifdef TI_WITH_LLVM if (auto *llvm_program_impl = kernel_->program->get_llvm_program_impl()) { ctx_->runtime = llvm_program_impl->get_llvm_runtime(); } +#endif + ctx_->result_buffer = kernel_->program->result_buffer; return *ctx_; } -float64 Kernel::get_ret_float(int i) { - auto dt = rets[i].dt->get_compute_type(); +template +T Kernel::fetch_ret(DataType dt, int i) { if (dt->is_primitive(PrimitiveTypeID::f32)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::f64)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::i32)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::i64)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::i8)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::i16)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::u8)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::u16)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::u32)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); } else if (dt->is_primitive(PrimitiveTypeID::u64)) { - return (float64)program->fetch_result(i); + return (T)program->fetch_result(i); + } else if (dt->is_primitive(PrimitiveTypeID::f16)) { + // use f32 to interact with python + return (T)program->fetch_result(i); } else { TI_NOT_IMPLEMENTED } } +float64 Kernel::get_ret_float(int i) { + auto dt = rets[i].dt->get_compute_type(); + return fetch_ret(dt, i); +} + int64 Kernel::get_ret_int(int i) { auto dt = rets[i].dt->get_compute_type(); - if (dt->is_primitive(PrimitiveTypeID::i32)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::i64)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::i8)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::i16)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::u8)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::u16)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::u32)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::u64)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::f32)) { - return (int64)program->fetch_result(i); - } else if (dt->is_primitive(PrimitiveTypeID::f64)) { - return (int64)program->fetch_result(i); - } else { - TI_NOT_IMPLEMENTED + return fetch_ret(dt, i); +} + +std::vector Kernel::get_ret_int_tensor(int i) { + DataType dt = rets[i].dt->as()->get_element_type(); + int size = rets[i].dt->as()->get_num_elements(); + std::vector res; + for (int j = 0; j < size; j++) { + res.emplace_back(fetch_ret(dt, j)); + } + return res; +} + +std::vector Kernel::get_ret_float_tensor(int i) { + DataType dt = rets[i].dt->as()->get_element_type(); + int size = rets[i].dt->as()->get_num_elements(); + std::vector res; + for (int j = 0; j < size; j++) { + res.emplace_back(fetch_ret(dt, j)); } + return res; } void Kernel::set_arch(Arch arch) { @@ -349,6 +373,9 @@ void Kernel::account_for_offloaded(OffloadedStmt *stmt) { } else if (task_type == OffloadedStmt::TaskType::struct_for) { stat.add("launched_tasks_compute", 1.0); stat.add("launched_tasks_struct_for", 1.0); + } else if (task_type == OffloadedStmt::TaskType::mesh_for) { + stat.add("launched_tasks_compute", 1.0); + stat.add("launched_tasks_mesh_for", 1.0); } else if (task_type == OffloadedStmt::TaskType::gc) { stat.add("launched_tasks_garbage_collect", 1.0); } @@ -358,6 +385,46 @@ std::string Kernel::get_name() const { return name; } +void Kernel::init(Program &program, + const std::function &func, + const std::string &primal_name, + bool grad) { + this->grad = grad; + this->lowered_ = false; + this->program = &program; +#ifdef TI_WITH_LLVM + if (auto *llvm_program_impl = program.get_llvm_program_impl()) { + llvm_program_impl->maybe_initialize_cuda_llvm_context(); + } +#endif + is_accessor = false; + is_evaluator = false; + compiled_ = nullptr; + context = std::make_unique(program.config.arch); + ir = context->get_root(); + ir_is_ast_ = true; + + this->arch = program.config.arch; + + if (!grad) { + this->name = primal_name; + } else { + this->name = primal_name + "_grad"; + } + + { + // Note: this is NOT a mutex. If we want to call Kernel::Kernel() + // concurrently, we need to lock this block of code together with + // taichi::lang::context with a mutex. + CurrentCallableGuard _(this->program, this); + func(); + ir->as()->kernel = this; + } + + if (!program.config.lazy_compilation) + compile(); +} + // static bool Kernel::supports_lowering(Arch arch) { return arch_is_cpu(arch) || (arch == Arch::cuda) || (arch == Arch::metal); diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 691a83d22e61e..c3a270ab098ed 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -3,14 +3,15 @@ #include "taichi/lang_util.h" #include "taichi/ir/snode.h" #include "taichi/ir/ir.h" -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" #include "taichi/program/callable.h" +#include "taichi/program/ndarray.h" TLANG_NAMESPACE_BEGIN class Program; -class Kernel : public Callable { +class TI_DLL_EXPORT Kernel : public Callable { public: std::string name; std::vector no_activate; @@ -20,10 +21,9 @@ class Kernel : public Callable { bool is_evaluator{false}; bool grad{false}; - // TODO: Give "Context" a more specific name. class LaunchContextBuilder { public: - LaunchContextBuilder(Kernel *kernel, Context *ctx); + LaunchContextBuilder(Kernel *kernel, RuntimeContext *ctx); explicit LaunchContextBuilder(Kernel *kernel); LaunchContextBuilder(LaunchContextBuilder &&) = default; @@ -37,22 +37,32 @@ class Kernel : public Callable { void set_extra_arg_int(int i, int j, int32 d); - void set_arg_external_array(int arg_id, uint64 ptr, uint64 size); + void set_arg_external_array(int arg_id, + uintptr_t ptr, + uint64 size, + bool is_device_allocation); + + void set_arg_external_array_with_shape(int arg_id, + uintptr_t ptr, + uint64 size, + const std::vector &shape); + + void set_arg_ndarray(int arg_id, const Ndarray &arr); // Sets the |arg_id|-th arg in the context to the bits stored in |d|. // This ignores the underlying kernel's |arg_id|-th arg type. void set_arg_raw(int arg_id, uint64 d); - Context &get_context(); + RuntimeContext &get_context(); private: Kernel *kernel_; - std::unique_ptr owned_ctx_; + std::unique_ptr owned_ctx_; // |ctx_| *almost* always points to |owned_ctx_|. However, it is possible - // that the caller passes a Context pointer externally. In that case, + // that the caller passes a RuntimeContext pointer externally. In that case, // |owned_ctx_| will be nullptr. // Invariant: |ctx_| will never be nullptr. - Context *ctx_; + RuntimeContext *ctx_; }; Kernel(Program &program, @@ -60,6 +70,11 @@ class Kernel : public Callable { const std::string &name = "", bool grad = false); + Kernel(Program &program, + const std::function &func, + const std::string &name = "", + bool grad = false); + Kernel(Program &program, std::unique_ptr &&ir, const std::string &name = "", @@ -83,14 +98,25 @@ class Kernel : public Callable { LaunchContextBuilder make_launch_context(); + template + T fetch_ret(DataType dt, int i); + float64 get_ret_float(int i); int64 get_ret_int(int i); + std::vector get_ret_int_tensor(int i); + + std::vector get_ret_float_tensor(int i); + void set_arch(Arch arch); void account_for_offloaded(OffloadedStmt *stmt); + uint64 get_next_task_id() { + return task_counter_++; + } + [[nodiscard]] std::string get_name() const override; /** * Whether the given |arch| is supported in the lower() method. @@ -101,6 +127,11 @@ class Kernel : public Callable { static bool supports_lowering(Arch arch); private: + void init(Program &program, + const std::function &func, + const std::string &name = "", + bool grad = false); + // True if |ir| is a frontend AST. False if it's already offloaded to CHI IR. bool ir_is_ast_{false}; // The closure that, if invoked, lauches the backend kernel (shader) @@ -109,6 +140,7 @@ class Kernel : public Callable { // lower inital AST all the way down to a bunch of // OffloadedStmt for async execution bool lowered_{false}; + std::atomic task_counter_{0}; }; TLANG_NAMESPACE_END diff --git a/taichi/program/kernel_profiler.cpp b/taichi/program/kernel_profiler.cpp index 308295d47ee8e..1756257240745 100644 --- a/taichi/program/kernel_profiler.cpp +++ b/taichi/program/kernel_profiler.cpp @@ -114,6 +114,8 @@ class DefaultProfiler : public KernelProfilerBase { } // namespace std::unique_ptr make_profiler(Arch arch, bool enable) { + if (!enable) + return nullptr; if (arch == Arch::cuda) { #if defined(TI_WITH_CUDA) return std::make_unique(enable); diff --git a/taichi/program/kernel_profiler.h b/taichi/program/kernel_profiler.h index 3eb7f52f0f3eb..dc049d19cfef0 100644 --- a/taichi/program/kernel_profiler.h +++ b/taichi/program/kernel_profiler.h @@ -1,6 +1,6 @@ #pragma once -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" #include "taichi/lang_util.h" #include @@ -61,6 +61,10 @@ class KernelProfilerBase { virtual void sync() = 0; + virtual bool set_profiler_toolkit(std::string toolkit_name) { + return false; + } + // TODO: remove start and always use start_with_handle virtual void start(const std::string &kernel_name){TI_NOT_IMPLEMENTED}; diff --git a/taichi/program/ndarray.cpp b/taichi/program/ndarray.cpp new file mode 100644 index 0000000000000..a54923d353ade --- /dev/null +++ b/taichi/program/ndarray.cpp @@ -0,0 +1,108 @@ +#include + +#include "taichi/program/ndarray.h" +#include "taichi/program/program.h" + +#ifdef TI_WITH_LLVM +#include "taichi/llvm/llvm_context.h" +#include "taichi/llvm/llvm_program.h" +#endif + +namespace taichi { +namespace lang { + +Ndarray::Ndarray(Program *prog, + const DataType type, + const std::vector &shape) + : dtype(type), + shape(shape), + num_active_indices(shape.size()), + nelement_(std::accumulate(std::begin(shape), + std::end(shape), + 1, + std::multiplies<>())), + element_size_(data_type_size(dtype)), + device_(prog->get_device_shared()), + prog_impl_(prog->get_llvm_program_impl()), + rw_accessors_bank_(&prog->get_ndarray_rw_accessors_bank()) { + ndarray_alloc_ = prog->allocate_memory_ndarray(nelement_ * element_size_, + prog->result_buffer); +#ifdef TI_WITH_LLVM + if (arch_is_cpu(prog->config.arch) || prog->config.arch == Arch::cuda) { + // For the LLVM backends, device allocation is a physical pointer. + data_ptr_ = prog->get_llvm_program_impl()->get_ndarray_alloc_info_ptr( + ndarray_alloc_); + } +#else + TI_ERROR("Llvm disabled"); +#endif +} + +Ndarray::~Ndarray() { + if (device_) { + device_->dealloc_memory(ndarray_alloc_); + } +} + +intptr_t Ndarray::get_data_ptr_as_int() const { + return reinterpret_cast(data_ptr_); +} + +intptr_t Ndarray::get_device_allocation_ptr_as_int() const { + // taichi's own ndarray's ptr points to its |DeviceAllocation| on the + // specified device. Note that torch-based ndarray's ptr is a raw ptr but + // we'll get rid of it soon. + return reinterpret_cast(&ndarray_alloc_); +} + +std::size_t Ndarray::get_element_size() const { + return element_size_; +} + +std::size_t Ndarray::get_nelement() const { + return nelement_; +} + +void Ndarray::fill_float(float val) { + buffer_fill(reinterpret_cast(val)); +} + +void Ndarray::fill_int(int32_t val) { + buffer_fill(reinterpret_cast(val)); +} + +void Ndarray::fill_uint(uint32_t val) { + buffer_fill(reinterpret_cast(val)); +} + +int64 Ndarray::read_int(const std::vector &i) { + return rw_accessors_bank_->get(this).read_int(i); +} + +uint64 Ndarray::read_uint(const std::vector &i) { + return rw_accessors_bank_->get(this).read_uint(i); +} + +float64 Ndarray::read_float(const std::vector &i) { + return rw_accessors_bank_->get(this).read_float(i); +} + +void Ndarray::write_int(const std::vector &i, int64 val) { + rw_accessors_bank_->get(this).write_int(i, val); +} + +void Ndarray::write_float(const std::vector &i, float64 val) { + rw_accessors_bank_->get(this).write_float(i, val); +} + +void Ndarray::buffer_fill(uint32_t val) { + // This is a temporary solution to bypass device api + // should be moved to commandList when available in CUDA +#ifdef TI_WITH_LLVM + prog_impl_->fill_ndarray(ndarray_alloc_, nelement_, val); +#else + TI_ERROR("Llvm disabled"); +#endif +} +} // namespace lang +} // namespace taichi diff --git a/taichi/program/ndarray.h b/taichi/program/ndarray.h new file mode 100644 index 0000000000000..dafec26cbd1b9 --- /dev/null +++ b/taichi/program/ndarray.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +#include "taichi/inc/constants.h" +#include "taichi/ir/type_utils.h" +#include "taichi/backends/device.h" + +namespace taichi { +namespace lang { + +class Program; +class LlvmProgramImpl; +class NdarrayRwAccessorsBank; + +class Ndarray { + public: + explicit Ndarray(Program *prog, + const DataType type, + const std::vector &shape); + + DataType dtype; + // Invariant: Since ndarray indices are flattened for vector/matrix, this is + // always true: + // num_active_indices = shape.size() + std::vector shape; + int num_active_indices{0}; + + intptr_t get_data_ptr_as_int() const; + intptr_t get_device_allocation_ptr_as_int() const; + std::size_t get_element_size() const; + std::size_t get_nelement() const; + void fill_float(float val); + void fill_int(int32_t val); + void fill_uint(uint32_t val); + int64 read_int(const std::vector &i); + uint64 read_uint(const std::vector &i); + float64 read_float(const std::vector &i); + void write_int(const std::vector &i, int64 val); + void write_float(const std::vector &i, float64 val); + ~Ndarray(); + + private: + DeviceAllocation ndarray_alloc_{kDeviceNullAllocation}; + // Invariant: + // data_ptr_ is not nullptr iff arch is a llvm backend + uint64_t *data_ptr_{nullptr}; + std::size_t nelement_{1}; + std::size_t element_size_{1}; + // Ndarrays manage their own |DeviceAllocation| so this must be shared with + // |OpenGlRuntime|. Without the ownership, when the program exits |device_| + // might be destructed earlier than Ndarray object, leaving a segfault when + // you try to deallocate in Ndarray destructor. + // Note that we might consider changing this logic later if we implement + // dynamic tensor rematerialization. + std::shared_ptr device_{nullptr}; + void buffer_fill(uint32_t val); + LlvmProgramImpl *prog_impl_{nullptr}; + NdarrayRwAccessorsBank *rw_accessors_bank_{nullptr}; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/ndarray_rw_accessors_bank.cpp b/taichi/program/ndarray_rw_accessors_bank.cpp new file mode 100644 index 0000000000000..1d5078f9dbd9f --- /dev/null +++ b/taichi/program/ndarray_rw_accessors_bank.cpp @@ -0,0 +1,116 @@ +#include "taichi/program/ndarray_rw_accessors_bank.h" +#include "taichi/program/program.h" + +namespace taichi { +namespace lang { + +namespace { +void set_kernel_args(const std::vector &I, + int num_active_indices, + Kernel::LaunchContextBuilder *launch_ctx) { + for (int i = 0; i < num_active_indices; i++) { + launch_ctx->set_arg_int(i, I[i]); + } +} +void set_kernel_extra_args(const Ndarray *ndarray, + int arg_id, + Kernel::LaunchContextBuilder *launch_ctx) { + for (int i = 0; i < ndarray->num_active_indices; i++) { + launch_ctx->set_extra_arg_int(arg_id, i, ndarray->shape[i]); + } +} +} // namespace + +NdarrayRwAccessorsBank::Accessors NdarrayRwAccessorsBank::get( + Ndarray *ndarray) { + NdarrayRwKeys keys{ndarray->num_active_indices, ndarray->dtype}; + if (ndarray_to_kernels_.find(keys) == ndarray_to_kernels_.end()) { + ndarray_to_kernels_[keys] = {&(program_->get_ndarray_reader(ndarray)), + &(program_->get_ndarray_writer(ndarray))}; + } + return Accessors(ndarray, ndarray_to_kernels_[keys], program_); +} + +NdarrayRwAccessorsBank::Accessors::Accessors(const Ndarray *ndarray, + const RwKernels &kernels, + Program *prog) + : ndarray_(ndarray), + prog_(prog), + reader_(kernels.reader), + writer_(kernels.writer) { + TI_ASSERT(reader_ != nullptr); + TI_ASSERT(writer_ != nullptr); +} + +void NdarrayRwAccessorsBank::Accessors::write_float(const std::vector &I, + float64 val) { + auto launch_ctx = writer_->make_launch_context(); + set_kernel_args(I, ndarray_->num_active_indices, &launch_ctx); + launch_ctx.set_arg_float(ndarray_->num_active_indices, val); + launch_ctx.set_arg_external_array( + ndarray_->num_active_indices + 1, + ndarray_->get_device_allocation_ptr_as_int(), + ndarray_->get_nelement() * ndarray_->get_element_size(), + /*is_device_allocation=*/true); + set_kernel_extra_args(ndarray_, ndarray_->num_active_indices + 1, + &launch_ctx); + prog_->synchronize(); + (*writer_)(launch_ctx); +} + +float64 NdarrayRwAccessorsBank::Accessors::read_float( + const std::vector &I) { + prog_->synchronize(); + auto launch_ctx = reader_->make_launch_context(); + set_kernel_args(I, ndarray_->num_active_indices, &launch_ctx); + launch_ctx.set_arg_external_array( + ndarray_->num_active_indices, + ndarray_->get_device_allocation_ptr_as_int(), + ndarray_->get_nelement() * ndarray_->get_element_size(), + /*is_device_allocation=*/true); + set_kernel_extra_args(ndarray_, ndarray_->num_active_indices, &launch_ctx); + (*reader_)(launch_ctx); + prog_->synchronize(); + auto ret = reader_->get_ret_float(0); + return ret; +} + +// for int32 and int64 +void NdarrayRwAccessorsBank::Accessors::write_int(const std::vector &I, + int64 val) { + auto launch_ctx = writer_->make_launch_context(); + set_kernel_args(I, ndarray_->num_active_indices, &launch_ctx); + launch_ctx.set_arg_int(ndarray_->num_active_indices, val); + launch_ctx.set_arg_external_array( + ndarray_->num_active_indices + 1, + ndarray_->get_device_allocation_ptr_as_int(), + ndarray_->get_nelement() * ndarray_->get_element_size(), + /*is_device_allocation=*/true); + set_kernel_extra_args(ndarray_, ndarray_->num_active_indices + 1, + &launch_ctx); + prog_->synchronize(); + (*writer_)(launch_ctx); +} + +int64 NdarrayRwAccessorsBank::Accessors::read_int(const std::vector &I) { + prog_->synchronize(); + auto launch_ctx = reader_->make_launch_context(); + set_kernel_args(I, ndarray_->num_active_indices, &launch_ctx); + launch_ctx.set_arg_external_array( + ndarray_->num_active_indices, + ndarray_->get_device_allocation_ptr_as_int(), + ndarray_->get_nelement() * ndarray_->get_element_size(), + /*is_device_allocation=*/true); + set_kernel_extra_args(ndarray_, ndarray_->num_active_indices, &launch_ctx); + (*reader_)(launch_ctx); + prog_->synchronize(); + auto ret = reader_->get_ret_int(0); + return ret; +} + +uint64 NdarrayRwAccessorsBank::Accessors::read_uint(const std::vector &I) { + return (uint64)read_int(I); +} + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/ndarray_rw_accessors_bank.h b/taichi/program/ndarray_rw_accessors_bank.h new file mode 100644 index 0000000000000..567859bbf8e09 --- /dev/null +++ b/taichi/program/ndarray_rw_accessors_bank.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include "taichi/program/kernel.h" +#include "taichi/program/ndarray.h" + +namespace taichi { +namespace lang { + +class Program; +class Ndarray; + +/* Note: [Ndarray host reader & writer] + * Unlike snodes/fields which are persistent global storage that can safely + * use SNode* as keys to cache reader & writer kernels, ndarrays' life-cycle + * depends on their corresponding python objects. In other words we cannot + * use Ndarray* here as caching keys since it's possible that one ndarray reuses + * exactly the same address where a freed ndarray instance was. + * + * Fortunately since ndarray reader & writer don't hardcode ndarray address in + * the kernel, their caching mechanism can also be more efficient than the snode + * ones. Currently we only use ndarray's num_active_indices & dtype information + * (saved in NdarrayRwKeys) in the reader & writer kernels Details can be found + * in get_ndarray_reader/writer in program.cpp. + */ +struct NdarrayRwKeys { + int num_active_indices; + DataType dtype; + + struct Hasher { + std::size_t operator()(const NdarrayRwKeys &k) const { + auto h1 = std::hash{}(k.num_active_indices); + auto h2 = k.dtype.hash(); + return h1 ^ h2; + } + }; + + bool operator==(const NdarrayRwKeys &other) const { + return num_active_indices == other.num_active_indices && + dtype == other.dtype; + } +}; + +/** A mapping from a Ndarray to its read/write access kernels. + */ +class NdarrayRwAccessorsBank { + private: + struct RwKernels { + Kernel *reader{nullptr}; + Kernel *writer{nullptr}; + }; + + public: + class Accessors { + public: + explicit Accessors(const Ndarray *ndarray, + const RwKernels &kernels, + Program *prog); + + // for float and double + void write_float(const std::vector &I, float64 val); + float64 read_float(const std::vector &I); + + // for int32 and int64 + void write_int(const std::vector &I, int64 val); + int64 read_int(const std::vector &I); + uint64 read_uint(const std::vector &I); + + private: + const Ndarray *ndarray_; + Program *prog_; + Kernel *reader_; + Kernel *writer_; + }; + + explicit NdarrayRwAccessorsBank(Program *program) : program_(program) { + } + + Accessors get(Ndarray *ndarray); + + private: + Program *const program_; + std::unordered_map + ndarray_to_kernels_; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index a2352e210eeb6..48fd81c676000 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -21,14 +21,20 @@ #include "taichi/program/snode_expr_utils.h" #include "taichi/util/statistics.h" #include "taichi/math/arithmetic.h" +#ifdef TI_WITH_LLVM #include "taichi/llvm/llvm_program.h" +#endif #if defined(TI_WITH_CC) #include "taichi/backends/cc/cc_program.h" #endif #ifdef TI_WITH_VULKAN #include "taichi/backends/vulkan/vulkan_program.h" -#include "taichi/backends/vulkan/loader.h" +#include "taichi/backends/vulkan/vulkan_loader.h" +#endif +#ifdef TI_WITH_DX11 +#include "taichi/backends/dx/dx_program.h" +#include "taichi/backends/dx/dx_api.h" #endif #if defined(TI_ARCH_x64) @@ -38,10 +44,10 @@ namespace taichi { namespace lang { -Program *current_program = nullptr; std::atomic Program::num_instances_; -Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { +Program::Program(Arch desired_arch) + : snode_rw_accessors_bank_(this), ndarray_rw_accessors_bank_(this) { TI_TRACE("Program initializing..."); // For performance considerations and correctness of CustomFloatType @@ -49,7 +55,7 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { // backends (including CPUs). #if defined(TI_ARCH_x64) _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); -#else +#elif !defined(TI_EMSCRIPTENED) // Enforce flush to zero on arm64 CPUs // https://developer.arm.com/documentation/100403/0201/register-descriptions/advanced-simd-and-floating-point-registers/aarch64-register-descriptions/fpcr--floating-point-control-register?lang=en std::uint64_t fpcr; @@ -69,63 +75,62 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { profiler = make_profiler(config.arch, config.kernel_profiler); if (arch_uses_llvm(config.arch)) { +#ifdef TI_WITH_LLVM program_impl_ = std::make_unique(config, profiler.get()); - +#else + TI_ERROR("This taichi is not compiled with LLVM"); +#endif } else if (config.arch == Arch::metal) { - if (!metal::is_metal_api_available()) { - TI_WARN("No Metal API detected."); - config.arch = host_arch(); - } else { - program_impl_ = std::make_unique(config); - } - } +#ifdef TI_WITH_METAL + TI_ASSERT(metal::is_metal_api_available()); + program_impl_ = std::make_unique(config); +#else + TI_ERROR("This taichi is not compiled with Metal") +#endif + } else if (config.arch == Arch::vulkan) { #ifdef TI_WITH_VULKAN - else if (config.arch == Arch::vulkan) { - if (!vulkan::is_vulkan_api_available()) { - TI_WARN("No Vulkan API detected."); - config.arch = host_arch(); - } else { - program_impl_ = std::make_unique(config); - } - } + TI_ASSERT(vulkan::is_vulkan_api_available()); + program_impl_ = std::make_unique(config); +#else + TI_ERROR("This taichi is not compiled with Vulkan") #endif - - if (config.arch == Arch::opengl) { - if (!opengl::is_opengl_api_available()) { - TI_WARN("No OpenGL API detected."); - config.arch = host_arch(); - } else { - program_impl_ = std::make_unique(config); - } - } - - if (config.arch == Arch::cc) { + } else if (config.arch == Arch::dx11) { +#ifdef TI_WITH_DX11 + TI_ASSERT(directx11::is_dx_api_available()); + program_impl_ = std::make_unique(config); +#else + TI_ERROR("This taichi is not compiled with DX11"); +#endif + } else if (config.arch == Arch::opengl) { + TI_ASSERT(opengl::initialize_opengl(config.use_gles)); + program_impl_ = std::make_unique(config); + } else if (config.arch == Arch::cc) { #ifdef TI_WITH_CC program_impl_ = std::make_unique(config); #else - TI_WARN("No C backend detected."); - config.arch = host_arch(); + TI_ERROR("No C backend detected."); #endif + } else { + TI_NOT_IMPLEMENTED } - if (config.arch != desired_arch) { - TI_WARN("Falling back to {}", arch_name(config.arch)); - } + // program_impl_ should be set in the if-else branch above + TI_ASSERT(program_impl_); Device *compute_device = nullptr; - if (program_impl_.get()) { - compute_device = program_impl_->get_compute_device(); - } + compute_device = program_impl_->get_compute_device(); // Must have handled all the arch fallback logic by this point. memory_pool_ = std::make_unique(config.arch, compute_device); TI_ASSERT_INFO(num_instances_ == 0, "Only one instance at a time"); total_compilation_time_ = 0; num_instances_ += 1; SNode::counter = 0; - TI_ASSERT(current_program == nullptr); - current_program = this; if (arch_uses_llvm(config.arch)) { +#if TI_WITH_LLVM static_cast(program_impl_.get())->initialize_host(); +#else + TI_NOT_IMPLEMENTED +#endif } result_buffer = nullptr; @@ -137,7 +142,7 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { TI_WARN("Running in async mode. This is experimental."); TI_ASSERT(is_extension_supported(config.arch, Extension::async_mode)); async_engine = std::make_unique( - &config, snodes, [this](Kernel &kernel, OffloadedStmt *offloaded) { + &config, [this](Kernel &kernel, OffloadedStmt *offloaded) { return this->compile(kernel, offloaded); }); } @@ -176,40 +181,40 @@ Function *Program::create_function(const FunctionKey &func_key) { FunctionType Program::compile(Kernel &kernel, OffloadedStmt *offloaded) { auto start_t = Time::get_time(); TI_AUTO_PROF; - FunctionType ret = nullptr; - if (arch_uses_llvm(config.arch) || kernel.arch == Arch::metal || - kernel.arch == Arch::vulkan || kernel.arch == Arch::opengl || - kernel.arch == Arch::cc) { - return program_impl_->compile(&kernel, offloaded); - } else { - TI_NOT_IMPLEMENTED; - } + auto ret = program_impl_->compile(&kernel, offloaded); TI_ASSERT(ret); total_compilation_time_ += Time::get_time() - start_t; return ret; } void Program::materialize_runtime() { - if (arch_uses_llvm(config.arch) || config.arch == Arch::metal || - config.arch == Arch::vulkan || config.arch == Arch::opengl || - config.arch == Arch::cc) { - program_impl_->materialize_runtime(memory_pool_.get(), profiler.get(), - &result_buffer); - } + program_impl_->materialize_runtime(memory_pool_.get(), profiler.get(), + &result_buffer); } void Program::destroy_snode_tree(SNodeTree *snode_tree) { TI_ASSERT(arch_uses_llvm(config.arch) || config.arch == Arch::vulkan); program_impl_->destroy_snode_tree(snode_tree); + free_snode_tree_ids_.push(snode_tree->id()); } -SNodeTree *Program::add_snode_tree(std::unique_ptr root) { - const int id = snode_trees_.size(); +SNodeTree *Program::add_snode_tree(std::unique_ptr root, + bool compile_only) { + const int id = allocate_snode_tree_id(); auto tree = std::make_unique(id, std::move(root)); tree->root()->set_snode_tree_id(id); - materialize_snode_tree(tree.get()); - snode_trees_.push_back(std::move(tree)); - + if (compile_only) { + program_impl_->compile_snode_tree_types(tree.get(), snode_trees_); + } else { + program_impl_->materialize_snode_tree(tree.get(), snode_trees_, + result_buffer); + } + if (id < snode_trees_.size()) { + snode_trees_[id] = std::move(tree); + } else { + TI_ASSERT(id == snode_trees_.size()); + snode_trees_.push_back(std::move(tree)); + } return snode_trees_[id].get(); } @@ -217,19 +222,14 @@ SNode *Program::get_snode_root(int tree_id) { return snode_trees_[tree_id]->root(); } -void Program::materialize_snode_tree(SNodeTree *tree) { - if (arch_is_cpu(config.arch) || config.arch == Arch::cuda || - config.arch == Arch::metal || config.arch == Arch::vulkan || - config.arch == Arch::opengl || config.arch == Arch::cc) { - program_impl_->materialize_snode_tree(tree, snode_trees_, snodes, - result_buffer); - } -} - void Program::check_runtime_error() { +#ifdef TI_WITH_LLVM TI_ASSERT(arch_uses_llvm(config.arch)); static_cast(program_impl_.get()) ->check_runtime_error(result_buffer); +#else + TI_ERROR("Llvm disabled"); +#endif } void Program::synchronize() { @@ -237,9 +237,6 @@ void Program::synchronize() { if (config.async_mode) { async_engine->synchronize(); } - if (profiler) { - profiler->sync(); - } if (arch_uses_llvm(config.arch) || config.arch == Arch::metal || config.arch == Arch::vulkan) { program_impl_->synchronize(); @@ -343,7 +340,7 @@ void Program::visualize_layout(const std::string &fn) { trash(system(fmt::format("pdflatex {}", fn).c_str())); } -Arch Program::get_snode_accessor_arch() { +Arch Program::get_accessor_arch() { if (config.arch == Arch::opengl) { return Arch::opengl; } else if (config.arch == Arch::vulkan) { @@ -354,6 +351,8 @@ Arch Program::get_snode_accessor_arch() { return Arch::metal; } else if (config.arch == Arch::cc) { return Arch::cc; + } else if (config.arch == Arch::dx11) { + return Arch::dx11; } else { return get_host_arch(); } @@ -368,10 +367,10 @@ Kernel &Program::get_snode_reader(SNode *snode) { indices.push_back(Expr::make(i, PrimitiveType::i32)); } auto ret = Stmt::make( - load_if_ptr(Expr(snode_to_glb_var_exprs_.at(snode))[indices])); - current_ast_builder().insert(std::move(ret)); + ExprGroup(Expr(snode_to_glb_var_exprs_.at(snode))[indices])); + this->current_ast_builder()->insert(std::move(ret)); }); - ker.set_arch(get_snode_accessor_arch()); + ker.set_arch(get_accessor_arch()); ker.name = kernel_name; ker.is_accessor = true; for (int i = 0; i < snode->num_active_indices; i++) @@ -388,11 +387,12 @@ Kernel &Program::get_snode_writer(SNode *snode) { for (int i = 0; i < snode->num_active_indices; i++) { indices.push_back(Expr::make(i, PrimitiveType::i32)); } - Expr(snode_to_glb_var_exprs_.at(snode))[indices] = - Expr::make(snode->num_active_indices, - snode->dt->get_compute_type()); + auto expr = Expr(snode_to_glb_var_exprs_.at(snode))[indices]; + this->current_ast_builder()->insert_assignment( + expr, Expr::make(snode->num_active_indices, + snode->dt->get_compute_type())); }); - ker.set_arch(get_snode_accessor_arch()); + ker.set_arch(get_accessor_arch()); ker.name = kernel_name; ker.is_accessor = true; for (int i = 0; i < snode->num_active_indices; i++) @@ -401,10 +401,65 @@ Kernel &Program::get_snode_writer(SNode *snode) { return ker; } +Kernel &Program::get_ndarray_reader(Ndarray *ndarray) { + static uint64 ndarray_reader_counter = 0; + auto kernel_name = fmt::format("ndarray_reader_{}", ndarray_reader_counter++); + NdarrayRwKeys keys{ndarray->num_active_indices, ndarray->dtype}; + auto &ker = kernel([keys, this] { + ExprGroup indices; + for (int i = 0; i < keys.num_active_indices; i++) { + indices.push_back(Expr::make(i, PrimitiveType::i32)); + } + auto ret = Stmt::make( + ExprGroup(Expr(Expr::make( + keys.dtype, keys.num_active_indices, keys.num_active_indices, + 0))[indices])); + this->current_ast_builder()->insert(std::move(ret)); + }); + ker.set_arch(get_accessor_arch()); + ker.name = kernel_name; + ker.is_accessor = true; + for (int i = 0; i < keys.num_active_indices; i++) + ker.insert_arg(PrimitiveType::i32, false); + ker.insert_arg(keys.dtype, true); + ker.insert_ret(keys.dtype); + return ker; +} + +Kernel &Program::get_ndarray_writer(Ndarray *ndarray) { + static uint64 ndarray_writer_counter = 0; + auto kernel_name = fmt::format("ndarray_writer_{}", ndarray_writer_counter++); + NdarrayRwKeys keys{ndarray->num_active_indices, ndarray->dtype}; + auto &ker = kernel([keys, this] { + ExprGroup indices; + for (int i = 0; i < keys.num_active_indices; i++) { + indices.push_back(Expr::make(i, PrimitiveType::i32)); + } + auto expr = Expr(Expr::make( + keys.dtype, keys.num_active_indices, keys.num_active_indices + 1, + 0))[indices]; + this->current_ast_builder()->insert_assignment( + expr, Expr::make(keys.num_active_indices, + keys.dtype->get_compute_type())); + }); + ker.set_arch(get_accessor_arch()); + ker.name = kernel_name; + ker.is_accessor = true; + for (int i = 0; i < keys.num_active_indices; i++) + ker.insert_arg(PrimitiveType::i32, false); + ker.insert_arg(keys.dtype, false); + ker.insert_arg(keys.dtype, true); + return ker; +} + uint64 Program::fetch_result_uint64(int i) { if (arch_uses_llvm(config.arch)) { +#ifdef TI_WITH_LLVM return static_cast(program_impl_.get()) ->fetch_result(i, result_buffer); +#else + TI_NOT_IMPLEMENTED +#endif } return result_buffer[i]; } @@ -414,6 +469,7 @@ void Program::finalize() { if (async_engine) async_engine = nullptr; // Finalize the async engine threads before // anything else gets destoried. + TI_TRACE("Program finalizing..."); if (config.print_benchmark_stat) { const char *current_test = std::getenv("PYTEST_CURRENT_TEST"); @@ -450,11 +506,14 @@ void Program::finalize() { } synchronize(); - current_program = nullptr; memory_pool_->terminate(); if (arch_uses_llvm(config.arch)) { +#if TI_WITH_LLVM static_cast(program_impl_.get())->finalize(); +#else + TI_NOT_IMPLEMENTED +#endif } finalized_ = true; @@ -471,9 +530,13 @@ int Program::default_block_dim(const CompileConfig &config) { } void Program::print_memory_profiler_info() { +#ifdef TI_WITH_LLVM TI_ASSERT(arch_uses_llvm(config.arch)); static_cast(program_impl_.get()) ->print_memory_profiler_info(snode_trees_, result_buffer); +#else + TI_ERROR("Llvm disabled"); +#endif } std::size_t Program::get_snode_num_dynamically_allocated(SNode *snode) { @@ -494,7 +557,11 @@ std::unique_ptr Program::make_aot_module_builder(Arch arch) { // platform. Consider decoupling this part if (arch == Arch::wasm) { // Have to check WASM first, or it dispatches to the LlvmProgramImpl. +#ifdef TI_WITH_LLVM return std::make_unique(); +#else + TI_NOT_IMPLEMENTED +#endif } if (arch_uses_llvm(config.arch) || config.arch == Arch::metal || config.arch == Arch::vulkan || config.arch == Arch::opengl) { @@ -504,7 +571,21 @@ std::unique_ptr Program::make_aot_module_builder(Arch arch) { } LlvmProgramImpl *Program::get_llvm_program_impl() { +#ifdef TI_WITH_LLVM return static_cast(program_impl_.get()); +#else + TI_ERROR("Llvm disabled"); +#endif +} + +int Program::allocate_snode_tree_id() { + if (free_snode_tree_ids_.empty()) { + return snode_trees_.size(); + } else { + int id = free_snode_tree_ids_.top(); + free_snode_tree_ids_.pop(); + return id; + } } } // namespace lang diff --git a/taichi/program/program.h b/taichi/program/program.h index f590c0607dda7..c35b7f45d8868 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -5,20 +5,23 @@ #include #include #include +#include #define TI_RUNTIME_HOST +#include "taichi/aot/module_builder.h" +#include "taichi/ir/frontend_ir.h" #include "taichi/ir/ir.h" #include "taichi/ir/type_factory.h" #include "taichi/ir/snode.h" #include "taichi/lang_util.h" #include "taichi/program/program_impl.h" #include "taichi/program/callable.h" -#include "taichi/program/aot_module_builder.h" #include "taichi/program/function.h" #include "taichi/program/kernel.h" #include "taichi/program/kernel_profiler.h" #include "taichi/program/snode_expr_utils.h" #include "taichi/program/snode_rw_accessors_bank.h" +#include "taichi/program/ndarray_rw_accessors_bank.h" #include "taichi/program/context.h" #include "taichi/runtime/runtime.h" #include "taichi/struct/snode_tree.h" @@ -26,6 +29,7 @@ #include "taichi/system/threading.h" #include "taichi/system/unified_allocator.h" #include "taichi/program/sparse_matrix.h" +#include "taichi/ir/mesh.h" namespace taichi { namespace lang { @@ -72,12 +76,6 @@ struct hash { namespace taichi { namespace lang { -extern Program *current_program; - -TI_FORCE_INLINE Program &get_current_program() { - return *current_program; -} - class StructCompiler; class LlvmProgramImpl; class AsyncEngine; @@ -95,7 +93,7 @@ class AsyncEngine; * LlvmProgramImpl, MetalProgramImpl.. */ -class Program { +class TI_DLL_EXPORT Program { public: using Kernel = taichi::lang::Kernel; Callable *current_callable{nullptr}; @@ -104,9 +102,6 @@ class Program { uint64 *result_buffer{nullptr}; // Note result_buffer is used by all backends - std::unordered_map - snodes; // TODO: seems LLVM specific but used by state_flow_graph.cpp. - std::unique_ptr async_engine{nullptr}; std::vector> kernels; @@ -184,6 +179,16 @@ class Program { return *kernels.back(); } + Kernel &kernel(const std::function &body, + const std::string &name = "", + bool grad = false) { + // Expr::set_allow_store(true); + auto func = std::make_unique(*this, body, name, grad); + // Expr::set_allow_store(false); + kernels.emplace_back(std::move(func)); + return *kernels.back(); + } + Function *create_function(const FunctionKey &func_key); // TODO: This function is doing two things: 1) compiling CHI IR, and 2) @@ -198,6 +203,10 @@ class Program { Kernel &get_snode_writer(SNode *snode); + Kernel &get_ndarray_reader(Ndarray *ndarray); + + Kernel &get_ndarray_writer(Ndarray *ndarray); + uint64 fetch_result_uint64(int i); template @@ -209,7 +218,7 @@ class Program { return host_arch(); } - Arch get_snode_accessor_arch(); + Arch get_accessor_arch(); float64 get_total_compilation_time() { return total_compilation_time_; @@ -240,6 +249,10 @@ class Program { return snode_rw_accessors_bank_; } + inline NdarrayRwAccessorsBank &get_ndarray_rw_accessors_bank() { + return ndarray_rw_accessors_bank_; + } + /** * Destroys a new SNode tree. * @@ -251,9 +264,26 @@ class Program { * Adds a new SNode tree. * * @param root The root of the new SNode tree. + * @param compile_only Only generates the compiled type * @return The pointer to SNode tree. + * + * FIXME: compile_only is mostly a hack to make AOT & cross-compilation work. + * E.g. users who would like to AOT to a specific target backend can do so, + * even if their platform doesn't support that backend. Unfortunately, the + * current implementation would leave the backend in a mostly broken state. We + * need a cleaner design to support both AOT and JIT modes. + */ + SNodeTree *add_snode_tree(std::unique_ptr root, bool compile_only); + + /** + * Allocates a SNode tree id for a new SNode tree + * + * @return The SNode tree id allocated + * + * Returns and consumes a free SNode tree id if there is any, + * Otherwise returns the size of `snode_trees_` */ - SNodeTree *add_snode_tree(std::unique_ptr root); + int allocate_snode_tree_id(); /** * Gets the root of a SNode tree. @@ -279,19 +309,28 @@ class Program { return program_impl_->get_graphics_device(); } - private: - /** - * Materializes a new SNodeTree. - * - * JIT compiles the @param tree to backend-specific data types. - */ - void materialize_snode_tree(SNodeTree *tree); + std::shared_ptr get_device_shared() { + return program_impl_->get_device_shared(); + } + + // TODO: do we still need result_buffer? + DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, + uint64 *result_buffer) { + return program_impl_->allocate_memory_ndarray(alloc_size, result_buffer); + } + + ASTBuilder *current_ast_builder() { + return current_callable ? ¤t_callable->context->builder() : nullptr; + } + private: // SNode information that requires using Program. SNodeGlobalVarExprMap snode_to_glb_var_exprs_; SNodeRwAccessorsBank snode_rw_accessors_bank_; + NdarrayRwAccessorsBank ndarray_rw_accessors_bank_; std::vector> snode_trees_; + std::stack free_snode_tree_ids_; std::vector> functions_; std::unordered_map function_map_; diff --git a/taichi/program/program_impl.cpp b/taichi/program/program_impl.cpp index a74af6efcf08b..08eb8007f1a98 100644 --- a/taichi/program/program_impl.cpp +++ b/taichi/program/program_impl.cpp @@ -6,5 +6,12 @@ namespace lang { ProgramImpl::ProgramImpl(CompileConfig &config_) : config(&config_) { } +void ProgramImpl::compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees) { + // FIXME: Eventually all the backends should implement this + TI_NOT_IMPLEMENTED; +} + } // namespace lang } // namespace taichi diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index 5fab64ef62aa0..3ba3645598a70 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -1,10 +1,12 @@ #pragma once + +#include "taichi/aot/module_builder.h" +#include "taichi/ir/statements.h" #include "taichi/system/memory_pool.h" #include "taichi/common/logging.h" #include "taichi/struct/snode_tree.h" #include "taichi/program/snode_expr_utils.h" #include "taichi/program/kernel_profiler.h" -#include "taichi/program/aot_module_builder.h" #include "taichi/backends/device.h" namespace taichi { @@ -33,12 +35,18 @@ class ProgramImpl { uint64 **result_buffer_ptr) = 0; /** - * Run StructCompiler for the backend. + * JIT compiles @param tree to backend-specific data types. + */ + virtual void compile_snode_tree_types( + SNodeTree *tree, + std::vector> &snode_trees); + + /** + * Compiles the @param tree types and allocates runtime buffer for it. */ virtual void materialize_snode_tree( SNodeTree *tree, std::vector> &snode_trees_, - std::unordered_map &snodes, uint64 *result_buffer_ptr) = 0; virtual void destroy_snode_tree(SNodeTree *snode_tree) = 0; @@ -65,10 +73,18 @@ class ProgramImpl { return nullptr; } + virtual std::shared_ptr get_device_shared() { + return nullptr; + } + virtual DevicePtr get_snode_tree_device_ptr(int tree_id) { return kDeviceNullPtr; } + virtual DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, + uint64 *result_buffer) { + return kDeviceNullAllocation; + } virtual ~ProgramImpl() { } diff --git a/taichi/program/snode_expr_utils.cpp b/taichi/program/snode_expr_utils.cpp index b7b6823cf3441..666c08790ff91 100644 --- a/taichi/program/snode_expr_utils.cpp +++ b/taichi/program/snode_expr_utils.cpp @@ -1,4 +1,6 @@ #include "taichi/program/snode_expr_utils.h" +#include "taichi/ir/snode.h" +#include "taichi/ir/frontend_ir.h" namespace taichi { namespace lang { diff --git a/taichi/program/snode_expr_utils.h b/taichi/program/snode_expr_utils.h index f4abf2b28a27e..28fd9b3b97046 100644 --- a/taichi/program/snode_expr_utils.h +++ b/taichi/program/snode_expr_utils.h @@ -2,9 +2,7 @@ #include #include - -#include "taichi/ir/snode.h" -#include "taichi/ir/frontend_ir.h" +#include // This file groups the set of helpers that need the Expr associated with a // given SNode. Expr is part of the frontend, which somehow depends on the @@ -12,6 +10,9 @@ // on less, we thus move SNode-Expr related utils away from SNode itself. namespace taichi { namespace lang { +class Expr; +class SNode; +class GlobalVariableExpression; using SNodeGlobalVarExprMap = std::unordered_map(max_num_triplets_ * 3 * element_size); +} + +template +void SparseMatrixBuilder::print_template() { + fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_, num_triplets_, max_num_triplets_); + T *data = reinterpret_cast(data_base_ptr_.get()); for (int64 i = 0; i < num_triplets_; i++) { - fmt::print("({}, {}) val={}", data_[i * 3], data_[i * 3 + 1], - taichi_union_cast(data_[i * 3 + 2])); + fmt::print("({}, {}) val={}\n", ((G *)data)[i * 3], ((G *)data)[i * 3 + 1], + taichi_union_cast(data[i * 3 + 2])); } fmt::print("\n"); } -SparseMatrix SparseMatrixBuilder::build() { - TI_ASSERT(built_ == false); - built_ = true; - using T = Eigen::Triplet; - std::vector triplets; +void SparseMatrixBuilder::print_triplets() { + auto element_size = data_type_size(dtype_); + switch (element_size) { + case 4: + print_template(); + break; + case 8: + print_template(); + break; + default: + TI_ERROR("Unsupported sparse matrix data type!"); + break; + } +} + +template +SparseMatrix SparseMatrixBuilder::build_template() { + using V = Eigen::Triplet; + std::vector triplets; + T *data = reinterpret_cast(data_base_ptr_.get()); for (int i = 0; i < num_triplets_; i++) { - triplets.push_back(T(data_[i * 3], data_[i * 3 + 1], - taichi_union_cast(data_[i * 3 + 2]))); + triplets.push_back(V(((G *)data)[i * 3], ((G *)data)[i * 3 + 1], + taichi_union_cast(data[i * 3 + 2]))); } SparseMatrix sm(rows_, cols_); sm.get_matrix().setFromTriplets(triplets.begin(), triplets.end()); @@ -45,6 +64,21 @@ SparseMatrix SparseMatrixBuilder::build() { return sm; } +SparseMatrix SparseMatrixBuilder::build() { + TI_ASSERT(built_ == false); + built_ = true; + auto element_size = data_type_size(dtype_); + switch (element_size) { + case 4: + return build_template(); + case 8: + return build_template(); + default: + TI_ERROR("Unsupported sparse matrix data type!"); + break; + } +} + void SparseMatrixBuilder::clear() { built_ = false; num_triplets_ = 0; diff --git a/taichi/program/sparse_matrix.h b/taichi/program/sparse_matrix.h index 2de867d1ce249..f5baa62539912 100644 --- a/taichi/program/sparse_matrix.h +++ b/taichi/program/sparse_matrix.h @@ -2,6 +2,8 @@ #include "taichi/common/core.h" #include "taichi/inc/constants.h" +#include "taichi/ir/type_utils.h" + #include "Eigen/Sparse" namespace taichi { @@ -11,9 +13,7 @@ class SparseMatrix; class SparseMatrixBuilder { public: - SparseMatrixBuilder(int rows, int cols, int max_num_triplets); - - void *get_data_base_ptr(); + SparseMatrixBuilder(int rows, int cols, int max_num_triplets, DataType dtype); void print_triplets(); @@ -21,14 +21,21 @@ class SparseMatrixBuilder { void clear(); + private: + template + void print_template(); + + template + SparseMatrix build_template(); + private: uint64 num_triplets_{0}; - void *data_base_ptr_{nullptr}; - std::vector data_; + std::unique_ptr data_base_ptr_{nullptr}; int rows_{0}; int cols_{0}; uint64 max_num_triplets_{0}; bool built_{false}; + DataType dtype_{PrimitiveType::f32}; }; class SparseMatrix { diff --git a/taichi/program/sparse_solver.cpp b/taichi/program/sparse_solver.cpp index 52f801bb07b1c..77d0bf8981921 100644 --- a/taichi/program/sparse_solver.cpp +++ b/taichi/program/sparse_solver.cpp @@ -1,24 +1,26 @@ +#include "taichi/ir/type_utils.h" + #include "sparse_solver.h" #include -#define MAKE_SOLVER(type, order) \ - { \ - {#type, #order}, []() -> std::unique_ptr { \ - using T = \ - Eigen::Simplicial##type, Eigen::Lower, \ - Eigen::order##Ordering>; \ - return std::make_unique>(); \ - } \ +#define MAKE_SOLVER(dt, type, order) \ + { \ + {#dt, #type, #order}, []() -> std::unique_ptr { \ + using T = Eigen::Simplicial##type, Eigen::Lower, \ + Eigen::order##Ordering>; \ + return std::make_unique>(); \ + } \ } +using Triplets = std::tuple; namespace { -struct pair_hash { - template - std::size_t operator()(const std::pair &p) const { - auto h1 = std::hash{}(p.first); - auto h2 = std::hash{}(p.second); - return h1 ^ h2; +struct key_hash { + std::size_t operator()(const Triplets &k) const { + auto h1 = std::hash{}(std::get<0>(k)); + auto h2 = std::hash{}(std::get<1>(k)); + auto h3 = std::hash{}(std::get<2>(k)); + return h1 ^ h2 ^ h3; } }; } // namespace @@ -55,20 +57,23 @@ bool EigenSparseSolver::info() { return solver_.info() == Eigen::Success; } -std::unique_ptr make_sparse_solver(const std::string &solver_type, +std::unique_ptr make_sparse_solver(DataType dt, + const std::string &solver_type, const std::string &ordering) { - using key_type = std::pair; + using key_type = Triplets; using func_type = std::unique_ptr (*)(); - static const std::unordered_map + static const std::unordered_map solver_factory = { - MAKE_SOLVER(LLT, AMD), - MAKE_SOLVER(LLT, COLAMD), - MAKE_SOLVER(LDLT, AMD), - MAKE_SOLVER(LDLT, COLAMD), - }; - std::pair solver_key = - std::make_pair(solver_type, ordering); + MAKE_SOLVER(float32, LLT, AMD), MAKE_SOLVER(float32, LLT, COLAMD), + MAKE_SOLVER(float32, LDLT, AMD), MAKE_SOLVER(float32, LDLT, COLAMD)}; + static const std::unordered_map dt_map = { + {"f32", "float32"}, {"f64", "float64"}}; + auto it = dt_map.find(taichi::lang::data_type_name(dt)); + if (it == dt_map.end()) + TI_ERROR("Not supported sparse solver data type: {}", + taichi::lang::data_type_name(dt)); + Triplets solver_key = std::make_tuple(it->second, solver_type, ordering); if (solver_factory.find(solver_key) != solver_factory.end()) { auto solver_func = solver_factory.at(solver_key); return solver_func(); diff --git a/taichi/program/sparse_solver.h b/taichi/program/sparse_solver.h index aeae65269aabb..8325802bdc288 100644 --- a/taichi/program/sparse_solver.h +++ b/taichi/program/sparse_solver.h @@ -1,5 +1,7 @@ #pragma once +#include "taichi/ir/type.h" + #include "sparse_matrix.h" namespace taichi { @@ -21,16 +23,16 @@ class EigenSparseSolver : public SparseSolver { EigenSolver solver_; public: - virtual ~EigenSparseSolver() = default; - virtual bool compute(const SparseMatrix &sm) override; - virtual void analyze_pattern(const SparseMatrix &sm) override; - virtual void factorize(const SparseMatrix &sm) override; - virtual Eigen::VectorXf solve( - const Eigen::Ref &b) override; - virtual bool info() override; + ~EigenSparseSolver() override = default; + bool compute(const SparseMatrix &sm) override; + void analyze_pattern(const SparseMatrix &sm) override; + void factorize(const SparseMatrix &sm) override; + Eigen::VectorXf solve(const Eigen::Ref &b) override; + bool info() override; }; -std::unique_ptr make_sparse_solver(const std::string &solver_type, +std::unique_ptr make_sparse_solver(DataType dt, + const std::string &solver_type, const std::string &ordering); } // namespace lang diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 5b02f9c52a453..aa4e6a1ae9e3b 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -240,8 +240,7 @@ void StateFlowGraph::Node::disconnect_with(StateFlowGraph::Node *other) { StateFlowGraph::StateFlowGraph(AsyncEngine *engine, IRBank *ir_bank, - const CompileConfig *const config, - const std::unordered_map &snodes) + const CompileConfig *const config) : first_pending_task_index_(1 /*after initial node*/), ir_bank_(ir_bank), engine_(engine), @@ -255,10 +254,6 @@ StateFlowGraph::StateFlowGraph(AsyncEngine *engine, initial_node_->input_edges.node_id = 0; initial_node_->output_edges.node_id = 0; initial_node_->mark_executed(); - - for (const auto snode : snodes) { - list_up_to_date_[snode.second] = false; - } } std::vector StateFlowGraph::get_pending_tasks() const { @@ -1016,7 +1011,7 @@ std::string StateFlowGraph::dump_dot(const std::optional &rankdir, const auto tt = nd->meta->type; if (!nd->is_initial_node && (tt == TaskType::range_for || tt == TaskType::struct_for || - tt == TaskType::serial)) { + tt == TaskType::mesh_for || tt == TaskType::serial)) { // ss << " style=filled fillcolor=lightgray"; } ss << "]\n"; @@ -1329,6 +1324,7 @@ bool StateFlowGraph::optimize_dead_store() { // |mt| is not the desired type. if ((mt == OffloadedTaskType::serial || mt == OffloadedTaskType::struct_for || + mt == OffloadedTaskType::mesh_for || mt == OffloadedTaskType::range_for) && ir->body->statements.empty()) { to_delete.insert(i + first_pending_task_index_); diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index c14785ef964eb..ce3aec83b69b9 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -6,8 +6,13 @@ #include #include +#ifdef TI_WITH_LLVM #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#else +#include +#include +#endif #include "taichi/ir/ir.h" #include "taichi/lang_util.h" #include "taichi/program/async_utils.h" @@ -28,7 +33,11 @@ class StateFlowGraph { public: static constexpr unsigned kNumInlined = 8u; using Edge = std::pair; +#ifdef TI_WITH_LLVM using Container = llvm::SmallVector; +#else + using Container = std::vector; +#endif StateToNodesMap() = default; @@ -227,8 +236,7 @@ class StateFlowGraph { StateFlowGraph(AsyncEngine *engine, IRBank *ir_bank, - const CompileConfig *const config, - const std::unordered_map &snodes); + const CompileConfig *const config); std::vector get_pending_tasks() const; @@ -318,8 +326,13 @@ class StateFlowGraph { AsyncState get_async_state(Kernel *kernel); void populate_latest_state_owner(std::size_t id); +#ifdef TI_WITH_LLVM using LatestStateReaders = llvm::SmallVector>, 4>; +#else + using LatestStateReaders = + std::vector>>; +#endif private: std::vector> nodes_; diff --git a/taichi/python/exception.cpp b/taichi/python/exception.cpp index 2742b9ddeb5f3..2154ec0762428 100644 --- a/taichi/python/exception.cpp +++ b/taichi/python/exception.cpp @@ -13,6 +13,6 @@ void raise_assertion_failure_in_python(const std::string &msg) { TI_NAMESPACE_END -TI_EXPORT void taichi_raise_assertion_failure_in_python(const char *msg) { +void taichi_raise_assertion_failure_in_python(const char *msg) { taichi::raise_assertion_failure_in_python(std::string(msg)); } diff --git a/taichi/python/exception.h b/taichi/python/exception.h index 00c4244484aa9..1397f2070631c 100644 --- a/taichi/python/exception.h +++ b/taichi/python/exception.h @@ -13,13 +13,13 @@ TI_NAMESPACE_BEGIN class ExceptionForPython : public std::exception { private: - std::string msg; + std::string msg_; public: - ExceptionForPython(const std::string &msg) : msg(msg) { + ExceptionForPython(const std::string &msg) : msg_(msg) { } - char const *what() const throw() { - return msg.c_str(); + char const *what() const throw() override { + return msg_.c_str(); } }; diff --git a/taichi/python/export_ggui.cpp b/taichi/python/export_ggui.cpp index 7313da3a5dbf3..d45e293b68275 100644 --- a/taichi/python/export_ggui.cpp +++ b/taichi/python/export_ggui.cpp @@ -19,6 +19,7 @@ namespace py = pybind11; #include "taichi/ui/backends/vulkan/scene.h" #include "taichi/ui/common/field_info.h" #include "taichi/ui/common/gui_base.h" +#include TI_UI_NAMESPACE_BEGIN @@ -229,19 +230,25 @@ struct PyCanvas { }; struct PyWindow { - WindowBase *window; + std::unique_ptr window{nullptr}; - PyWindow(std::string name, + PyWindow(Program *prog, + std::string name, py::tuple res, bool vsync, + bool show_window, std::string package_path, Arch ti_arch, bool is_packed_mode) { - AppConfig config = {name, res[0].cast(), res[1].cast(), - vsync, package_path, ti_arch, - is_packed_mode}; + AppConfig config = {name, res[0].cast(), res[1].cast(), + vsync, show_window, package_path, + ti_arch, is_packed_mode}; // todo: support other ggui backends - window = new vulkan::Window(config); + window = std::make_unique(prog, config); + } + + void write_image(const std::string &filename) { + window->write_image(filename); } void show() { @@ -293,8 +300,10 @@ struct PyWindow { return py::make_tuple(x, y); } - ~PyWindow() { - delete window; + void destroy() { + if (window) { + window.reset(); + } } }; @@ -302,17 +311,20 @@ void export_ggui(py::module &m) { m.attr("GGUI_AVAILABLE") = py::bool_(true); py::class_(m, "PyWindow") - .def(py::init()) + .def(py::init()) .def("get_canvas", &PyWindow::get_canvas) .def("show", &PyWindow::show) + .def("write_image", &PyWindow::write_image) .def("is_pressed", &PyWindow::is_pressed) .def("get_cursor_pos", &PyWindow::py_get_cursor_pos) .def("is_running", &PyWindow::is_running) .def("set_is_running", &PyWindow::set_is_running) .def("get_event", &PyWindow::get_event) .def("get_events", &PyWindow::get_events) - .def_property("event", &PyWindow::get_current_event, - &PyWindow::set_current_event) + .def("get_current_event", &PyWindow::get_current_event) + .def("set_current_event", &PyWindow::set_current_event) + .def("destroy", &PyWindow::destroy) .def("GUI", &PyWindow::GUI); py::class_(m, "PyCanvas") @@ -381,6 +393,7 @@ void export_ggui(py::module &m) { py::enum_(m, "FieldSource") .value("TaichiCuda", FieldSource::TaichiCuda) .value("TaichiX64", FieldSource::TaichiX64) + .value("TaichiVulkan", FieldSource::TaichiVulkan) .export_values(); py::enum_(m, "FieldType") diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 7d46ace327a98..a8cbfa1bafe33 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -2,8 +2,11 @@ #include #include +#include "taichi/ir/snode.h" +#if TI_WITH_LLVM #include "llvm/Config/llvm-config.h" +#endif #include "pybind11/functional.h" #include "pybind11/pybind11.h" @@ -15,8 +18,7 @@ #include "taichi/ir/statements.h" #include "taichi/program/extension.h" #include "taichi/program/async_engine.h" -#include "taichi/program/snode_expr_utils.h" -#include "taichi/program/snode_rw_accessors_bank.h" +#include "taichi/program/ndarray.h" #include "taichi/common/interface.h" #include "taichi/python/export.h" #include "taichi/gui/gui.h" @@ -27,6 +29,7 @@ #include "taichi/python/snode_registry.h" #include "taichi/program/sparse_matrix.h" #include "taichi/program/sparse_solver.h" +#include "taichi/ir/mesh.h" #include "taichi/program/kernel_profiler.h" @@ -49,28 +52,20 @@ Expr expr_index(const Expr &expr, const Expr &index) { return expr[index]; } -void expr_assign(const Expr &lhs_, const Expr &rhs, std::string tb) { - auto lhs = ptr_if_global(lhs_); - TI_ASSERT(lhs->is_lvalue()); - auto stmt = std::make_unique(lhs, load_if_ptr(rhs)); - stmt->set_tb(tb); - current_ast_builder().insert(std::move(stmt)); -} - -std::vector> scope_stack; - std::string libdevice_path(); -SNodeRwAccessorsBank::Accessors get_snode_rw_accessors(SNode *snode) { - return get_current_program().get_snode_rw_accessors_bank().get(snode); -} - TLANG_NAMESPACE_END TI_NAMESPACE_BEGIN void export_lang(py::module &m) { using namespace taichi::lang; + py::register_exception(m, "TaichiTypeError", + PyExc_TypeError); + py::register_exception(m, "TaichiSyntaxError", + PyExc_SyntaxError); + py::register_exception(m, "TaichiRuntimeError", + PyExc_RuntimeError); py::enum_(m, "Arch", py::arithmetic()) #define PER_ARCH(x) .value(#x, Arch::x) #include "taichi/inc/archs.inc.h" @@ -98,6 +93,7 @@ void export_lang(py::module &m) { .def(py::self == py::self) .def("__hash__", &DataType::hash) .def("to_string", &DataType::to_string) + .def("__str__", &DataType::to_string) .def( "get_ptr", [](DataType *dtype) -> Type * { return *dtype; }, py::return_value_policy::reference) @@ -122,8 +118,11 @@ void export_lang(py::module &m) { py::class_(m, "CompileConfig") .def(py::init<>()) .def_readwrite("arch", &CompileConfig::arch) + .def_readwrite("opt_level", &CompileConfig::opt_level) .def_readwrite("packed", &CompileConfig::packed) .def_readwrite("print_ir", &CompileConfig::print_ir) + .def_readwrite("print_preprocessed_ir", + &CompileConfig::print_preprocessed_ir) .def_readwrite("debug", &CompileConfig::debug) .def_readwrite("cfg_optimization", &CompileConfig::cfg_optimization) .def_readwrite("check_out_of_bound", &CompileConfig::check_out_of_bound) @@ -177,6 +176,8 @@ void export_lang(py::module &m) { .def_readwrite("make_thread_local", &CompileConfig::make_thread_local) .def_readwrite("make_block_local", &CompileConfig::make_block_local) .def_readwrite("detect_read_only", &CompileConfig::detect_read_only) + .def_readwrite("ndarray_use_cached_allocator", + &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("async_opt_passes", &CompileConfig::async_opt_passes) @@ -197,7 +198,26 @@ void export_lang(py::module &m) { .def_readwrite("quant_opt_store_fusion", &CompileConfig::quant_opt_store_fusion) .def_readwrite("quant_opt_atomic_demotion", - &CompileConfig::quant_opt_atomic_demotion); + &CompileConfig::quant_opt_atomic_demotion) + .def_readwrite("allow_nv_shader_extension", + &CompileConfig::allow_nv_shader_extension) + .def_readwrite("use_gles", &CompileConfig::use_gles) + .def_readwrite("make_mesh_block_local", + &CompileConfig::make_mesh_block_local) + .def_readwrite("mesh_localize_to_end_mapping", + &CompileConfig::mesh_localize_to_end_mapping) + .def_readwrite("mesh_localize_from_end_mapping", + &CompileConfig::mesh_localize_from_end_mapping) + .def_readwrite("optimize_mesh_reordered_mapping", + &CompileConfig::optimize_mesh_reordered_mapping) + .def_readwrite("mesh_localize_all_attr_mappings", + &CompileConfig::mesh_localize_all_attr_mappings) + .def_readwrite("demote_no_access_mesh_fors", + &CompileConfig::demote_no_access_mesh_fors) + .def_readwrite("experimental_auto_mesh_local", + &CompileConfig::experimental_auto_mesh_local) + .def_readwrite("auto_mesh_local_default_occupacy", + &CompileConfig::auto_mesh_local_default_occupacy); m.def("reset_default_compile_config", [&]() { default_compile_config = CompileConfig(); }); @@ -230,9 +250,57 @@ void export_lang(py::module &m) { .def_readwrite("metric_values", &KernelProfileTracedRecord::metric_values); + py::enum_(m, "SNodeAccessFlag", py::arithmetic()) + .value("block_local", SNodeAccessFlag::block_local) + .value("read_only", SNodeAccessFlag::read_only) + .value("mesh_local", SNodeAccessFlag::mesh_local) + .export_values(); + + // Export ASTBuilder + py::class_(m, "ASTBuilder") + .def("create_kernel_exprgroup_return", + &ASTBuilder::create_kernel_exprgroup_return) + .def("create_print", &ASTBuilder::create_print) + .def("begin_func", &ASTBuilder::begin_func) + .def("end_func", &ASTBuilder::end_func) + .def("stop_grad", &ASTBuilder::stop_gradient) + .def("begin_frontend_if", &ASTBuilder::begin_frontend_if) + .def("begin_frontend_if_true", &ASTBuilder::begin_frontend_if_true) + .def("pop_scope", &ASTBuilder::pop_scope) + .def("begin_frontend_if_false", &ASTBuilder::begin_frontend_if_false) + .def("insert_deactivate", Deactivate) + .def("insert_activate", Activate) + .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) + .def("expr_alloca", &ASTBuilder::expr_alloca) + .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) + .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) + .def("expr_assign", &ASTBuilder::expr_assign) + .def("begin_frontend_range_for", &ASTBuilder::begin_frontend_range_for) + .def("end_frontend_range_for", &ASTBuilder::pop_scope) + .def("begin_frontend_struct_for", &ASTBuilder::begin_frontend_struct_for) + .def("end_frontend_struct_for", &ASTBuilder::pop_scope) + .def("begin_frontend_mesh_for", &ASTBuilder::begin_frontend_mesh_for) + .def("end_frontend_mesh_for", &ASTBuilder::pop_scope) + .def("begin_frontend_while", &ASTBuilder::begin_frontend_while) + .def("insert_break_stmt", &ASTBuilder::insert_break_stmt) + .def("insert_continue_stmt", &ASTBuilder::insert_continue_stmt) + .def("insert_expr_stmt", &ASTBuilder::insert_expr_stmt) + .def("insert_thread_idx_expr", &ASTBuilder::insert_thread_idx_expr) + .def("insert_patch_idx_expr", &ASTBuilder::insert_patch_idx_expr) + .def("sifakis_svd_f32", sifakis_svd_export) + .def("sifakis_svd_f64", sifakis_svd_export) + .def("expr_var", &ASTBuilder::make_var) + .def("bit_vectorize", &ASTBuilder::bit_vectorize) + .def("parallelize", &ASTBuilder::parallelize) + .def("block_dim", &ASTBuilder::block_dim) + .def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag) + .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag); + py::class_(m, "Program") .def(py::init<>()) .def_readonly("config", &Program::config) + .def("sync_kernel_profiler", + [](Program *program) { program->profiler->sync(); }) .def("query_kernel_profile_info", [](Program *program, const std::string &name) { return program->query_kernel_profile_info(name); @@ -250,6 +318,10 @@ void export_lang(py::module &m) { }) .def("kernel_profiler_total_time", [](Program *program) { return program->profiler->get_total_time(); }) + .def("set_kernel_profiler_toolkit", + [](Program *program, const std::string toolkit_name) { + return program->profiler->set_profiler_toolkit(toolkit_name); + }) .def("clear_kernel_profile_info", &Program::clear_kernel_profile_info) .def("timeline_clear", [](Program *) { Timelines::get_instance().clear(); }) @@ -273,7 +345,64 @@ void export_lang(py::module &m) { .def("make_aot_module_builder", &Program::make_aot_module_builder) .def("get_snode_tree_size", &Program::get_snode_tree_size) .def("get_snode_root", &Program::get_snode_root, - py::return_value_policy::reference); + py::return_value_policy::reference) + .def("current_ast_builder", &Program::current_ast_builder, + py::return_value_policy::reference) + .def( + "create_kernel", + [](Program *program, const std::function &body, + const std::string &name, bool grad) -> Kernel * { + py::gil_scoped_release release; + return &program->kernel(body, name, grad); + }, + py::return_value_policy::reference) + .def("create_function", &Program::create_function, + py::return_value_policy::reference) + .def("create_sparse_matrix_builder", + [](Program *program, int n, int m, uint64 max_num_entries, + DataType dtype) { + TI_ERROR_IF(!arch_is_cpu(program->config.arch), + "SparseMatrix only supports CPU for now."); + return SparseMatrixBuilder(n, m, max_num_entries, dtype); + }) + .def("create_sparse_matrix", + [](Program *program, int n, int m) { + TI_ERROR_IF(!arch_is_cpu(program->config.arch), + "SparseMatrix only supports CPU for now."); + return SparseMatrix(n, m); + }) + .def( + "dump_dot", + [](Program *program, std::optional rankdir, + int embed_states_threshold) { + // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#allow-prohibiting-none-arguments + return program->async_engine->sfg->dump_dot(rankdir, + embed_states_threshold); + }, + py::arg("rankdir").none(true), + py::arg("embed_states_threshold")) // FIXME: + .def("no_activate", + [](Program *program, SNode *snode) { + // TODO(#2193): Also apply to @ti.func? + auto *kernel = dynamic_cast(program->current_callable); + TI_ASSERT(kernel); + kernel->no_activate.push_back(snode); + }) + .def("print_sfg", + [](Program *program) { return program->async_engine->sfg->print(); }) + .def("decl_arg", + [&](Program *program, const DataType &dt, bool is_array) { + return program->current_callable->insert_arg(dt, is_array); + }) + .def("decl_arr_arg", + [&](Program *program, const DataType &dt, int total_dim, + std::vector shape) { + return program->current_callable->insert_arr_arg(dt, total_dim, + shape); + }) + .def("decl_ret", [&](Program *program, const DataType &dt) { + return program->current_callable->insert_ret(dt); + }); py::class_(m, "AotModuleBuilder") .def("add_field", &AotModuleBuilder::add_field) @@ -281,14 +410,6 @@ void export_lang(py::module &m) { .def("add_kernel_template", &AotModuleBuilder::add_kernel_template) .def("dump", &AotModuleBuilder::dump); - m.def("get_current_program", get_current_program, - py::return_value_policy::reference); - - m.def( - "current_compile_config", - [&]() -> CompileConfig & { return get_current_program().config; }, - py::return_value_policy::reference); - py::class_(m, "Axis").def(py::init()); py::class_(m, "SNode") .def(py::init<>()) @@ -316,11 +437,7 @@ void export_lang(py::module &m) { py::return_value_policy::reference) .def("bit_struct", &SNode::bit_struct, py::return_value_policy::reference) .def("bit_array", &SNode::bit_array, py::return_value_policy::reference) - .def("place", - [](SNode *snode, Expr &expr, const std::vector &offset) { - place_child(&expr, offset, snode, - get_current_program().get_snode_to_glb_var_exprs()); - }) + .def("place", &SNode::place) .def("data_type", [](SNode *snode) { return snode->dt; }) .def("name", [](SNode *snode) { return snode->name; }) .def("get_num_ch", @@ -329,39 +446,16 @@ void export_lang(py::module &m) { "get_ch", [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); }, py::return_value_policy::reference) - .def("lazy_grad", - [](SNode *snode) { - make_lazy_grad(snode, - get_current_program().get_snode_to_glb_var_exprs()); - }) - .def("read_int", - [](SNode *snode, const std::vector &I) -> int64 { - return get_snode_rw_accessors(snode).read_int(I); - }) - .def("read_uint", - [](SNode *snode, const std::vector &I) -> uint64 { - return get_snode_rw_accessors(snode).read_uint(I); - }) - .def("read_float", - [](SNode *snode, const std::vector &I) -> float64 { - return get_snode_rw_accessors(snode).read_float(I); - }) + .def("lazy_grad", &SNode::lazy_grad) + .def("read_int", &SNode::read_int) + .def("read_uint", &SNode::read_uint) + .def("read_float", &SNode::read_float) .def("has_grad", &SNode::has_grad) .def("is_primal", &SNode::is_primal) .def("is_place", &SNode::is_place) - .def("get_expr", - [](SNode *snode) { - return Expr( - get_current_program().get_snode_to_glb_var_exprs()->at(snode)); - }) - .def("write_int", - [](SNode *snode, const std::vector &I, int64 val) { - get_snode_rw_accessors(snode).write_int(I, val); - }) - .def("write_float", - [](SNode *snode, const std::vector &I, float64 val) { - get_snode_rw_accessors(snode).write_float(I, val); - }) + .def("get_expr", &SNode::get_expr) + .def("write_int", &SNode::write_int) + .def("write_float", &SNode::write_float) .def("get_shape_along_axis", &SNode::shape_along_axis) .def("get_physical_index_position", [](SNode *snode) { @@ -372,6 +466,8 @@ void export_lang(py::module &m) { .def("num_active_indices", [](SNode *snode) { return snode->num_active_indices; }) .def_readonly("cell_size_bytes", &SNode::cell_size_bytes) + .def_readonly("offset_bytes_in_parent_cell", + &SNode::offset_bytes_in_parent_cell) .def("begin_shared_exp_placement", &SNode::begin_shared_exp_placement) .def("end_shared_exp_placement", &SNode::end_shared_exp_placement); @@ -381,10 +477,35 @@ void export_lang(py::module &m) { program->destroy_snode_tree(snode_tree); }); + py::class_(m, "Ndarray") + .def(py::init &>()) + .def("data_ptr", &Ndarray::get_data_ptr_as_int) + .def("device_allocation_ptr", &Ndarray::get_device_allocation_ptr_as_int) + .def("element_size", &Ndarray::get_element_size) + .def("nelement", &Ndarray::get_nelement) + .def("fill_float", &Ndarray::fill_float) + .def("fill_int", &Ndarray::fill_int) + .def("fill_uint", &Ndarray::fill_uint) + .def("read_int", &Ndarray::read_int) + .def("read_uint", &Ndarray::read_uint) + .def("read_float", &Ndarray::read_float) + .def("write_int", &Ndarray::write_int) + .def("write_float", &Ndarray::write_float) + .def_readonly("dtype", &Ndarray::dtype) + .def_readonly("shape", &Ndarray::shape); + py::class_(m, "Kernel") .def("get_ret_int", &Kernel::get_ret_int) .def("get_ret_float", &Kernel::get_ret_float) + .def("get_ret_int_tensor", &Kernel::get_ret_int_tensor) + .def("get_ret_float_tensor", &Kernel::get_ret_float_tensor) .def("make_launch_context", &Kernel::make_launch_context) + .def( + "ast_builder", + [](Kernel *self) -> ASTBuilder * { + return &self->context->builder(); + }, + py::return_value_policy::reference) .def("__call__", [](Kernel *kernel, Kernel::LaunchContextBuilder &launch_ctx) { py::gil_scoped_release release; @@ -396,16 +517,25 @@ void export_lang(py::module &m) { .def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float) .def("set_arg_external_array", &Kernel::LaunchContextBuilder::set_arg_external_array) + .def("set_arg_external_array_with_shape", + &Kernel::LaunchContextBuilder::set_arg_external_array_with_shape) + .def("set_arg_ndarray", &Kernel::LaunchContextBuilder::set_arg_ndarray) .def("set_extra_arg_int", &Kernel::LaunchContextBuilder::set_extra_arg_int); py::class_(m, "Function") .def("set_function_body", py::overload_cast &>( - &Function::set_function_body)); + &Function::set_function_body)) + .def( + "ast_builder", + [](Function *self) -> ASTBuilder * { + return &self->context->builder(); + }, + py::return_value_policy::reference); py::class_ expr(m, "Expr"); - expr.def("serialize", &Expr::serialize) + expr.def("serialize", [](Expr *expr) { return expr->serialize(); }) .def("snode", &Expr::snode, py::return_value_policy::reference) .def("is_global_var", [](Expr *expr) { return expr->is(); }) @@ -428,6 +558,8 @@ void export_lang(py::module &m) { }) .def("set_grad", &Expr::set_grad) .def("set_attribute", &Expr::set_attribute) + .def("get_ret_type", &Expr::get_ret_type) + .def("type_check", &Expr::type_check) .def("get_expr_name", [](Expr *expr) { return expr->cast()->name; @@ -449,18 +581,10 @@ void export_lang(py::module &m) { .def(py::init<>()) .def("size", [](ExprGroup *eg) { return eg->exprs.size(); }) .def("push_back", &ExprGroup::push_back) - .def("serialize", &ExprGroup::serialize); + .def("serialize", [](ExprGroup *eg) { eg->serialize(); }); py::class_(m, "Stmt"); - m.def("insert_deactivate", [](SNode *snode, const ExprGroup &indices) { - return Deactivate(snode, indices); - }); - - m.def("insert_activate", [](SNode *snode, const ExprGroup &indices) { - return Activate(snode, indices); - }); - m.def("expr_get_addr", [](SNode *snode, const ExprGroup &indices) { return Expr::make(snode, SNodeOpType::get_addr, indices); }); @@ -470,15 +594,6 @@ void export_lang(py::module &m) { return Append(snode, indices, val); }); - m.def("insert_external_func_call", - [](std::size_t func_addr, std::string source, const ExprGroup &args, - const ExprGroup &outputs) { - auto expr = Expr::make( - (void *)func_addr, source, args.exprs, outputs.exprs); - - current_ast_builder().insert(Stmt::make(expr)); - }); - m.def("insert_is_active", [](SNode *snode, const ExprGroup &indices) { return is_active(snode, indices); }); @@ -487,85 +602,11 @@ void export_lang(py::module &m) { return Length(snode, indices); }); - m.def("create_assert_stmt", [&](const Expr &cond, const std::string &msg, - const std::vector &args) { - auto stmt_unique = std::make_unique(cond, msg, args); - current_ast_builder().insert(std::move(stmt_unique)); - }); - m.def("insert_internal_func_call", [&](const std::string &func_name, const ExprGroup &args) { return Expr::make(func_name, args.exprs); }); - m.def("begin_frontend_while", [&](const Expr &cond) { - auto stmt_unique = std::make_unique(cond); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - scope_stack.push_back(current_ast_builder().create_scope(stmt->body)); - }); - - m.def("begin_frontend_range_for", - [&](const Expr &i, const Expr &s, const Expr &e) { - auto stmt_unique = std::make_unique(i, s, e); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - scope_stack.push_back(current_ast_builder().create_scope(stmt->body)); - }); - - m.def("begin_frontend_struct_for", [&](const ExprGroup &loop_vars, - const Expr &global) { - auto stmt_unique = std::make_unique(loop_vars, global); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - scope_stack.push_back(current_ast_builder().create_scope(stmt->body)); - }); - - m.def("end_frontend_range_for", [&]() { scope_stack.pop_back(); }); - m.def("pop_scope", [&]() { scope_stack.pop_back(); }); - - m.def("begin_frontend_if", [&](const Expr &cond) { - auto stmt_tmp = std::make_unique(cond); - current_ast_builder().insert(std::move(stmt_tmp)); - }); - - m.def("begin_frontend_if_true", [&]() { - auto if_stmt = current_ast_builder().get_last_stmt()->as(); - scope_stack.push_back( - current_ast_builder().create_scope(if_stmt->true_statements)); - }); - - m.def("begin_frontend_if_false", [&]() { - auto if_stmt = current_ast_builder().get_last_stmt()->as(); - scope_stack.push_back( - current_ast_builder().create_scope(if_stmt->false_statements)); - }); - - m.def("insert_break_stmt", [&]() { - current_ast_builder().insert(Stmt::make()); - }); - - m.def("create_kernel_return", [&](const Expr &value) { - current_ast_builder().insert(Stmt::make(value)); - }); - - m.def("insert_continue_stmt", [&]() { - current_ast_builder().insert(Stmt::make()); - }); - - m.def("insert_expr_stmt", [&](const Expr &val) { - current_ast_builder().insert(Stmt::make(val)); - }); - - m.def("begin_func", [&](const std::string &funcid) { - auto stmt_unique = std::make_unique(funcid); - auto stmt = stmt_unique.get(); - current_ast_builder().insert(std::move(stmt_unique)); - scope_stack.push_back(current_ast_builder().create_scope(stmt->body)); - }); - - m.def("end_func", [&](const std::string &funcid) { scope_stack.pop_back(); }); - m.def("make_func_call_expr", Expr::make); @@ -574,38 +615,31 @@ void export_lang(py::module &m) { static_cast(bit_cast)); m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::add, ptr_if_global(a), - load_if_ptr(b)); + return Expr::make(AtomicOpType::add, a, b); }); m.def("expr_atomic_sub", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::sub, ptr_if_global(a), - load_if_ptr(b)); + return Expr::make(AtomicOpType::sub, a, b); }); m.def("expr_atomic_min", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::min, ptr_if_global(a), - load_if_ptr(b)); + return Expr::make(AtomicOpType::min, a, b); }); m.def("expr_atomic_max", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::max, ptr_if_global(a), - load_if_ptr(b)); + return Expr::make(AtomicOpType::max, a, b); }); m.def("expr_atomic_bit_and", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::bit_and, - ptr_if_global(a), load_if_ptr(b)); + return Expr::make(AtomicOpType::bit_and, a, b); }); m.def("expr_atomic_bit_or", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::bit_or, - ptr_if_global(a), load_if_ptr(b)); + return Expr::make(AtomicOpType::bit_or, a, b); }); m.def("expr_atomic_bit_xor", [&](const Expr &a, const Expr &b) { - return Expr::make(AtomicOpType::bit_xor, - ptr_if_global(a), load_if_ptr(b)); + return Expr::make(AtomicOpType::bit_xor, a, b); }); m.def("expr_add", expr_add); @@ -648,6 +682,7 @@ void export_lang(py::module &m) { m.def("expr_neg", [&](const Expr &e) { return -e; }); DEFINE_EXPRESSION_OP_UNARY(sqrt) + DEFINE_EXPRESSION_OP_UNARY(round) DEFINE_EXPRESSION_OP_UNARY(floor) DEFINE_EXPRESSION_OP_UNARY(ceil) DEFINE_EXPRESSION_OP_UNARY(abs) @@ -663,40 +698,6 @@ void export_lang(py::module &m) { DEFINE_EXPRESSION_OP_UNARY(exp) DEFINE_EXPRESSION_OP_UNARY(log) - m.def("expr_var", [](const Expr &e) { return Var(e); }); - m.def("expr_alloca", []() { - auto var = Expr(std::make_shared()); - current_ast_builder().insert(std::make_unique( - std::static_pointer_cast(var.expr)->id, - PrimitiveType::unknown)); - return var; - }); - m.def("expr_alloca_local_tensor", [](const std::vector &shape, - const DataType &element_type, - const ExprGroup &elements) { - auto var = Expr(std::make_shared()); - current_ast_builder().insert(std::make_unique( - std::static_pointer_cast(var.expr)->id, shape, - element_type)); - for (int i = 0; i < (int)elements.exprs.size(); ++i) { - ExprGroup reversed_indices; - int linearized_index = i; - for (int d = (int)shape.size() - 1; d >= 0; --d) { - reversed_indices.push_back( - Expr::make(linearized_index % shape[d])); - linearized_index /= shape[d]; - } - ExprGroup indices; - for (int d = 0; d < (int)shape.size(); ++d) - indices.push_back(reversed_indices[(int)shape.size() - 1 - d]); - current_ast_builder().insert(std::make_unique( - Expr::make(var, indices, shape, 1), - load_if_ptr(elements.exprs[i]))); - } - return var; - }); - m.def("expr_assign", expr_assign); - m.def("make_global_load_stmt", Stmt::make); m.def("make_global_store_stmt", Stmt::make); m.def("make_frontend_assign_stmt", @@ -706,16 +707,18 @@ void export_lang(py::module &m) { Expr::make); m.def("make_external_tensor_expr", - Expr::make); + Expr::make &>); m.def("make_id_expr", Expr::make); m.def("make_rand_expr", Expr::make); - m.def("make_const_expr_i32", Expr::make); - m.def("make_const_expr_i64", Expr::make); - m.def("make_const_expr_f32", Expr::make); - m.def("make_const_expr_f64", Expr::make); + m.def("make_const_expr_int", + Expr::make); + + m.def("make_const_expr_fp", + Expr::make); m.def("make_global_ptr_expr", Expr::make); @@ -740,6 +743,8 @@ void export_lang(py::module &m) { #include "taichi/inc/data_type.inc.h" #undef PER_TYPE + m.def("data_type_size", data_type_size); + m.def("is_custom_type", is_custom_type); m.def("is_integral", is_integral); m.def("is_signed", is_signed); m.def("is_real", is_real); @@ -756,31 +761,12 @@ void export_lang(py::module &m) { return expr[expr_group]; }); - m.def("global_subscript_with_offset", - [](const Expr &var, const ExprGroup &indices, - const std::vector &shape, bool is_aos) { - // TODO: Add test for dimension check - if (is_aos) - return Expr::make(var, indices, shape, 1); - else { - SNode *snode = var.cast() - ->var.cast() - ->snode; - return Expr::make( - var, indices, shape, - snode->get_total_num_elements_towards_root()); - } - }); - - m.def("local_subscript_with_offset", - [](const Expr &var, const ExprGroup &indices, - const std::vector &shape) { - // TODO: Add test for dimension check - return Expr::make(var, indices, shape, 1); - }); + m.def("make_tensor_element_expr", + Expr::make &, int>); m.def("subscript", [](SNode *snode, const ExprGroup &indices) { - return Expr::make(snode, indices.loaded()); + return Expr::make(snode, indices); }); m.def("get_external_tensor_dim", [](const Expr &expr) { @@ -791,42 +777,31 @@ void export_lang(py::module &m) { m.def("get_external_tensor_shape_along_axis", Expr::make); - m.def( - "create_kernel", - [&](const std::function &body, const std::string &name, - bool grad) -> Kernel * { - py::gil_scoped_release release; - return &get_current_program().kernel(body, name, grad); - }, - py::return_value_policy::reference); + // Mesh related. + m.def("get_relation_size", [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx, + mesh::MeshElementType to_type) { + return Expr::make(mesh_ptr.ptr.get(), + mesh_idx, to_type); + }); - m.def( - "create_function", - [&](const FunctionKey &funcid) { - return get_current_program().create_function(funcid); - }, - py::return_value_policy::reference); + m.def("get_relation_access", + [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx, + mesh::MeshElementType to_type, const Expr &neighbor_idx) { + return Expr::make( + mesh_ptr.ptr.get(), mesh_idx, to_type, neighbor_idx); + }); + + m.def("get_index_conversion", + [](mesh::MeshPtr mesh_ptr, mesh::MeshElementType idx_type, + const Expr &idx, mesh::ConvType &conv_type) { + return Expr::make( + mesh_ptr.ptr.get(), idx_type, idx, conv_type); + }); py::class_(m, "FunctionKey") .def(py::init()) .def_readonly("instance_id", &FunctionKey::instance_id); - // This function will call `Expr &Expr::operator=(const Expr &o)` implicitly. - m.def("create_print", - [&](std::vector> contents) { - current_ast_builder().insert( - std::make_unique(contents)); - }); - - m.def("decl_arg", [&](const DataType &dt, bool is_external_array) { - return get_current_program().current_callable->insert_arg( - dt, is_external_array); - }); - - m.def("decl_ret", [&](const DataType &dt) { - return get_current_program().current_callable->insert_ret(dt); - }); - m.def("test_throw", [] { try { throw IRModified(); @@ -834,33 +809,13 @@ void export_lang(py::module &m) { TI_INFO("caught"); } }); - // Schedules - m.def("parallelize", Parallelize); - m.def("vectorize", Vectorize); - m.def("bit_vectorize", BitVectorize); - m.def("block_dim", BlockDim); - - py::enum_(m, "SNodeAccessFlag", py::arithmetic()) - .value("block_local", SNodeAccessFlag::block_local) - .value("read_only", SNodeAccessFlag::read_only) - .export_values(); - - m.def("insert_snode_access_flag", insert_snode_access_flag); - m.def("reset_snode_access_flag", reset_snode_access_flag); - m.def("no_activate", [](SNode *snode) { - // TODO(#2193): Also apply to @ti.func? - auto *kernel = - dynamic_cast(get_current_program().current_callable); - TI_ASSERT(kernel); - kernel->no_activate.push_back(snode); - }); - m.def("stop_grad", - [](SNode *snode) { current_ast_builder().stop_gradient(snode); }); m.def("test_throw", [] { throw IRModified(); }); m.def("needs_grad", needs_grad); +#if TI_WITH_LLVM m.def("libdevice_path", libdevice_path); +#endif m.def("host_arch", host_arch); @@ -872,15 +827,15 @@ void export_lang(py::module &m) { m.def("get_version_major", get_version_major); m.def("get_version_minor", get_version_minor); m.def("get_version_patch", get_version_patch); +#if TI_WITH_LLVM m.def("get_llvm_version_string", [] { return LLVM_VERSION_STRING; }); +#endif m.def("test_printf", [] { printf("test_printf\n"); }); m.def("test_logging", [] { TI_INFO("test_logging"); }); m.def("trigger_crash", [] { *(int *)(1) = 0; }); m.def("get_max_num_indices", [] { return taichi_max_num_indices; }); m.def("get_max_num_args", [] { return taichi_max_num_args; }); m.def("test_threading", test_threading); - m.def("sifakis_svd_f32", sifakis_svd_export); - m.def("sifakis_svd_f64", sifakis_svd_export); m.def("global_var_expr_from_snode", [](SNode *snode) { return Expr::make(snode); }); @@ -917,14 +872,6 @@ void export_lang(py::module &m) { m.def("stop_recording", []() { ActionRecorder::get_instance().stop_recording(); }); - // A temporary option which will be removed soon in the future - m.def("toggle_advanced_optimization", [](bool option) { - TI_WARN( - "'ti.core.toggle_advance_optimization(False)' is deprecated." - " Use 'ti.init(advanced_optimization=False)' instead"); - get_current_program().config.advanced_optimization = option; - }); - m.def("query_int64", [](const std::string &key) { if (key == "cuda_compute_capability") { #if defined(TI_WITH_CUDA) @@ -937,17 +884,6 @@ void export_lang(py::module &m) { } }); - m.def("print_sfg", - []() { return get_current_program().async_engine->sfg->print(); }); - m.def( - "dump_dot", - [](std::optional rankdir, int embed_states_threshold) { - // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#allow-prohibiting-none-arguments - return get_current_program().async_engine->sfg->dump_dot( - rankdir, embed_states_threshold); - }, - py::arg("rankdir").none(true), py::arg("embed_states_threshold")); - // Type system py::class_(m, "Type").def("to_string", &Type::to_string); @@ -967,15 +903,21 @@ void export_lang(py::module &m) { m.def("get_type_factory_instance", TypeFactory::get_instance, py::return_value_policy::reference); + m.def("decl_tensor_type", + [&](std::vector shape, const DataType &element_type) { + return TypeFactory::create_tensor_type(shape, element_type); + }); + py::class_(m, "SNodeRegistry") .def(py::init<>()) .def("create_root", &SNodeRegistry::create_root, py::return_value_policy::reference); + m.def( "finalize_snode_tree", - [](SNodeRegistry *registry, const SNode *root, - Program *program) -> SNodeTree * { - return program->add_snode_tree(registry->finalize(root)); + [](SNodeRegistry *registry, const SNode *root, Program *program, + bool compile_only) -> SNodeTree * { + return program->add_snode_tree(registry->finalize(root), compile_only); }, py::return_value_policy::reference); @@ -984,13 +926,6 @@ void export_lang(py::module &m) { .def("build", &SparseMatrixBuilder::build) .def("get_addr", [](SparseMatrixBuilder *mat) { return uint64(mat); }); - m.def("create_sparse_matrix_builder", - [](int n, int m, uint64 max_num_entries) { - TI_ERROR_IF(!arch_is_cpu(get_current_program().config.arch), - "SparseMatrix only supports CPU for now."); - return SparseMatrixBuilder(n, m, max_num_entries); - }); - py::class_(m, "SparseMatrix") .def("to_string", &SparseMatrix::to_string) .def(py::self + py::self, py::return_value_policy::reference_internal) @@ -1008,12 +943,6 @@ void export_lang(py::module &m) { .def("num_rows", &SparseMatrix::num_rows) .def("num_cols", &SparseMatrix::num_cols); - m.def("create_sparse_matrix", [](int n, int m) { - TI_ERROR_IF(!arch_is_cpu(get_current_program().config.arch), - "SparseMatrix only supports CPU for now."); - return SparseMatrix(n, m); - }); - py::class_(m, "SparseSolver") .def("compute", &SparseSolver::compute) .def("analyze_pattern", &SparseSolver::analyze_pattern) @@ -1022,6 +951,114 @@ void export_lang(py::module &m) { .def("info", &SparseSolver::info); m.def("make_sparse_solver", &make_sparse_solver); + + // Mesh Class + // Mesh related. + py::enum_(m, "MeshTopology", py::arithmetic()) + .value("Triangle", mesh::MeshTopology::Triangle) + .value("Tetrahedron", mesh::MeshTopology::Tetrahedron) + .export_values(); + + py::enum_(m, "MeshElementType", py::arithmetic()) + .value("Vertex", mesh::MeshElementType::Vertex) + .value("Edge", mesh::MeshElementType::Edge) + .value("Face", mesh::MeshElementType::Face) + .value("Cell", mesh::MeshElementType::Cell) + .export_values(); + + py::enum_(m, "MeshRelationType", py::arithmetic()) + .value("VV", mesh::MeshRelationType::VV) + .value("VE", mesh::MeshRelationType::VE) + .value("VF", mesh::MeshRelationType::VF) + .value("VC", mesh::MeshRelationType::VC) + .value("EV", mesh::MeshRelationType::EV) + .value("EE", mesh::MeshRelationType::EE) + .value("EF", mesh::MeshRelationType::EF) + .value("EC", mesh::MeshRelationType::EC) + .value("FV", mesh::MeshRelationType::FV) + .value("FE", mesh::MeshRelationType::FE) + .value("FF", mesh::MeshRelationType::FF) + .value("FC", mesh::MeshRelationType::FC) + .value("CV", mesh::MeshRelationType::CV) + .value("CE", mesh::MeshRelationType::CE) + .value("CF", mesh::MeshRelationType::CF) + .value("CC", mesh::MeshRelationType::CC) + .export_values(); + + py::enum_(m, "ConvType", py::arithmetic()) + .value("l2g", mesh::ConvType::l2g) + .value("l2r", mesh::ConvType::l2r) + .value("g2r", mesh::ConvType::g2r) + .export_values(); + + py::class_(m, "Mesh"); + py::class_(m, "MeshPtr"); + + m.def("element_order", mesh::element_order); + m.def("from_end_element_order", mesh::from_end_element_order); + m.def("to_end_element_order", mesh::to_end_element_order); + m.def("relation_by_orders", mesh::relation_by_orders); + m.def("inverse_relation", mesh::inverse_relation); + m.def("element_type_name", mesh::element_type_name); + + m.def( + "create_mesh", + []() { + auto mesh_shared = std::make_shared(); + mesh::MeshPtr mesh_ptr = mesh::MeshPtr{mesh_shared}; + return mesh_ptr; + }, + py::return_value_policy::reference); + + // ad-hoc setters + m.def("set_owned_offset", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) { + mesh_ptr.ptr->owned_offset.insert(std::pair(type, snode)); + }); + m.def("set_total_offset", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) { + mesh_ptr.ptr->total_offset.insert(std::pair(type, snode)); + }); + m.def("set_num_patches", [](mesh::MeshPtr &mesh_ptr, int num_patches) { + mesh_ptr.ptr->num_patches = num_patches; + }); + + m.def("set_num_elements", [](mesh::MeshPtr &mesh_ptr, + mesh::MeshElementType type, int num_elements) { + mesh_ptr.ptr->num_elements.insert(std::pair(type, num_elements)); + }); + + m.def("get_num_elements", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type) { + return mesh_ptr.ptr->num_elements.find(type)->second; + }); + + m.def("set_patch_max_element_num", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, + int max_element_num) { + mesh_ptr.ptr->patch_max_element_num.insert( + std::pair(type, max_element_num)); + }); + + m.def("set_index_mapping", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType element_type, + mesh::ConvType conv_type, SNode *snode) { + mesh_ptr.ptr->index_mapping.insert( + std::make_pair(std::make_pair(element_type, conv_type), snode)); + }); + + m.def("set_relation_fixed", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value) { + mesh_ptr.ptr->relations.insert( + std::pair(type, mesh::MeshLocalRelation(value))); + }); + + m.def("set_relation_dynamic", + [](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value, + SNode *offset) { + mesh_ptr.ptr->relations.insert( + std::pair(type, mesh::MeshLocalRelation(value, offset))); + }); } TI_NAMESPACE_END diff --git a/taichi/python/export_misc.cpp b/taichi/python/export_misc.cpp index b73efbba6ae93..e75df0b88d856 100644 --- a/taichi/python/export_misc.cpp +++ b/taichi/python/export_misc.cpp @@ -6,6 +6,7 @@ #include "taichi/backends/metal/api.h" #include "taichi/backends/opengl/opengl_api.h" #include "taichi/backends/vulkan/runtime.h" +#include "taichi/backends/dx/dx_api.h" #include "taichi/common/core.h" #include "taichi/common/interface.h" #include "taichi/common/task.h" @@ -17,6 +18,7 @@ #include "taichi/python/memory_usage_monitor.h" #include "taichi/system/benchmark.h" #include "taichi/system/dynamic_loader.h" +#include "taichi/system/hacked_signal_handler.h" #include "taichi/system/profiler.h" #include "taichi/util/statistics.h" #if defined(TI_WITH_CUDA) @@ -24,7 +26,7 @@ #endif #ifdef TI_WITH_VULKAN -#include "taichi/backends/vulkan/loader.h" +#include "taichi/backends/vulkan/vulkan_loader.h" #endif #ifdef TI_WITH_CC @@ -33,16 +35,7 @@ extern bool is_c_backend_available(); } #endif -TI_NAMESPACE_BEGIN - -Config config_from_py_dict(py::dict &c) { - Config config; - for (auto item : c) { - config.set(std::string(py::str(item.first)), - std::string(py::str(item.second))); - } - return config; -} +namespace taichi { void test_raise_error() { raise_assertion_failure_in_python("Just a test."); @@ -69,27 +62,6 @@ void print_all_units() { std::cout << all_units << " units in all." << std::endl; } -void duplicate_stdout_to_file(const std::string &fn) { -/* -static int stdout_fd = -1; -int fd[2]; -pipe(fd); -stdout = fdopen(fd[1], "w"); -auto file_fd = fdopen(fd[0], "w"); -FILE *file = freopen(fn.c_str(), "w", file_fd); -*/ -#if defined(TI_PLATFORM_UNIX) - std::cerr.rdbuf(std::cout.rdbuf()); - dup2(fileno(popen(fmt::format("tee {}", fn).c_str(), "w")), STDOUT_FILENO); -#else - TI_NOT_IMPLEMENTED; -#endif -} - -void stop_duplicating_stdout_to_file(const std::string &fn) { - TI_NOT_IMPLEMENTED; -} - void export_misc(py::module &m) { py::class_(m, "Config"); py::register_exception_translator([](std::exception_ptr p) { @@ -126,8 +98,6 @@ void export_misc(py::module &m) { TI_EXPORT_LOGGING(error); TI_EXPORT_LOGGING(critical); - m.def("duplicate_stdout_to_file", duplicate_stdout_to_file); - m.def("print_all_units", print_all_units); m.def("set_core_state_python_imported", CoreState::set_python_imported); m.def("set_logging_level", [](const std::string &level) { @@ -141,7 +111,6 @@ void export_misc(py::module &m) { m.def("set_core_trigger_gdb_when_crash", CoreState::set_trigger_gdb_when_crash); m.def("test_raise_error", test_raise_error); - m.def("config_from_dict", config_from_py_dict); m.def("get_default_float_size", []() { return sizeof(real); }); m.def("trigger_sig_fpe", []() { int a = 2; @@ -169,12 +138,20 @@ void export_misc(py::module &m) { m.def("toggle_python_print_buffer", [](bool opt) { py_cout.enabled = opt; }); m.def("with_cuda", is_cuda_api_available); m.def("with_metal", taichi::lang::metal::is_metal_api_available); - m.def("with_opengl", taichi::lang::opengl::is_opengl_api_available); + m.def("with_opengl", taichi::lang::opengl::is_opengl_api_available, + py::arg("use_gles") = false); #ifdef TI_WITH_VULKAN m.def("with_vulkan", taichi::lang::vulkan::is_vulkan_api_available); + m.def("set_vulkan_visible_device", + taichi::lang::vulkan::set_vulkan_visible_device); #else m.def("with_vulkan", []() { return false; }); #endif +#ifdef TI_WITH_DX11 + m.def("with_dx11", taichi::lang::directx11::is_dx_api_available); +#else + m.def("with_dx11", []() { return false; }); +#endif #ifdef TI_WITH_CC m.def("with_cc", taichi::lang::cccp::is_c_backend_available); @@ -189,6 +166,8 @@ void export_misc(py::module &m) { m.def( "get_kernel_stats", []() -> Statistics & { return stat; }, py::return_value_policy::reference); + + py::class_(m, "HackedSignalRegister").def(py::init<>()); } -TI_NAMESPACE_END +} // namespace taichi diff --git a/taichi/python/memory_usage_monitor.cpp b/taichi/python/memory_usage_monitor.cpp index 78ee6bce9613a..c6d053d8ac194 100644 --- a/taichi/python/memory_usage_monitor.cpp +++ b/taichi/python/memory_usage_monitor.cpp @@ -39,17 +39,17 @@ uint64 get_memory_usage(int pid) { } MemoryMonitor::MemoryMonitor(int pid, std::string output_fn) { - log.open(output_fn, std::ios_base::out); - locals = new py::dict; - (*reinterpret_cast(locals))["pid"] = pid; + log_.open(output_fn, std::ios_base::out); + locals_ = new py::dict; + (*reinterpret_cast(locals_))["pid"] = pid; py::exec(R"( import os, psutil process = psutil.Process(pid))", - py::globals(), *reinterpret_cast(locals)); + py::globals(), *reinterpret_cast(locals_)); } MemoryMonitor::~MemoryMonitor() { - delete reinterpret_cast(locals); + delete reinterpret_cast(locals_); } uint64 MemoryMonitor::get_usage() const { @@ -59,17 +59,17 @@ uint64 MemoryMonitor::get_usage() const { mem = process.memory_info().rss except: mem = -1)", - py::globals(), *reinterpret_cast(locals)); - return (*reinterpret_cast(locals))["mem"].cast(); + py::globals(), *reinterpret_cast(locals_)); + return (*reinterpret_cast(locals_))["mem"].cast(); } void MemoryMonitor::append_sample() { auto t = std::chrono::system_clock::now(); - log << fmt::format( + log_ << fmt::format( "{:.5f} {}\n", (t.time_since_epoch() / std::chrono::nanoseconds(1)) / 1e9_f64, get_usage()); - log.flush(); + log_.flush(); } void start_memory_monitoring(std::string output_fn, int pid, real interval) { diff --git a/taichi/python/memory_usage_monitor.h b/taichi/python/memory_usage_monitor.h index 187393d8d4a97..ac6cb5f2c2b73 100644 --- a/taichi/python/memory_usage_monitor.h +++ b/taichi/python/memory_usage_monitor.h @@ -7,8 +7,8 @@ TI_NAMESPACE_BEGIN class MemoryMonitor { // avoid including py::dict // py::dict locals; - void *locals; - std::ofstream log; + void *locals_; + std::ofstream log_; public: MemoryMonitor(int pid, std::string output_fn); diff --git a/taichi/python/py_exception_translator.cpp b/taichi/python/py_exception_translator.cpp index 7d7160c50fbd3..ff1bc11f1daad 100644 --- a/taichi/python/py_exception_translator.cpp +++ b/taichi/python/py_exception_translator.cpp @@ -18,8 +18,6 @@ class ExceptionTranslationImpl { std::rethrow_exception(p); } catch (const std::string &e) { PyErr_SetString(PyExc_RuntimeError, e.c_str()); - } catch (const std::exception &e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); } }); } diff --git a/taichi/python/snode_registry.cpp b/taichi/python/snode_registry.cpp index ef6f284c0adcd..c98a39efffc51 100644 --- a/taichi/python/snode_registry.cpp +++ b/taichi/python/snode_registry.cpp @@ -1,12 +1,17 @@ #include "taichi/python/snode_registry.h" +#include "taichi/common/logging.h" #include "taichi/ir/snode.h" +#include "taichi/program/program.h" namespace taichi { namespace lang { -SNode *SNodeRegistry::create_root() { - auto n = std::make_unique(/*depth=*/0, SNodeType::root); +SNode *SNodeRegistry::create_root(Program *prog) { + TI_ASSERT(prog != nullptr); + auto n = std::make_unique(/*depth=*/0, SNodeType::root, + prog->get_snode_to_glb_var_exprs(), + &prog->get_snode_rw_accessors_bank()); auto *res = n.get(); snodes_.push_back(std::move(n)); return res; diff --git a/taichi/python/snode_registry.h b/taichi/python/snode_registry.h index bb1f29c4817d8..869e146899c27 100644 --- a/taichi/python/snode_registry.h +++ b/taichi/python/snode_registry.h @@ -7,6 +7,7 @@ namespace taichi { namespace lang { class SNode; +class Program; /** * A helper class to keep the root SNodes that aren't materialized yet. @@ -25,7 +26,7 @@ class SNodeRegistry { * * @return Pointer to the created SNode. */ - SNode *create_root(); + SNode *create_root(Program *prog); /** * Transfers the ownership of @param snode to the caller. diff --git a/taichi/runtime/llvm/atomic.h b/taichi/runtime/llvm/atomic.h index 53f2645eccf90..08b3d4d76ccf1 100644 --- a/taichi/runtime/llvm/atomic.h +++ b/taichi/runtime/llvm/atomic.h @@ -36,29 +36,20 @@ DEFINE_ATOMIC_OP_INTRINSIC(xor, i64) DEFINE_ATOMIC_OP_INTRINSIC(xor, u32) DEFINE_ATOMIC_OP_INTRINSIC(xor, u64) -inline f32 add_f32(f32 a, f32 b) { - return a + b; -} - -inline f64 add_f64(f64 a, f64 b) { - return a + b; -} - -inline f32 min_f32(f32 a, f32 b) { - return b > a ? a : b; -} - -inline f64 min_f64(f64 a, f64 b) { - return b > a ? a : b; -} +#define DEFINE_ADD(T) \ + T add_##T(T a, T b) { \ + return a + b; \ + } -inline f32 max_f32(f32 a, f32 b) { - return b < a ? a : b; -} +#define DEFINE_MIN(T) \ + T min_##T(T a, T b) { \ + return b > a ? a : b; \ + } -inline f64 max_f64(f64 a, f64 b) { - return b < a ? a : b; -} +#define DEFINE_MAX(T) \ + T max_##T(T a, T b) { \ + return b < a ? a : b; \ + } #define DEFINE_ATOMIC_OP_COMP_EXCH(OP, T) \ T atomic_##OP##_##T(volatile T *dest, T inc) { \ @@ -74,13 +65,24 @@ inline f64 max_f64(f64 a, f64 b) { return old_val; \ } +DEFINE_ADD(f32) +DEFINE_ADD(f64) +DEFINE_MIN(f32) +DEFINE_MIN(f64) +DEFINE_MAX(f32) +DEFINE_MAX(f64) + DEFINE_ATOMIC_OP_COMP_EXCH(add, f32) DEFINE_ATOMIC_OP_COMP_EXCH(add, f64) -DEFINE_ATOMIC_OP_COMP_EXCH(min, f32) -DEFINE_ATOMIC_OP_COMP_EXCH(min, f64) DEFINE_ATOMIC_OP_COMP_EXCH(min, i32) DEFINE_ATOMIC_OP_COMP_EXCH(min, i64) -DEFINE_ATOMIC_OP_COMP_EXCH(max, f32) -DEFINE_ATOMIC_OP_COMP_EXCH(max, f64) +DEFINE_ATOMIC_OP_COMP_EXCH(min, f32) +DEFINE_ATOMIC_OP_COMP_EXCH(min, f64) DEFINE_ATOMIC_OP_COMP_EXCH(max, i32) DEFINE_ATOMIC_OP_COMP_EXCH(max, i64) +DEFINE_ATOMIC_OP_COMP_EXCH(max, f32) +DEFINE_ATOMIC_OP_COMP_EXCH(max, f64) +DEFINE_ATOMIC_OP_COMP_EXCH(min, u32) +DEFINE_ATOMIC_OP_COMP_EXCH(min, u64) +DEFINE_ATOMIC_OP_COMP_EXCH(max, u32) +DEFINE_ATOMIC_OP_COMP_EXCH(max, u64) diff --git a/taichi/runtime/llvm/internal_functions.h b/taichi/runtime/llvm/internal_functions.h index ec620c9055622..591ccddc64bcf 100644 --- a/taichi/runtime/llvm/internal_functions.h +++ b/taichi/runtime/llvm/internal_functions.h @@ -9,39 +9,54 @@ } \ } while (0) -i32 do_nothing(Context *context) { +#define ATOMIC_INSERT(T) \ + do { \ + auto base_ptr = (int64 *)base_ptr_; \ + int64 *num_triplets = base_ptr; \ + auto data_base_ptr = *(T **)(base_ptr + 1); \ + auto triplet_id = atomic_add_i64(num_triplets, 1); \ + data_base_ptr[triplet_id * 3] = i; \ + data_base_ptr[triplet_id * 3 + 1] = j; \ + data_base_ptr[triplet_id * 3 + 2] = taichi_union_cast(value); \ + } while (0); + +i32 do_nothing(RuntimeContext *context) { return 0; } -i32 refresh_counter(Context *context) { +i32 refresh_counter(RuntimeContext *context) { auto runtime = context->runtime; auto queue = runtime->mem_req_queue; queue->tail++; return 0; } -i32 insert_triplet(Context *context, - int64 base_ptr_, - int i, - int j, - float value) { - auto base_ptr = (int64 *)base_ptr_; - - int64 *num_triplets = base_ptr; - auto data_base_ptr = *(int32 **)(base_ptr + 1); +i32 insert_triplet_f32(RuntimeContext *context, + int64 base_ptr_, + int i, + int j, + float value) { + ATOMIC_INSERT(int32); + return 0; +} - auto triplet_id = atomic_add_i64(num_triplets, 1); - data_base_ptr[triplet_id * 3] = i; - data_base_ptr[triplet_id * 3 + 1] = j; - data_base_ptr[triplet_id * 3 + 2] = taichi_union_cast(value); +i32 insert_triplet_f64(RuntimeContext *context, + int64 base_ptr_, + int i, + int j, + float64 value) { + ATOMIC_INSERT(int64); return 0; } -i32 test_internal_func_args(Context *context, float32 i, float32 j, int32 k) { +i32 test_internal_func_args(RuntimeContext *context, + float32 i, + float32 j, + int32 k) { return static_cast((i + j) * k); } -i32 test_stack(Context *context) { +i32 test_stack(RuntimeContext *context) { auto stack = new u8[132]; stack_push(stack, 16, 4); stack_push(stack, 16, 4); @@ -50,7 +65,7 @@ i32 test_stack(Context *context) { return 0; } -i32 test_list_manager(Context *context) { +i32 test_list_manager(RuntimeContext *context) { auto runtime = context->runtime; taichi_printf(runtime, "LLVMRuntime %p\n", runtime); auto list = context->runtime->create(runtime, 4, 16); @@ -65,7 +80,7 @@ i32 test_list_manager(Context *context) { return 0; } -i32 test_node_allocator(Context *context) { +i32 test_node_allocator(RuntimeContext *context) { auto runtime = context->runtime; taichi_printf(runtime, "LLVMRuntime %p\n", runtime); auto nodes = context->runtime->create(runtime, sizeof(i64), 4); @@ -98,7 +113,7 @@ i32 test_node_allocator(Context *context) { return 0; } -i32 test_node_allocator_gc_cpu(Context *context) { +i32 test_node_allocator_gc_cpu(RuntimeContext *context) { auto runtime = context->runtime; taichi_printf(runtime, "LLVMRuntime %p\n", runtime); auto nodes = context->runtime->create(runtime, sizeof(i64), 4); @@ -142,7 +157,7 @@ i32 test_node_allocator_gc_cpu(Context *context) { return 0; } -i32 test_active_mask(Context *context) { +i32 test_active_mask(RuntimeContext *context) { auto rt = context->runtime; taichi_printf(rt, "%d activemask %x\n", thread_idx(), cuda_active_mask()); @@ -159,7 +174,7 @@ i32 test_active_mask(Context *context) { return 0; } -i32 test_shfl(Context *context) { +i32 test_shfl(RuntimeContext *context) { auto rt = context->runtime; auto s = cuda_shfl_down_sync_i32(cuda_active_mask(), warp_idx() + 1000, 2, 31); diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 27abab797169d..0be7f410ec593 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -26,7 +26,7 @@ #include "taichi/inc/cuda_kernel_utils.inc.h" #include "taichi/math/arithmetic.h" -struct Context; +struct RuntimeContext; using assert_failed_type = void (*)(const char *); using host_printf_type = void (*)(const char *, ...); using host_vsnprintf_type = int (*)(char *, @@ -34,7 +34,8 @@ using host_vsnprintf_type = int (*)(char *, const char *, std::va_list); using vm_allocator_type = void *(*)(void *, std::size_t, std::size_t); -using RangeForTaskFunc = void(Context *, const char *tls, int i); +using RangeForTaskFunc = void(RuntimeContext *, const char *tls, int i); +using MeshForTaskFunc = void(RuntimeContext *, const char *tls, uint32_t i); using parallel_for_type = void (*)(void *thread_pool, int splits, int num_desired_threads, @@ -105,7 +106,7 @@ using f64 = float64; using uint8 = uint8_t; using Ptr = uint8 *; -using ContextArgType = long long; +using RuntimeContextArgType = long long; #if ARCH_cuda extern "C" { @@ -214,27 +215,51 @@ i64 floordiv_i64(i64 a, i64 b) { return ifloordiv(a, b); } +u16 min_u16(u16 a, u16 b) { + return a < b ? a : b; +} + +i16 min_i16(i16 a, i16 b) { + return a < b ? a : b; +} + +u32 min_u32(u32 a, u32 b) { + return a < b ? a : b; +} + int min_i32(i32 a, i32 b) { return a < b ? a : b; } -int min_i64(i64 a, i64 b) { +u64 min_u64(u64 a, u64 b) { return a < b ? a : b; } -int max_i32(i32 a, i32 b) { +i64 min_i64(i64 a, i64 b) { + return a < b ? a : b; +} + +u16 max_u16(u16 a, u16 b) { return a > b ? a : b; } -u32 min_u32(u32 a, u32 b) { - return a < b ? a : b; +i16 max_i16(i16 a, i16 b) { + return a > b ? a : b; } u32 max_u32(u32 a, u32 b) { return a > b ? a : b; } -int max_i64(i64 a, i64 b) { +int max_i32(i32 a, i32 b) { + return a > b ? a : b; +} + +u64 max_u64(u64 a, u64 b) { + return a > b ? a : b; +} + +i64 max_i64(i64 a, i64 b) { return a > b ? a : b; } @@ -321,10 +346,11 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val); #include "taichi/program/context.h" #include "taichi/runtime/llvm/mem_request.h" -STRUCT_FIELD_ARRAY(Context, args); -STRUCT_FIELD(Context, runtime); +STRUCT_FIELD_ARRAY(RuntimeContext, args); +STRUCT_FIELD(RuntimeContext, runtime); +STRUCT_FIELD(RuntimeContext, result_buffer) -int32 Context_get_extra_args(Context *ctx, int32 i, int32 j) { +int32 RuntimeContext_get_extra_args(RuntimeContext *ctx, int32 i, int32 j) { return ctx->extra_args[i][j]; } @@ -353,7 +379,7 @@ struct StructMeta { PhysicalCoordinates *refined_coord, int index); - Context *context; + RuntimeContext *context; }; STRUCT_FIELD(StructMeta, snode_id) @@ -370,7 +396,7 @@ struct LLVMRuntime; constexpr bool enable_assert = true; -void taichi_assert(Context *context, i32 test, const char *msg); +void taichi_assert(RuntimeContext *context, i32 test, const char *msg); void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg); #define TI_ASSERT_INFO(x, msg) taichi_assert(context, (int)(x), msg) #define TI_ASSERT(x) TI_ASSERT_INFO(x, #x) @@ -671,8 +697,8 @@ struct NodeManager { extern "C" { -void LLVMRuntime_store_result(LLVMRuntime *runtime, u64 ret) { - runtime->set_result(taichi_result_buffer_ret_value_id, ret); +void RuntimeContext_store_result(RuntimeContext *ctx, u64 ret, u32 idx) { + ctx->result_buffer[taichi_result_buffer_ret_value_id + idx] = ret; } void LLVMRuntime_profiler_start(LLVMRuntime *runtime, Ptr kernel_name) { @@ -722,7 +748,7 @@ RUNTIME_STRUCT_FIELD(ListManager, num_elements); RUNTIME_STRUCT_FIELD(ListManager, max_num_elements_per_chunk); RUNTIME_STRUCT_FIELD(ListManager, element_size); -void taichi_assert(Context *context, i32 test, const char *msg) { +void taichi_assert(RuntimeContext *context, i32 test, const char *msg) { taichi_assert_runtime(context->runtime, test, msg); } @@ -911,13 +937,17 @@ void runtime_initialize_snodes(LLVMRuntime *runtime, const int num_snodes, const int snode_tree_id, std::size_t rounded_size, - Ptr ptr) { + Ptr ptr, + bool all_dense) { // For Metal runtime, we have to make sure that both the beginning address // and the size of the root buffer memory are aligned to page size. runtime->root_mem_sizes[snode_tree_id] = rounded_size; runtime->roots[snode_tree_id] = ptr; // runtime->request_allocate_aligned ready to use // initialize the root node element list + if (all_dense) { + return; + } for (int i = root_id; i < root_id + num_snodes; i++) { // TODO: some SNodes do not actually need an element list. runtime->element_lists[i] = @@ -1219,10 +1249,10 @@ void element_listgen_nonroot(LLVMRuntime *runtime, } } -using BlockTask = void(Context *, char *, Element *, int, int); +using BlockTask = void(RuntimeContext *, char *, Element *, int, int); struct cpu_block_task_helper_context { - Context *context; + RuntimeContext *context; BlockTask *task; ListManager *list; int element_size; @@ -1248,7 +1278,7 @@ void cpu_struct_for_block_helper(void *ctx_, int thread_id, int i) { upper = std::min(upper, e.loop_bounds[1]); alignas(8) char tls_buffer[ctx->tls_buffer_size]; - Context this_thread_context = *ctx->context; + RuntimeContext this_thread_context = *ctx->context; this_thread_context.cpu_thread_id = thread_id; if (lower < upper) { (*ctx->task)(&this_thread_context, tls_buffer, @@ -1256,7 +1286,7 @@ void cpu_struct_for_block_helper(void *ctx_, int thread_id, int i) { } } -void parallel_struct_for(Context *context, +void parallel_struct_for(RuntimeContext *context, int snode_id, int element_size, int element_split, @@ -1300,10 +1330,13 @@ void parallel_struct_for(Context *context, #endif } -using range_for_xlogue = void (*)(Context *, /*TLS*/ char *tls_base); +using range_for_xlogue = void (*)(RuntimeContext *, /*TLS*/ char *tls_base); +using mesh_for_xlogue = void (*)(RuntimeContext *, + /*TLS*/ char *tls_base, + uint32_t patch_idx); struct range_task_helper_context { - Context *context; + RuntimeContext *context; range_for_xlogue prologue{nullptr}; RangeForTaskFunc *body{nullptr}; range_for_xlogue epilogue{nullptr}; @@ -1323,7 +1356,7 @@ void cpu_parallel_range_for_task(void *range_context, if (ctx.prologue) ctx.prologue(ctx.context, tls_ptr); - Context this_thread_context = *ctx.context; + RuntimeContext this_thread_context = *ctx.context; this_thread_context.cpu_thread_id = thread_id; if (ctx.step == 1) { int block_start = ctx.begin + task_id * ctx.block_size; @@ -1342,7 +1375,7 @@ void cpu_parallel_range_for_task(void *range_context, ctx.epilogue(ctx.context, tls_ptr); } -void cpu_parallel_range_for(Context *context, +void cpu_parallel_range_for(RuntimeContext *context, int num_threads, int begin, int end, @@ -1379,7 +1412,7 @@ void cpu_parallel_range_for(Context *context, &ctx, cpu_parallel_range_for_task); } -void gpu_parallel_range_for(Context *context, +void gpu_parallel_range_for(RuntimeContext *context, int begin, int end, range_for_xlogue prologue, @@ -1399,7 +1432,84 @@ void gpu_parallel_range_for(Context *context, epilogue(context, tls_ptr); } -i32 linear_thread_idx(Context *context) { +struct mesh_task_helper_context { + RuntimeContext *context; + mesh_for_xlogue prologue{nullptr}; + RangeForTaskFunc *body{nullptr}; + mesh_for_xlogue epilogue{nullptr}; + std::size_t tls_size{1}; + int num_patches; + int block_size; +}; + +void cpu_parallel_mesh_for_task(void *range_context, + int thread_id, + int task_id) { + auto ctx = *(mesh_task_helper_context *)range_context; + alignas(8) char tls_buffer[ctx.tls_size]; + auto tls_ptr = &tls_buffer[0]; + + RuntimeContext this_thread_context = *ctx.context; + this_thread_context.cpu_thread_id = thread_id; + + int block_start = task_id * ctx.block_size; + int block_end = std::min(block_start + ctx.block_size, ctx.num_patches); + + for (int idx = block_start; idx < block_end; idx++) { + if (ctx.prologue) + ctx.prologue(ctx.context, tls_ptr, idx); + ctx.body(&this_thread_context, tls_ptr, idx); + if (ctx.epilogue) + ctx.epilogue(ctx.context, tls_ptr, idx); + } +} + +void cpu_parallel_mesh_for(RuntimeContext *context, + int num_threads, + int num_patches, + int block_dim, + mesh_for_xlogue prologue, + RangeForTaskFunc *body, + mesh_for_xlogue epilogue, + std::size_t tls_size) { + mesh_task_helper_context ctx; + ctx.context = context; + ctx.prologue = prologue; + ctx.tls_size = tls_size; + ctx.body = body; + ctx.epilogue = epilogue; + ctx.num_patches = num_patches; + if (block_dim == 0) { + // adaptive block dim + // ensure each thread has at least ~32 tasks for load balancing + // and each task has at least 512 items to amortize scheduler overhead + block_dim = std::min(512, std::max(1, num_patches / (num_threads * 32))); + } + ctx.block_size = block_dim; + auto runtime = context->runtime; + runtime->parallel_for(runtime->thread_pool, + (num_patches + block_dim - 1) / block_dim, num_threads, + &ctx, cpu_parallel_mesh_for_task); +} + +void gpu_parallel_mesh_for(RuntimeContext *context, + int num_patches, + mesh_for_xlogue prologue, + MeshForTaskFunc *func, + mesh_for_xlogue epilogue, + const std::size_t tls_size) { + alignas(8) char tls_buffer[tls_size]; + auto tls_ptr = &tls_buffer[0]; + for (int idx = block_idx(); idx < num_patches; idx += grid_dim()) { + if (prologue) + prologue(context, tls_ptr, idx); + func(context, tls_ptr, idx); + if (epilogue) + epilogue(context, tls_ptr, idx); + } +} + +i32 linear_thread_idx(RuntimeContext *context) { #if ARCH_cuda return block_idx() * block_dim() + thread_idx(); #else @@ -1443,7 +1553,7 @@ void node_gc(LLVMRuntime *runtime, int snode_id) { runtime->node_allocators[snode_id]->gc_serial(); } -void gc_parallel_0(Context *context, int snode_id) { +void gc_parallel_0(RuntimeContext *context, int snode_id) { LLVMRuntime *runtime = context->runtime; auto allocator = runtime->node_allocators[snode_id]; auto free_list = allocator->free_list; @@ -1471,7 +1581,7 @@ void gc_parallel_0(Context *context, int snode_id) { } } -void gc_parallel_1(Context *context, int snode_id) { +void gc_parallel_1(RuntimeContext *context, int snode_id) { LLVMRuntime *runtime = context->runtime; auto allocator = runtime->node_allocators[snode_id]; auto free_list = allocator->free_list; @@ -1485,7 +1595,7 @@ void gc_parallel_1(Context *context, int snode_id) { allocator->recycled_list->clear(); } -void gc_parallel_2(Context *context, int snode_id) { +void gc_parallel_2(RuntimeContext *context, int snode_id) { LLVMRuntime *runtime = context->runtime; auto allocator = runtime->node_allocators[snode_id]; auto elements = allocator->recycle_list_size_backup; @@ -1529,7 +1639,7 @@ void gc_parallel_2(Context *context, int snode_id) { extern "C" { -u32 rand_u32(Context *context) { +u32 rand_u32(RuntimeContext *context) { auto state = &((LLVMRuntime *)context->runtime) ->rand_states[linear_thread_idx(context)]; @@ -1548,23 +1658,23 @@ u32 rand_u32(Context *context) { // it decorrelates streams of PRNGs. } -uint64 rand_u64(Context *context) { +uint64 rand_u64(RuntimeContext *context) { return ((u64)rand_u32(context) << 32) + rand_u32(context); } -f32 rand_f32(Context *context) { - return rand_u32(context) * (1.0f / 4294967296.0f); +f32 rand_f32(RuntimeContext *context) { + return (rand_u32(context) >> 8) * (1.0f / 16777216.0f); } -f64 rand_f64(Context *context) { - return rand_u64(context) * (1.0 / 18446744073709551616.0); +f64 rand_f64(RuntimeContext *context) { + return (rand_u64(context) >> 11) * (1.0 / 9007199254740992.0); } -i32 rand_i32(Context *context) { +i32 rand_i32(RuntimeContext *context) { return rand_u32(context); } -i64 rand_i64(Context *context) { +i64 rand_i64(RuntimeContext *context) { return rand_u64(context); } }; @@ -1726,18 +1836,20 @@ i32 kWasmPrintBufferSize = 1024 * 1024; } extern "C" { -// The input means starting address of Context, which should be set to +// The input means starting address of RuntimeContext, which should be set to // '__heap_base' to avoid conflicts with C++ stack data which is stored in // memory. The function returns starting address of root buffer. The print -// buffer locates just before Context (8MB). Here is an illustration for +// buffer locates just before RuntimeContext (8MB). Here is an illustration for // proper memory layout in WASM: -// ━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━ -// Print ┃▄ Context ┃ ▄ Runtime ▄ ┃ RandState[0] ┃ Root Buffer ... -// ━━━━━━━┻│━━━━━━━━━▲━━│━━━━━━━━━│━━▲━━━━━━━━━━━━━━▲━━━━━━━━━━━━━━━━━━━ -// └─────────┘ │ └──┘ │ -// └───────────────────────────┘ -i32 wasm_materialize(Context *context) { - context->runtime = (LLVMRuntime *)((size_t)context + sizeof(Context)); +// clang-format off +// ━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━ +// Print ┃▄ RuntimeContext ┃ ▄ Runtime ▄ ┃ RandState[0] ┃ Root Buffer ... +// ━━━━━━━┻│━━━━━━━━━━━━━━━━▲━━│━━━━━━━━━│━━▲━━━━━━━━━━━━━━▲━━━━━━━━━━━━━━━━━━━ +// └────────────────┘ │ └──┘ │ +// └───────────────────────────┘ +// clang-format on +i32 wasm_materialize(RuntimeContext *context) { + context->runtime = (LLVMRuntime *)((size_t)context + sizeof(RuntimeContext)); context->runtime->rand_states = (RandState *)((size_t)context->runtime + sizeof(LLVMRuntime)); // set random seed to (1, 0, 0, 0) @@ -1758,11 +1870,11 @@ i32 wasm_materialize(Context *context) { // char c[4]; // } data; //} wasm_buffer_buffer[kWasmPrintBufferSize]; -void wasm_set_print_buffer(Context *context, Ptr buffer) { +void wasm_set_print_buffer(RuntimeContext *context, Ptr buffer) { context->runtime->wasm_print_buffer = buffer; } -void wasm_print_i32(Context *context, i32 value) { +void wasm_print_i32(RuntimeContext *context, i32 value) { Ptr buffer = context->runtime->wasm_print_buffer; if (buffer == nullptr) return; @@ -1772,7 +1884,7 @@ void wasm_print_i32(Context *context, i32 value) { ((i32 *)buffer)[print_pos * 2 + 2] = value; } -void wasm_print_f32(Context *context, f32 value) { +void wasm_print_f32(RuntimeContext *context, f32 value) { Ptr buffer = context->runtime->wasm_print_buffer; if (buffer == nullptr) return; @@ -1782,7 +1894,7 @@ void wasm_print_f32(Context *context, f32 value) { ((f32 *)buffer)[print_pos * 2 + 2] = value; } -void wasm_print_char(Context *context, +void wasm_print_char(RuntimeContext *context, i8 value0, i8 value1, i8 value2, @@ -1799,11 +1911,15 @@ void wasm_print_char(Context *context, ((i8 *)buffer)[print_pos * 8 + 11] = value3; } -void wasm_set_kernel_parameter_i32(Context *context, int index, i32 value) { +void wasm_set_kernel_parameter_i32(RuntimeContext *context, + int index, + i32 value) { *(i32 *)(&context->args[index]) = value; } -void wasm_set_kernel_parameter_f32(Context *context, int index, f32 value) { +void wasm_set_kernel_parameter_f32(RuntimeContext *context, + int index, + f32 value) { *(f32 *)(&context->args[index]) = value; } } diff --git a/taichi/runtime/runtime.h b/taichi/runtime/runtime.h index f2d5714af7f3d..523675abae4e6 100644 --- a/taichi/runtime/runtime.h +++ b/taichi/runtime/runtime.h @@ -1,7 +1,7 @@ #pragma once #include "taichi/common/core.h" -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" #include "taichi/program/kernel_profiler.h" #include diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index 3ece99d9525ec..146d095ea952b 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -1,3 +1,4 @@ +#ifdef TI_WITH_LLVM #include "taichi/struct/struct_llvm.h" #include "llvm/IR/Verifier.h" @@ -13,21 +14,25 @@ namespace lang { StructCompilerLLVM::StructCompilerLLVM(Arch arch, const CompileConfig *config, TaichiLLVMContext *tlctx, - std::unique_ptr &&module) + std::unique_ptr &&module, + int snode_tree_id) : LLVMModuleBuilder(std::move(module), tlctx), arch_(arch), config_(config), tlctx_(tlctx), - llvm_ctx_(tlctx_->get_this_thread_context()) { + llvm_ctx_(tlctx_->get_this_thread_context()), + snode_tree_id_(snode_tree_id) { } StructCompilerLLVM::StructCompilerLLVM(Arch arch, LlvmProgramImpl *prog, - std::unique_ptr &&module) + std::unique_ptr &&module, + int snode_tree_id) : StructCompilerLLVM(arch, prog->config, prog->get_llvm_context(arch), - std::move(module)) { + std::move(module), + snode_tree_id) { } void StructCompilerLLVM::generate_types(SNode &snode) { @@ -56,6 +61,13 @@ void StructCompilerLLVM::generate_types(SNode &snode) { snode.cell_size_bytes = tlctx_->get_type_size(ch_type); + for (int i = 0; i < snode.ch.size(); i++) { + if (!snode.ch[i]->is_bit_level) { + snode.ch[i]->offset_bytes_in_parent_cell = + tlctx_->get_struct_element_offset(ch_type, i); + } + } + llvm::Type *body_type = nullptr, *aux_type = nullptr; if (type == SNodeType::dense || type == SNodeType::bitmasked) { TI_ASSERT(snode._morton == false); @@ -160,8 +172,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) { // Create a dummy function in the module with the type stub as return type // so that the type is referenced in the module auto ft = llvm::FunctionType::get(llvm::PointerType::get(stub, 0), false); - llvm::Function::Create(ft, llvm::Function::ExternalLinkage, - type_stub_name(&snode) + "_func", module.get()); + create_function(ft, type_stub_name(&snode) + "_func"); } void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { @@ -174,9 +185,7 @@ void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { {coord_type_ptr, coord_type_ptr, llvm::Type::getInt32Ty(*llvm_ctx_)}, false); - auto func = - llvm::Function::Create(ft, llvm::Function::ExternalLinkage, - snode->refine_coordinates_func_name(), *module); + auto func = create_function(ft, snode->refine_coordinates_func_name()); auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry", func); @@ -250,9 +259,7 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) { llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx_), {llvm::Type::getInt8PtrTy(*llvm_ctx_)}, false); - auto func = - llvm::Function::Create(ft, llvm::Function::ExternalLinkage, - snode.get_ch_from_parent_func_name(), *module); + auto func = create_function(ft, snode.get_ch_from_parent_func_name()); auto bb = llvm::BasicBlock::Create(*llvm_ctx_, "entry", func); @@ -306,7 +313,7 @@ void StructCompilerLLVM::run(SNode &root) { TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes); auto node_type = get_llvm_node_type(module.get(), &root); - root_size = tlctx_->get_data_layout().getTypeAllocSize(node_type); + root_size = tlctx_->get_type_size(node_type); tlctx_->set_struct_module(module); } @@ -345,5 +352,14 @@ llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module, return get_stub(module, snode, 3); } +llvm::Function *StructCompilerLLVM::create_function(llvm::FunctionType *ft, + std::string func_name) { + tlctx_->add_function_to_snode_tree(snode_tree_id_, func_name); + return llvm::Function::Create(ft, llvm::Function::ExternalLinkage, func_name, + *module); +} + } // namespace lang } // namespace taichi + +#endif //#ifdef TI_WITH_LLVM diff --git a/taichi/struct/struct_llvm.h b/taichi/struct/struct_llvm.h index b55fd7672feae..df16363742b1b 100644 --- a/taichi/struct/struct_llvm.h +++ b/taichi/struct/struct_llvm.h @@ -1,4 +1,5 @@ #pragma once +#ifdef TI_WITH_LLVM // Codegen for the hierarchical data structure (LLVM) #include "taichi/llvm/llvm_program.h" #include "taichi/llvm/llvm_codegen_utils.h" @@ -6,16 +7,20 @@ namespace taichi { namespace lang { + +class LlvmProgramImpl; class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder { public: StructCompilerLLVM(Arch arch, const CompileConfig *config, TaichiLLVMContext *tlctx, - std::unique_ptr &&module); + std::unique_ptr &&module, + int snode_tree_id); StructCompilerLLVM(Arch arch, LlvmProgramImpl *prog, - std::unique_ptr &&module); + std::unique_ptr &&module, + int snode_tree_id); void generate_types(SNode &snode) override; @@ -23,6 +28,9 @@ class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder { void run(SNode &node) override; + llvm::Function *create_function(llvm::FunctionType *ft, + std::string func_name); + void generate_refine_coordinates(SNode *snode); static std::string type_stub_name(SNode *snode); @@ -42,7 +50,10 @@ class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder { const CompileConfig *const config_; TaichiLLVMContext *const tlctx_; llvm::LLVMContext *const llvm_ctx_; + int snode_tree_id_; }; } // namespace lang } // namespace taichi + +#endif //#ifdef TI_WITH_LLVM diff --git a/taichi/system/benchmark.h b/taichi/system/benchmark.h index 6642510c285ca..fd1ca6d24f8bb 100644 --- a/taichi/system/benchmark.h +++ b/taichi/system/benchmark.h @@ -24,7 +24,7 @@ class Benchmark : public Unit { virtual void finalize(){}; public: - virtual void initialize(const Config &config) override { + void initialize(const Config &config) override { warm_up_iterations = config.get("warm_up_iterations", 16); workload = config.get("workload", int64(1024)); returns_time = config.get("returns_time", false); @@ -54,7 +54,7 @@ class Benchmark : public Unit { return elapsed / (iterations * workload); } - virtual bool test() const override { + bool test() const override { return true; } }; diff --git a/taichi/system/demangling.cpp b/taichi/system/demangling.cpp index 09c0fd461f54e..5059905885fb8 100644 --- a/taichi/system/demangling.cpp +++ b/taichi/system/demangling.cpp @@ -28,7 +28,7 @@ std::string cpp_demangle(const std::string &mangled_name) { } class Demangling : public Task { - virtual std::string run(const std::vector ¶meters) { + std::string run(const std::vector ¶meters) override { if (parameters.size() == 0) { printf("There should be at least one parameter for demangling.\n"); } diff --git a/taichi/system/dynamic_loader.cpp b/taichi/system/dynamic_loader.cpp index 2923fa34c8642..e1a866d74d88e 100644 --- a/taichi/system/dynamic_loader.cpp +++ b/taichi/system/dynamic_loader.cpp @@ -14,18 +14,18 @@ DynamicLoader::DynamicLoader(const std::string &dll_path) { void DynamicLoader::load_dll(const std::string &dll_path) { #ifdef WIN32 - dll = (HMODULE)LoadLibraryA(dll_path.c_str()); + dll_ = (HMODULE)LoadLibraryA(dll_path.c_str()); #else - dll = dlopen(dll_path.c_str(), RTLD_LAZY); + dll_ = dlopen(dll_path.c_str(), RTLD_LAZY); #endif } void *DynamicLoader::load_function(const std::string &func_name) { TI_ASSERT_INFO(loaded(), "DLL not opened"); #ifdef WIN32 - auto func = (void *)GetProcAddress((HMODULE)dll, func_name.c_str()); + auto func = (void *)GetProcAddress((HMODULE)dll_, func_name.c_str()); #else - auto func = dlsym(dll, func_name.c_str()); + auto func = dlsym(dll_, func_name.c_str()); const char *dlsym_error = dlerror(); TI_ERROR_IF(dlsym_error, "Cannot load function: {}", dlsym_error); #endif @@ -36,11 +36,11 @@ void *DynamicLoader::load_function(const std::string &func_name) { void DynamicLoader::close_dll() { TI_ASSERT_INFO(loaded(), "DLL not opened"); #ifdef WIN32 - FreeLibrary((HMODULE)dll); + FreeLibrary((HMODULE)dll_); #else - dlclose(dll); + dlclose(dll_); #endif - dll = nullptr; + dll_ = nullptr; } DynamicLoader::~DynamicLoader() { @@ -49,7 +49,7 @@ DynamicLoader::~DynamicLoader() { } bool DynamicLoader::loaded() const { - return dll != nullptr; + return dll_ != nullptr; } TI_NAMESPACE_END diff --git a/taichi/system/dynamic_loader.h b/taichi/system/dynamic_loader.h index a62d4dec95832..3812f13555960 100644 --- a/taichi/system/dynamic_loader.h +++ b/taichi/system/dynamic_loader.h @@ -30,7 +30,7 @@ class DynamicLoader { ~DynamicLoader(); private: - void *dll = nullptr; + void *dll_ = nullptr; }; TI_NAMESPACE_END diff --git a/taichi/system/hacked_signal_handler.cpp b/taichi/system/hacked_signal_handler.cpp index 79a309415c550..77995be9b557f 100644 --- a/taichi/system/hacked_signal_handler.cpp +++ b/taichi/system/hacked_signal_handler.cpp @@ -1,11 +1,11 @@ #include #include "taichi/common/logging.h" +#include "taichi/system/hacked_signal_handler.h" #include "taichi/system/threading.h" #include "taichi/system/traceback.h" namespace taichi { - namespace { std::string signal_name(int sig) { @@ -40,30 +40,45 @@ void signal_handler(int signo) { TI_UNREACHABLE; } -class HackedSignalRegister { - public: - explicit HackedSignalRegister() { +} // namespace + +HackedSignalRegister::HackedSignalRegister() { #define TI_REGISTER_SIGNAL_HANDLER(name, handler) \ { \ if (std::signal(name, handler) == SIG_ERR) \ std::printf("Cannot register signal handler for" #name "\n"); \ } - TI_REGISTER_SIGNAL_HANDLER(SIGSEGV, signal_handler); - TI_REGISTER_SIGNAL_HANDLER(SIGABRT, signal_handler); + TI_REGISTER_SIGNAL_HANDLER(SIGSEGV, signal_handler); + TI_REGISTER_SIGNAL_HANDLER(SIGABRT, signal_handler); #if !defined(_WIN64) - TI_REGISTER_SIGNAL_HANDLER(SIGBUS, signal_handler); + TI_REGISTER_SIGNAL_HANDLER(SIGBUS, signal_handler); #endif - TI_REGISTER_SIGNAL_HANDLER(SIGFPE, signal_handler); + TI_REGISTER_SIGNAL_HANDLER(SIGFPE, signal_handler); #undef TI_REGISTER_SIGNAL_HANDLER - Logger::get_instance().set_print_stacktrace_func(print_traceback); - TI_TRACE("Taichi core started. Thread ID = {}", PID::get_pid()); + Logger::get_instance().set_print_stacktrace_func(print_traceback); + TI_TRACE("Taichi signal handlers registered. Thread ID = {}", PID::get_pid()); +} + +HackedSignalRegister::~HackedSignalRegister() { +#define TI_UNREGISTER_SIGNAL_HANDLER(name) \ + { \ + if (std::signal(name, SIG_DFL) == SIG_ERR) \ + std::printf("Cannot unregister signal handler for" #name "\n"); \ } -}; -HackedSignalRegister _; + TI_UNREGISTER_SIGNAL_HANDLER(SIGSEGV); + TI_UNREGISTER_SIGNAL_HANDLER(SIGABRT); +#if !defined(_WIN64) + TI_UNREGISTER_SIGNAL_HANDLER(SIGBUS); +#endif + TI_UNREGISTER_SIGNAL_HANDLER(SIGFPE); + +#undef TI_UNREGISTER_SIGNAL_HANDLER + TI_TRACE("Taichi signal handlers unregistered. Thread ID = {}", + PID::get_pid()); +} -} // namespace } // namespace taichi diff --git a/taichi/system/hacked_signal_handler.h b/taichi/system/hacked_signal_handler.h new file mode 100644 index 0000000000000..f780268442e6d --- /dev/null +++ b/taichi/system/hacked_signal_handler.h @@ -0,0 +1,11 @@ +#pragma once + +namespace taichi { + +class HackedSignalRegister { + public: + explicit HackedSignalRegister(); + ~HackedSignalRegister(); +}; + +} // namespace taichi diff --git a/taichi/system/interface_registry.cpp b/taichi/system/interface_registry.cpp index f1d9e8877a2db..765fa0ce358a4 100644 --- a/taichi/system/interface_registry.cpp +++ b/taichi/system/interface_registry.cpp @@ -1,6 +1,5 @@ #include -#include "pybind11/pybind11.h" #include "taichi/common/interface.h" #include "taichi/common/task.h" #include "taichi/system/benchmark.h" diff --git a/taichi/system/memory_pool.cpp b/taichi/system/memory_pool.cpp index 4f1468181336b..0da8b424f84c4 100644 --- a/taichi/system/memory_pool.cpp +++ b/taichi/system/memory_pool.cpp @@ -3,7 +3,8 @@ #include "taichi/backends/cuda/cuda_driver.h" #include "taichi/backends/cuda/cuda_device.h" -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { // In the future we wish to move the MemoryPool inside each Device // so that the memory allocated from each Device can be used as-is. @@ -136,4 +137,5 @@ MemoryPool::~MemoryPool() { } } -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/system/memory_pool.h b/taichi/system/memory_pool.h index 317891de0d1c1..54da6df300eec 100644 --- a/taichi/system/memory_pool.h +++ b/taichi/system/memory_pool.h @@ -9,11 +9,12 @@ #include #include -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { // A memory pool that runs on the host -class MemoryPool { +class TI_DLL_EXPORT MemoryPool { public: std::vector> allocators; static constexpr std::size_t default_allocator_size = @@ -53,4 +54,5 @@ class MemoryPool { Device *device_; }; -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/system/profiler.cpp b/taichi/system/profiler.cpp index 284d512383aba..438e5c3f5bd0c 100644 --- a/taichi/system/profiler.cpp +++ b/taichi/system/profiler.cpp @@ -213,19 +213,19 @@ void ProfilerRecords::print(ProfilerRecordNode *node, int depth) { } ScopedProfiler::ScopedProfiler(std::string name, uint64 elements) { - start_time = Time::get_time(); - this->name = name; - this->elements = elements; - stopped = false; + start_time_ = Time::get_time(); + this->name_ = name; + this->elements_ = elements; + stopped_ = false; ProfilerRecords::get_this_thread_instance().push(name); } void ScopedProfiler::stop() { - TI_ASSERT_INFO(!stopped, "Profiler already stopped."); - float64 elapsed = Time::get_time() - start_time; - if ((int64)elements != -1) { + TI_ASSERT_INFO(!stopped_, "Profiler already stopped."); + float64 elapsed = Time::get_time() - start_time_; + if ((int64)elements_ != -1) { ProfilerRecords::get_this_thread_instance().insert_sample(elapsed, - elements); + elements_); } else { ProfilerRecords::get_this_thread_instance().insert_sample(elapsed); } @@ -241,7 +241,7 @@ void ScopedProfiler::enable() { } ScopedProfiler::~ScopedProfiler() { - if (!stopped) { + if (!stopped_) { stop(); } } @@ -252,27 +252,27 @@ Profiling &Profiling::get_instance() { } ProfilerRecords *Profiling::get_this_thread_profiler() { - std::lock_guard _(mut); + std::lock_guard _(mut_); auto id = std::this_thread::get_id(); std::stringstream ss; ss << id; - if (profilers.find(id) == profilers.end()) { + if (profilers_.find(id) == profilers_.end()) { // Note: thread id may be reused - profilers[id] = new ProfilerRecords(fmt::format("thread {}", ss.str())); + profilers_[id] = new ProfilerRecords(fmt::format("thread {}", ss.str())); } - return profilers[id]; + return profilers_[id]; } void Profiling::print_profile_info() { - std::lock_guard _(mut); - for (auto p : profilers) { + std::lock_guard _(mut_); + for (auto p : profilers_) { p.second->print(); } } void Profiling::clear_profile_info() { - std::lock_guard _(mut); - for (auto p : profilers) { + std::lock_guard _(mut_); + for (auto p : profilers_) { p.second->clear(); } } diff --git a/taichi/system/profiler.h b/taichi/system/profiler.h index 50fcdcce05106..5cd8d19731a8e 100644 --- a/taichi/system/profiler.h +++ b/taichi/system/profiler.h @@ -34,10 +34,10 @@ class ScopedProfiler { ~ScopedProfiler(); private: - std::string name; - float64 start_time; - uint64 elements; - bool stopped; + std::string name_; + float64 start_time_; + uint64 elements_; + bool stopped_; }; // A profiling system for multithreaded applications @@ -49,8 +49,8 @@ class Profiling { static Profiling &get_instance(); private: - std::mutex mut; - std::unordered_map profilers; + std::mutex mut_; + std::unordered_map profilers_; }; #define TI_PROFILER(name) taichi::ScopedProfiler _profiler_##__LINE__(name); diff --git a/taichi/system/run_tests.cpp b/taichi/system/run_tests.cpp index 3d1c58a4e4b8a..68628afd44f92 100644 --- a/taichi/system/run_tests.cpp +++ b/taichi/system/run_tests.cpp @@ -10,7 +10,7 @@ TI_NAMESPACE_BEGIN class RunTests : public Task { - virtual std::string run(const std::vector ¶meters) { + std::string run(const std::vector ¶meters) override { return std::to_string(run_tests(parameters)); } }; diff --git a/taichi/system/snode_tree_buffer_manager.cpp b/taichi/system/snode_tree_buffer_manager.cpp index 6106d0940e542..7a0d2accd90f6 100644 --- a/taichi/system/snode_tree_buffer_manager.cpp +++ b/taichi/system/snode_tree_buffer_manager.cpp @@ -1,10 +1,12 @@ #include "snode_tree_buffer_manager.h" #include "taichi/program/program.h" +#ifdef TI_WITH_LLVM #include "taichi/llvm/llvm_program.h" +#endif TLANG_NAMESPACE_BEGIN -SNodeTreeBufferManager::SNodeTreeBufferManager(LlvmProgramImpl *prog) +SNodeTreeBufferManager::SNodeTreeBufferManager(ProgramImpl *prog) : prog_(prog) { TI_TRACE("SNode tree buffer manager created."); } @@ -38,6 +40,7 @@ Ptr SNodeTreeBufferManager::allocate(JITModule *runtime_jit, std::size_t alignment, const int snode_tree_id, uint64 *result_buffer) { +#ifdef TI_WITH_LLVM TI_TRACE("allocating memory for SNode Tree {}", snode_tree_id); TI_ASSERT_INFO(snode_tree_id < kMaxNumSnodeTreesLlvm, "LLVM backend supports up to {} snode trees", @@ -46,8 +49,9 @@ Ptr SNodeTreeBufferManager::allocate(JITModule *runtime_jit, if (set_it == size_set_.end()) { runtime_jit->call( "runtime_memory_allocate_aligned", runtime, size, alignment); - auto ptr = prog_->fetch_result(taichi_result_buffer_runtime_query_id, - result_buffer); + LlvmProgramImpl *llvm_prog = static_cast(prog_); + auto ptr = llvm_prog->fetch_result( + taichi_result_buffer_runtime_query_id, result_buffer); roots_[snode_tree_id] = ptr; sizes_[snode_tree_id] = size; return ptr; @@ -64,6 +68,9 @@ Ptr SNodeTreeBufferManager::allocate(JITModule *runtime_jit, sizes_[snode_tree_id] = size; return x.second; } +#else + TI_ERROR("Llvm disabled"); +#endif } void SNodeTreeBufferManager::destroy(SNodeTree *snode_tree) { diff --git a/taichi/system/snode_tree_buffer_manager.h b/taichi/system/snode_tree_buffer_manager.h index 92d8d18d6c12d..e4e557351e6e2 100644 --- a/taichi/system/snode_tree_buffer_manager.h +++ b/taichi/system/snode_tree_buffer_manager.h @@ -1,5 +1,4 @@ #pragma once -#include "taichi/llvm/llvm_context.h" #include "taichi/inc/constants.h" #include "taichi/struct/snode_tree.h" #define TI_RUNTIME_HOST @@ -10,11 +9,12 @@ using Ptr = uint8_t *; TLANG_NAMESPACE_BEGIN -class LlvmProgramImpl; +class JITModule; +class ProgramImpl; class SNodeTreeBufferManager { public: - SNodeTreeBufferManager(LlvmProgramImpl *prog); + SNodeTreeBufferManager(ProgramImpl *prog); void merge_and_insert(Ptr ptr, std::size_t size); @@ -30,7 +30,7 @@ class SNodeTreeBufferManager { private: std::set> size_set_; std::map ptr_map_; - LlvmProgramImpl *prog_; + ProgramImpl *prog_; Ptr roots_[kMaxNumSnodeTreesLlvm]; std::size_t sizes_[kMaxNumSnodeTreesLlvm]; }; diff --git a/taichi/system/timer.h b/taichi/system/timer.h index fcaf75e0688cf..d9ccf2e19ea31 100644 --- a/taichi/system/timer.h +++ b/taichi/system/timer.h @@ -31,7 +31,7 @@ TI_NAMESPACE_BEGIN #include -class Time { +class TI_DLL_EXPORT Time { public: static double get_time(); static uint64 get_cycles(); @@ -69,14 +69,16 @@ class Time { class TickTimer : public Timer { protected: - double get_time(); + double get_time() override; - void print_record(const char *left, double elapsed, double average); + void print_record(const char *left, + double elapsed, + double average) override; public: TickTimer(std::string name); - ~TickTimer() { + ~TickTimer() override { output(); } }; diff --git a/taichi/system/traceback.cpp b/taichi/system/traceback.cpp index 2c4c6c406856c..a9a2222e93499 100644 --- a/taichi/system/traceback.cpp +++ b/taichi/system/traceback.cpp @@ -15,7 +15,8 @@ #include #include "spdlog/fmt/bundled/color.h" -#ifdef __APPLE__ +#if defined(__APPLE__) || (defined(__unix__) && !defined(__linux__)) && \ + !defined(ANDROID) && !defined(TI_EMSCRIPTENED) #include #include #endif @@ -91,7 +92,7 @@ inline std::vector stack_trace() { HANDLE thread = GetCurrentThread(); if (SymInitialize(process, NULL, TRUE) == FALSE) { - trace(__FUNCTION__ ": Failed to call SymInitialize."); + trace("Failed to call SymInitialize."); return std::vector(); } @@ -149,15 +150,14 @@ inline std::vector stack_trace() { #endif char symbolBuffer[sizeof(IMAGEHLP_SYMBOL) + 255]; PIMAGEHLP_SYMBOL symbol = (PIMAGEHLP_SYMBOL)symbolBuffer; - symbol->SizeOfStruct = (sizeof IMAGEHLP_SYMBOL) + 255; + symbol->SizeOfStruct = sizeof(IMAGEHLP_SYMBOL) + 255; symbol->MaxNameLength = 254; if (SymGetSymFromAddr(process, frame.AddrPC.Offset, &offset, symbol)) { f.name = symbol->Name; } else { DWORD error = GetLastError(); - trace(__FUNCTION__ ": Failed to resolve address 0x%X: %u\n", - frame.AddrPC.Offset, error); + trace("Failed to resolve address 0x%X: %u\n", frame.AddrPC.Offset, error); f.name = "Unknown Function"; } @@ -170,8 +170,8 @@ inline std::vector stack_trace() { f.line = line.LineNumber; } else { DWORD error = GetLastError(); - trace(__FUNCTION__ ": Failed to resolve line for 0x%X: %u\n", - frame.AddrPC.Offset, error); + trace("Failed to resolve line for 0x%X: %u\n", frame.AddrPC.Offset, + error); f.line = 0; } @@ -187,7 +187,7 @@ inline std::vector stack_trace() { } } // namespace dbg #endif -#ifdef __linux__ +#if defined(__linux__) && !defined(ANDROID) #include #include #include @@ -197,7 +197,7 @@ inline std::vector stack_trace() { TI_NAMESPACE_BEGIN -TI_EXPORT void print_traceback() { +void print_traceback() { #ifdef __APPLE__ static std::mutex traceback_printer_mutex; // Modified based on @@ -302,6 +302,18 @@ TI_EXPORT void print_traceback() { fmt::print(fg(fmt::color::magenta), fmt::format(" in {}\n", stack[i].module)); } +#elif defined(ANDROID) + // Not supported + fmt::print(fg(fmt::color::magenta), "***********************************\n"); + fmt::print(fg(fmt::color::magenta), "* Taichi Compiler Stack Traceback *\n"); + fmt::print(fg(fmt::color::magenta), "***********************************\n"); + fmt::print(fg(fmt::color::magenta), "NOT SUPPORTED ON ANDROID\n"); +#elif defined(TI_EMSCRIPTENED) + // Not supported + fmt::print(fg(fmt::color::magenta), "***********************************\n"); + fmt::print(fg(fmt::color::magenta), + "* Emscriptened Taichi Compiler Stack Traceback *\n"); + fmt::print(fg(fmt::color::magenta), "***********************************\n"); #else // Based on http://man7.org/linux/man-pages/man3/backtrace.3.html constexpr int BT_BUF_SIZE = 1024; diff --git a/taichi/system/traceback.h b/taichi/system/traceback.h index bb8b8db782129..8963ab9c91624 100644 --- a/taichi/system/traceback.h +++ b/taichi/system/traceback.h @@ -2,6 +2,6 @@ namespace taichi { -TI_EXPORT void print_traceback(); +void print_traceback(); } // namespace taichi diff --git a/taichi/system/unified_allocator.cpp b/taichi/system/unified_allocator.cpp index 5d827fe280f78..57ab154070109 100644 --- a/taichi/system/unified_allocator.cpp +++ b/taichi/system/unified_allocator.cpp @@ -16,9 +16,10 @@ TLANG_NAMESPACE_BEGIN UnifiedAllocator::UnifiedAllocator(std::size_t size, Arch arch, Device *device) - : size(size), arch_(arch), device_(device) { + : size_(size), arch_(arch), device_(device) { auto t = Time::get_time(); if (arch_ == Arch::x64) { +#ifdef TI_WITH_LLVM Device::AllocParams alloc_params; alloc_params.size = size; alloc_params.host_read = true; @@ -27,11 +28,14 @@ UnifiedAllocator::UnifiedAllocator(std::size_t size, Arch arch, Device *device) cpu::CpuDevice *cpu_device = static_cast(device); alloc = cpu_device->allocate_memory(alloc_params); data = (uint8 *)cpu_device->get_alloc_info(alloc).ptr; +#else + TI_NOT_IMPLEMENTED +#endif } else { TI_TRACE("Allocating virtual address space of size {} MB", size / 1024 / 1024); - cpu_vm = std::make_unique(size); - data = (uint8 *)cpu_vm->ptr; + cpu_vm_ = std::make_unique(size); + data = (uint8 *)cpu_vm_->ptr; } TI_ASSERT(data != nullptr); TI_ASSERT(uint64(data) % 4096 == 0); @@ -52,7 +56,7 @@ taichi::lang::UnifiedAllocator::~UnifiedAllocator() { } void taichi::lang::UnifiedAllocator::memset(unsigned char val) { - std::memset(data, val, size); + std::memset(data, val, size_); } TLANG_NAMESPACE_END diff --git a/taichi/system/unified_allocator.h b/taichi/system/unified_allocator.h index b83ef039480db..a52a23fdca998 100644 --- a/taichi/system/unified_allocator.h +++ b/taichi/system/unified_allocator.h @@ -3,7 +3,7 @@ #include #include -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" #include "taichi/backends/device.h" namespace taichi { @@ -14,8 +14,8 @@ TLANG_NAMESPACE_BEGIN // This class can only have one instance class UnifiedAllocator { - std::unique_ptr cpu_vm; - std::size_t size; + std::unique_ptr cpu_vm_; + std::size_t size_; Arch arch_; // put these two on the unified memory so that GPU can have access diff --git a/taichi/system/virtual_memory.h b/taichi/system/virtual_memory.h index 3bd3869ae0272..47ee75ee15a61 100644 --- a/taichi/system/virtual_memory.h +++ b/taichi/system/virtual_memory.h @@ -20,7 +20,7 @@ class VirtualMemoryAllocator { // http://pages.cs.wisc.edu/~sifakis/papers/SPGrid.pdf Sec 3.1 #if defined(TI_PLATFORM_UNIX) ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE, -1, 0); + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); TI_ERROR_IF(ptr == MAP_FAILED, "Virtual memory allocation ({} B) failed.", size); #else diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 1072ec8e45042..a67c02d05122d 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -24,7 +24,7 @@ class AlgSimp : public BasicStmtVisitor { void replace_with_zero(Stmt *stmt) { auto zero = Stmt::make(LaneAttribute(stmt->ret_type)); - stmt->replace_with(zero.get()); + stmt->replace_usages_with(zero.get()); modifier.insert_before(stmt, std::move(zero)); modifier.erase(stmt); } @@ -34,7 +34,7 @@ class AlgSimp : public BasicStmtVisitor { auto one_raw = one.get(); modifier.insert_before(stmt, std::move(one)); cast_to_result_type(one_raw, stmt); - stmt->replace_with(one_raw); + stmt->replace_usages_with(one_raw); modifier.erase(stmt); } @@ -79,7 +79,7 @@ class AlgSimp : public BasicStmtVisitor { void visit(UnaryOpStmt *stmt) override { if (stmt->is_cast()) { if (stmt->cast_type == stmt->operand->ret_type) { - stmt->replace_with(stmt->operand); + stmt->replace_usages_with(stmt->operand); modifier.erase(stmt); } else if (stmt->operand->is() && stmt->operand->as()->is_cast()) { @@ -105,7 +105,7 @@ class AlgSimp : public BasicStmtVisitor { TI_ASSERT(stmt->op_type == BinaryOpType::mul); if (alg_is_one(lhs) || alg_is_one(rhs)) { // 1 * a -> a, a * 1 -> a - stmt->replace_with(alg_is_one(lhs) ? stmt->rhs : stmt->lhs); + stmt->replace_usages_with(alg_is_one(lhs) ? stmt->rhs : stmt->lhs); modifier.erase(stmt); return true; } @@ -127,7 +127,7 @@ class AlgSimp : public BasicStmtVisitor { auto result = Stmt::make(BinaryOpType::bit_shl, stmt->lhs, new_rhs.get()); result->ret_type = stmt->ret_type; - stmt->replace_with(result.get()); + stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(new_rhs)); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); @@ -141,7 +141,7 @@ class AlgSimp : public BasicStmtVisitor { cast_to_result_type(a, stmt); auto sum = Stmt::make(BinaryOpType::add, a, a); sum->ret_type = a->ret_type; - stmt->replace_with(sum.get()); + stmt->replace_usages_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); modifier.erase(stmt); return true; @@ -156,7 +156,7 @@ class AlgSimp : public BasicStmtVisitor { stmt->op_type == BinaryOpType::floordiv); if (alg_is_one(rhs)) { // a / 1 -> a - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); return true; } @@ -186,7 +186,7 @@ class AlgSimp : public BasicStmtVisitor { auto product = Stmt::make(BinaryOpType::mul, stmt->lhs, reciprocal.get()); product->ret_type = stmt->ret_type; - stmt->replace_with(product.get()); + stmt->replace_usages_with(product.get()); modifier.insert_before(stmt, std::move(reciprocal)); modifier.insert_before(stmt, std::move(product)); modifier.erase(stmt); @@ -202,7 +202,7 @@ class AlgSimp : public BasicStmtVisitor { auto result = Stmt::make(BinaryOpType::bit_sar, stmt->lhs, new_rhs.get()); result->ret_type = stmt->ret_type; - stmt->replace_with(result.get()); + stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(new_rhs)); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); @@ -228,16 +228,16 @@ class AlgSimp : public BasicStmtVisitor { stmt->op_type == BinaryOpType::bit_xor) { if (alg_is_zero(rhs)) { // a +-|^ 0 -> a - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if (stmt->op_type != BinaryOpType::sub && alg_is_zero(lhs)) { // 0 +|^ a -> a - stmt->replace_with(stmt->rhs); + stmt->replace_usages_with(stmt->rhs); modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::bit_or && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a | a -> a - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if ((stmt->op_type == BinaryOpType::sub || stmt->op_type == BinaryOpType::bit_xor) && @@ -250,7 +250,7 @@ class AlgSimp : public BasicStmtVisitor { float64 exponent = rhs->val[0].val_cast_to_float64(); if (exponent == 1) { // a ** 1 -> a - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if (exponent == 0) { // a ** 0 -> 1 @@ -261,7 +261,7 @@ class AlgSimp : public BasicStmtVisitor { cast_to_result_type(a, stmt); auto result = Stmt::make(UnaryOpType::sqrt, a); result->ret_type = a->ret_type; - stmt->replace_with(result.get()); + stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); } else if (exponent == std::round(exponent) && exponent > 0 && @@ -294,7 +294,7 @@ class AlgSimp : public BasicStmtVisitor { a_power_of_2 = new_a_power.get(); modifier.insert_before(stmt, std::move(new_a_power)); } - stmt->replace_with(result); + stmt->replace_usages_with(result); modifier.erase(stmt); } else if (exponent == std::round(exponent) && exponent < 0 && exponent >= -max_weaken_exponent) { @@ -309,7 +309,7 @@ class AlgSimp : public BasicStmtVisitor { a_to_n->ret_type = stmt->ret_type; auto result = Stmt::make(BinaryOpType::div, one_raw, a_to_n.get()); - stmt->replace_with(result.get()); + stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(new_exponent)); modifier.insert_before(stmt, std::move(a_to_n)); modifier.insert_before(stmt, std::move(result)); @@ -318,18 +318,18 @@ class AlgSimp : public BasicStmtVisitor { } else if (stmt->op_type == BinaryOpType::bit_and) { if (alg_is_minus_one(rhs)) { // a & -1 -> a - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } else if (alg_is_minus_one(lhs)) { // -1 & a -> a - stmt->replace_with(stmt->rhs); + stmt->replace_usages_with(stmt->rhs); modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::bit_sar || @@ -341,7 +341,7 @@ class AlgSimp : public BasicStmtVisitor { // 0 << a -> 0 // 0 >> a -> 0 TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); - stmt->replace_with(stmt->lhs); + stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); } } else if (is_comparison(stmt->op_type)) { diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index c09e3a0fd1ada..d9f5c52818527 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -8,6 +8,78 @@ #include TLANG_NAMESPACE_BEGIN +class IndependentBlocksJudger : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + void visit(LocalLoadStmt *stmt) override { + for (auto &lane : stmt->src.data) { + touched_allocas_.insert(lane.var->as()); + } + } + + void visit(LocalStoreStmt *stmt) override { + touched_allocas_.insert(stmt->dest->as()); + } + + void visit(AtomicOpStmt *stmt) override { + // We don't need to check the global atomics inside the range for-loops + // because + // 1. If the range for-loop is innermost, they will be captured by + // MakeAdjoint anyway + // 2. If the range for-loop is not innermost, they will be processed by + // another IndependentBlocksJudger + if (is_inside_loop_) + return; + TI_ASSERT(stmt->dest->is()); + for (const auto &node : stmt->dest->cast()->snodes.data) { + if (node->has_grad()) { + qualified_atomics_ = false; + break; + } + } + } + + void visit(RangeForStmt *stmt) override { + inner_most_loop_ = false; + is_inside_loop_ = true; + stmt->body->accept(this); + is_inside_loop_ = false; + } + + static bool run(IRNode *root) { + IndependentBlocksJudger Judger; + Block *block = root->as(); + root->accept(&Judger); + std::set outside_blocks; + // Collect all parent blocks (i.e. outside blocks) of the current block for + // local load/store stmt checks + for (auto b = block->parent_block(); b; b = b->parent_block()) { + if (b) + outside_blocks.insert(b); + } + for (const auto &alloca : Judger.touched_allocas_) { + // Test if the alloca belongs to the current block + if (outside_blocks.find(alloca->parent) != outside_blocks.end()) { + // This block is not an IB since it loads/modifies outside variables + return false; + } + } + + // To judge whether a block is an IB + // 1. No local load/store to allocas *outside* itself has been strictly + // enforced + // 2. If the #1 is satisfied, either an inner most loop or a block without + // global atomics is an IB + return Judger.qualified_atomics_ || Judger.inner_most_loop_; + } + + private: + std::set touched_allocas_; + bool qualified_atomics_ = true; + bool inner_most_loop_ = true; + bool is_inside_loop_ = false; +}; // Do automatic differentiation pass in the reverse order (reverse-mode AD) @@ -26,15 +98,15 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { public: using BasicStmtVisitor::visit; - void visit(WhileStmt *stmt) { + void visit(WhileStmt *stmt) override { TI_ERROR("WhileStmt is not supported in AutoDiff."); } - void visit(ContinueStmt *stmt) { + void visit(ContinueStmt *stmt) override { TI_ERROR("ContinueStmt is not supported in AutoDiff."); } - void visit(WhileControlStmt *stmt) { + void visit(WhileControlStmt *stmt) override { TI_ERROR("WhileControlStmt (break) is not supported in AutoDiff."); } @@ -43,65 +115,40 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { // Note: // - Local atomics should have been demoted before this pass. // - It is OK for an IB to have more than two for loops. - bool qualified = true; - std::set touched_allocas; - // TODO: remove this abuse since it *gathers nothing* - irpass::analysis::gather_statements(block, [&](Stmt *stmt) -> bool { - if (auto local_load = stmt->cast(); local_load) { - for (auto &lane : local_load->src.data) { - touched_allocas.insert(lane.var->as()); - } - } else if (auto local_store = stmt->cast(); local_store) { - touched_allocas.insert(local_store->dest->as()); - } - return false; - }); - for (auto alloca : touched_allocas) { - // Test if the alloca belongs to the current block - bool belong_to_this_block = false; - for (auto b = alloca->parent; b; b = b->parent_block()) { - if (b == block) { - belong_to_this_block = true; - } - } - if (!belong_to_this_block) { - // This block is not an IB since it loads/modifies outside variables - qualified = false; - break; - } - } - return qualified; + // - No atomics operations to the global variables which require gradient + + return IndependentBlocksJudger::run(block); } void visit_loop_body(Block *block) { if (is_independent_block(block)) { - current_ib = block; + current_ib_ = block; block->accept(this); } else { // No need to dive further } } - void visit(StructForStmt *stmt) { - TI_ASSERT(depth == 0); - depth++; - current_ib = stmt->body.get(); + void visit(StructForStmt *stmt) override { + TI_ASSERT(depth_ == 0); + depth_++; + current_ib_ = stmt->body.get(); visit_loop_body(stmt->body.get()); - depth--; - if (depth == 0) { - independent_blocks.push_back(current_ib); + depth_--; + if (depth_ == 0) { + independent_blocks_.push_back(current_ib_); } } - void visit(RangeForStmt *stmt) { - if (depth == 0) { - current_ib = stmt->body.get(); + void visit(RangeForStmt *stmt) override { + if (depth_ == 0) { + current_ib_ = stmt->body.get(); } - depth++; + depth_++; visit_loop_body(stmt->body.get()); - depth--; - if (depth == 0) { - independent_blocks.push_back(current_ib); + depth_--; + if (depth_ == 0) { + independent_blocks_.push_back(current_ib_); } } @@ -116,18 +163,18 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { } if (!has_for) { // The whole block is an IB - pass.independent_blocks.push_back(block); + pass.independent_blocks_.push_back(block); } else { root->accept(&pass); } - TI_ASSERT(!pass.independent_blocks.empty()); - return pass.independent_blocks; + TI_ASSERT(!pass.independent_blocks_.empty()); + return pass.independent_blocks_; } private: - std::vector independent_blocks; - int depth{0}; - Block *current_ib{nullptr}; + std::vector independent_blocks_; + int depth_{0}; + Block *current_ib_{nullptr}; }; // Note that SSA does not mean the instruction will be executed at most once. @@ -137,43 +184,65 @@ class PromoteSSA2LocalVar : public BasicStmtVisitor { using BasicStmtVisitor::visit; PromoteSSA2LocalVar(Block *block) { - alloca_block = block; + alloca_block_ = block; invoke_default_visitor = true; - execute_once = true; + execute_once_ = true; } void visit(Stmt *stmt) override { - if (execute_once) + if (execute_once_) return; TI_ASSERT(stmt->width() == 1); if (!(stmt->is() || stmt->is() || stmt->is() || stmt->is() || - stmt->is())) { + stmt->is() || stmt->is())) { // TODO: this list may be incomplete return; } - // Create a alloc - auto alloc = Stmt::make(1, stmt->ret_type); - auto alloc_ptr = alloc.get(); - TI_ASSERT(alloca_block); - alloca_block->insert(std::move(alloc), 0); - auto load = stmt->insert_after_me( - Stmt::make(LocalAddress(alloc_ptr, 0))); - irpass::replace_all_usages_with(stmt->parent, stmt, load); - // Create the load first so that the operand of the store won't get replaced - stmt->insert_after_me(Stmt::make(alloc_ptr, stmt)); + + if (stmt->is()) { + // Create a new alloc at the top of an ib to replace the old alloca + auto alloc = Stmt::make(1, stmt->ret_type); + auto alloc_ptr = alloc.get(); + TI_ASSERT(alloca_block_); + alloca_block_->insert(std::move(alloc), 0); + // Replace all the usages of the old stmt with that of the new one + irpass::replace_all_usages_with(stmt->parent, stmt, alloc_ptr); + + // Replace the old alloca with a local store + // and it will be replaced by a AdStackPushStmt in the following + // ReplaceLocalVarWithStacks pass + auto dtype = stmt->ret_type; + auto zero = + stmt->insert_after_me(Stmt::make(TypedConstant(dtype, 0))); + zero->insert_after_me(Stmt::make(alloc_ptr, zero)); + // Remove the old stmt + stmt->parent->erase(stmt); + } else { + // Create a alloc + auto alloc = Stmt::make(1, stmt->ret_type); + auto alloc_ptr = alloc.get(); + TI_ASSERT(alloca_block_); + alloca_block_->insert(std::move(alloc), 0); + auto load = stmt->insert_after_me( + Stmt::make(LocalAddress(alloc_ptr, 0))); + irpass::replace_all_usages_with(stmt->parent, stmt, load); + // Create the load first so that the operand of the store won't get + // replaced + stmt->insert_after_me(Stmt::make(alloc_ptr, stmt)); + } } void visit(RangeForStmt *stmt) override { - auto old_execute_once = execute_once; - execute_once = false; // loop body may be executed many times + auto old_execute_once = execute_once_; + execute_once_ = false; // loop body may be executed many times stmt->body->accept(this); - execute_once = old_execute_once; + execute_once_ = old_execute_once; } private: - Block *alloca_block{nullptr}; - bool execute_once; + Block *alloca_block_{nullptr}; + bool execute_once_; public: static void run(Block *block) { @@ -194,9 +263,9 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { TI_ASSERT(alloc->width() == 1); bool load_only = irpass::analysis::gather_statements(alloc->parent, [&](Stmt *s) { - if (auto store = s->cast()) + if (auto store = s->cast()) { return store->dest == alloc; - else if (auto atomic = s->cast()) { + } else if (auto atomic = s->cast()) { return atomic->dest == alloc; } else { return false; @@ -234,32 +303,32 @@ class ReverseOuterLoops : public BasicStmtVisitor { using BasicStmtVisitor::visit; private: - ReverseOuterLoops(const std::vector &IB) : loop_depth(0), IB(IB) { + ReverseOuterLoops(const std::vector &IB) : loop_depth_(0), ib_(IB) { } - bool is_IB(Block *block) const { - return std::find(IB.begin(), IB.end(), block) != IB.end(); + bool is_ib(Block *block) const { + return std::find(ib_.begin(), ib_.end(), block) != ib_.end(); } - void visit(StructForStmt *stmt) { - loop_depth += 1; - if (!is_IB(stmt->body.get())) + void visit(StructForStmt *stmt) override { + loop_depth_ += 1; + if (!is_ib(stmt->body.get())) stmt->body->accept(this); - loop_depth -= 1; + loop_depth_ -= 1; } - void visit(RangeForStmt *stmt) { - if (loop_depth >= 1) { + void visit(RangeForStmt *stmt) override { + if (loop_depth_ >= 1) { stmt->reversed = !stmt->reversed; } - loop_depth += 1; - if (!is_IB(stmt->body.get())) + loop_depth_ += 1; + if (!is_ib(stmt->body.get())) stmt->body->accept(this); - loop_depth -= 1; + loop_depth_ -= 1; } - int loop_depth; - std::vector IB; + int loop_depth_; + std::vector ib_; public: static void run(IRNode *root, const std::vector &IB) { @@ -430,7 +499,8 @@ class MakeAdjoint : public IRVisitor { } void visit(UnaryOpStmt *stmt) override { - if (stmt->op_type == UnaryOpType::floor) { + if (stmt->op_type == UnaryOpType::floor || + stmt->op_type == UnaryOpType::ceil) { // do nothing } else if (stmt->op_type == UnaryOpType::neg) { accumulate(stmt->operand, negate(adjoint(stmt))); diff --git a/taichi/transforms/binary_op_simplify.cpp b/taichi/transforms/binary_op_simplify.cpp index 69955ecb1345d..64b1d32d40b4e 100644 --- a/taichi/transforms/binary_op_simplify.cpp +++ b/taichi/transforms/binary_op_simplify.cpp @@ -51,7 +51,7 @@ class BinaryOpSimp : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(bin_op)); // Replace stmt now to avoid being "simplified" again - stmt->replace_with(new_stmt.get()); + stmt->replace_usages_with(new_stmt.get()); modifier.insert_before(stmt, std::move(new_stmt)); modifier.erase(stmt); return true; @@ -72,7 +72,7 @@ class BinaryOpSimp : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(mask_stmt)); // Replace stmt now to avoid being "simplified" again - stmt->replace_with(new_stmt.get()); + stmt->replace_usages_with(new_stmt.get()); modifier.insert_before(stmt, std::move(new_stmt)); modifier.erase(stmt); return true; @@ -117,7 +117,7 @@ class BinaryOpSimp : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(mask_stmt)); // Replace stmt now to avoid being "simplified" again - stmt->replace_with(new_stmt.get()); + stmt->replace_usages_with(new_stmt.get()); modifier.insert_before(stmt, std::move(new_stmt)); modifier.erase(stmt); return; diff --git a/taichi/transforms/bit_loop_vectorize.cpp b/taichi/transforms/bit_loop_vectorize.cpp index 8a127e9a5a091..f684c02144d54 100644 --- a/taichi/transforms/bit_loop_vectorize.cpp +++ b/taichi/transforms/bit_loop_vectorize.cpp @@ -40,7 +40,7 @@ class BitLoopVectorize : public IRVisitor { void visit(GlobalLoadStmt *stmt) override { auto ptr_type = stmt->src->ret_type->as(); if (in_struct_for_loop && bit_vectorize != 1) { - if (auto cit = ptr_type->get_pointee_type()->cast()) { + if (ptr_type->get_pointee_type()->cast()) { // rewrite the previous GlobalPtrStmt's return type from *cit to // *phy_type auto ptr = stmt->src->cast(); @@ -116,7 +116,7 @@ class BitLoopVectorize : public IRVisitor { stmt->insert_before_me(std::move(base_shift_op)); stmt->insert_before_me(std::move(offsetted_shift_offset)); stmt->insert_before_me(std::move(offsetted_shift_op)); - stmt->replace_with(or_op.get()); + stmt->replace_usages_with(or_op.get()); offsetted_shift_op_p->insert_after_me(std::move(or_op)); } } @@ -127,7 +127,7 @@ class BitLoopVectorize : public IRVisitor { void visit(GlobalStoreStmt *stmt) override { auto ptr_type = stmt->dest->ret_type->as(); if (in_struct_for_loop && bit_vectorize != 1) { - if (auto cit = ptr_type->get_pointee_type()->cast()) { + if (ptr_type->get_pointee_type()->cast()) { // rewrite the previous GlobalPtrStmt's return type from *cit to // *phy_type auto ptr = stmt->dest->cast(); @@ -168,7 +168,7 @@ class BitLoopVectorize : public IRVisitor { if (lhs_val == 1) { if (auto rhs = stmt->rhs->cast(); rhs && rhs->is_bit_vectorized) { - stmt->replace_with(stmt->rhs); + stmt->replace_usages_with(stmt->rhs); } } } else if (stmt->op_type == BinaryOpType::cmp_eq) { @@ -191,7 +191,7 @@ class BitLoopVectorize : public IRVisitor { // modify IR auto zero_p = zero.get(); stmt->insert_before_me(std::move(zero)); - stmt->replace_with(add.get()); + stmt->replace_usages_with(add.get()); zero_p->insert_after_me(std::move(add)); } } else if (auto lhs = stmt->lhs->cast()) { @@ -232,7 +232,7 @@ class BitLoopVectorize : public IRVisitor { stmt->insert_before_me(std::move(not_a)); stmt->insert_before_me(std::move(not_c)); stmt->insert_before_me(std::move(and_a_b)); - stmt->replace_with(and_b_c.get()); + stmt->replace_usages_with(and_b_c.get()); and_a_b_p->insert_after_me(std::move(and_b_c)); } } diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 24778f99673f0..74b0de7cd7971 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -33,7 +33,6 @@ void compile_to_offloads(IRNode *ir, const CompileConfig &config, Kernel *kernel, bool verbose, - bool vectorize, bool grad, bool ad_use_stack, bool start_from_ast) { @@ -48,6 +47,7 @@ void compile_to_offloads(IRNode *ir, } if (start_from_ast) { + irpass::frontend_type_check(ir); irpass::lower_ast(ir); print("Lowered"); } @@ -68,16 +68,6 @@ void compile_to_offloads(IRNode *ir, return; } - if (vectorize) { - irpass::loop_vectorize(ir, config); - print("Loop Vectorized"); - irpass::analysis::verify(ir); - - irpass::vector_split(ir, config.max_vector_width, config.serial_schedule); - print("Loop Split"); - irpass::analysis::verify(ir); - } - // TODO: strictly enforce bit vectorization for x86 cpu and CUDA now // create a separate CompileConfig flag for the new pass if (arch_is_cpu(config.arch) || config.arch == Arch::cuda) { @@ -91,9 +81,8 @@ void compile_to_offloads(IRNode *ir, print("Simplified I"); irpass::analysis::verify(ir); - if (irpass::inlining(ir, config, {})) { - print("Functions inlined"); - irpass::analysis::verify(ir); + if (is_extension_supported(config.arch, Extension::mesh)) { + irpass::analysis::gather_meshfor_relation_types(ir); } if (grad) { @@ -127,7 +116,7 @@ void compile_to_offloads(IRNode *ir, // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). - if (config.cfg_optimization) { + if (config.opt_level > 0 && config.cfg_optimization) { irpass::cfg_optimization(ir, false); print("Optimized by CFG"); irpass::analysis::verify(ir); @@ -179,16 +168,40 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); } + if (is_extension_supported(config.arch, Extension::mesh) && + config.demote_no_access_mesh_fors) { + irpass::demote_no_access_mesh_fors(ir); + irpass::type_check(ir, config); + print("No-access mesh-for demoted"); + irpass::analysis::verify(ir); + } + if (make_thread_local) { irpass::make_thread_local(ir, config); print("Make thread local"); } + if (is_extension_supported(config.arch, Extension::mesh)) { + irpass::make_mesh_thread_local(ir, config, {kernel->get_name()}); + print("Make mesh thread local"); + if (config.make_mesh_block_local && config.arch == Arch::cuda) { + irpass::make_mesh_block_local(ir, config, {kernel->get_name()}); + print("Make mesh block local"); + irpass::full_simplify(ir, config, {false, kernel->program}); + print("Simplified X"); + } + } + if (make_block_local) { irpass::make_block_local(ir, config, {kernel->get_name()}); print("Make block local"); } + if (is_extension_supported(config.arch, Extension::mesh)) { + irpass::demote_mesh_statements(ir, config, {kernel->get_name()}); + print("Demote mesh statements"); + } + irpass::demote_atomics(ir, config); print("Atomics demoted II"); irpass::analysis::verify(ir); @@ -243,7 +256,6 @@ void offload_to_executable(IRNode *ir, void compile_to_executable(IRNode *ir, const CompileConfig &config, Kernel *kernel, - bool vectorize, bool grad, bool ad_use_stack, bool verbose, @@ -253,8 +265,8 @@ void compile_to_executable(IRNode *ir, bool start_from_ast) { TI_AUTO_PROF; - compile_to_offloads(ir, config, kernel, verbose, vectorize, grad, - ad_use_stack, start_from_ast); + compile_to_offloads(ir, config, kernel, verbose, grad, ad_use_stack, + start_from_ast); offload_to_executable(ir, config, kernel, verbose, /*determine_ad_stack_size=*/grad && ad_use_stack, @@ -279,9 +291,21 @@ void compile_inline_function(IRNode *ir, } if (start_from_ast) { + irpass::frontend_type_check(ir); irpass::lower_ast(ir); print("Lowered"); } + irpass::lower_access(ir, config, {{}, true}); + print("Access lowered"); + irpass::analysis::verify(ir); + + irpass::die(ir); + print("DIE"); + irpass::analysis::verify(ir); + + irpass::flag_access(ir); + print("Access flagged III"); + irpass::analysis::verify(ir); irpass::type_check(ir, config); print("Typechecked"); diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 8757e63d9c2f1..2dc67c60701f9 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -33,7 +33,7 @@ class ConstantFold : public BasicStmtVisitor { return it->second.get(); auto kernel_name = fmt::format("jit_evaluator_{}", cache.size()); - auto func = [&id]() { + auto func = [&id, this]() { auto lhstmt = Stmt::make(/*arg_id=*/0, id.lhs, /*is_ptr=*/false); auto rhstmt = @@ -49,11 +49,11 @@ class ConstantFold : public BasicStmtVisitor { } } auto ret = Stmt::make(oper.get()); - current_ast_builder().insert(std::move(lhstmt)); + program->current_ast_builder()->insert(std::move(lhstmt)); if (id.is_binary) - current_ast_builder().insert(std::move(rhstmt)); - current_ast_builder().insert(std::move(oper)); - current_ast_builder().insert(std::move(ret)); + program->current_ast_builder()->insert(std::move(rhstmt)); + program->current_ast_builder()->insert(std::move(oper)); + program->current_ast_builder()->insert(std::move(ret)); }; auto ker = std::make_unique(*program, func, kernel_name); @@ -144,7 +144,7 @@ class ConstantFold : public BasicStmtVisitor { if (jit_evaluate_binary_op(new_constant, stmt, lhs->val[0], rhs->val[0])) { auto evaluated = Stmt::make(LaneAttribute(new_constant)); - stmt->replace_with(evaluated.get()); + stmt->replace_usages_with(evaluated.get()); modifier.insert_before(stmt, std::move(evaluated)); modifier.erase(stmt); } @@ -152,7 +152,7 @@ class ConstantFold : public BasicStmtVisitor { void visit(UnaryOpStmt *stmt) override { if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) { - stmt->replace_with(stmt->operand); + stmt->replace_usages_with(stmt->operand); modifier.erase(stmt); return; } @@ -180,7 +180,7 @@ class ConstantFold : public BasicStmtVisitor { if (cast_available) { auto evaluated = Stmt::make(LaneAttribute(new_constant)); - stmt->replace_with(evaluated.get()); + stmt->replace_usages_with(evaluated.get()); modifier.insert_before(stmt, std::move(evaluated)); modifier.erase(stmt); return; @@ -191,7 +191,7 @@ class ConstantFold : public BasicStmtVisitor { if (jit_evaluate_unary_op(new_constant, stmt, operand->val[0])) { auto evaluated = Stmt::make(LaneAttribute(new_constant)); - stmt->replace_with(evaluated.get()); + stmt->replace_usages_with(evaluated.get()); modifier.insert_before(stmt, std::move(evaluated)); modifier.erase(stmt); } @@ -215,7 +215,7 @@ class ConstantFold : public BasicStmtVisitor { result_stmt = Stmt::make(LaneAttribute( TypedConstant(input->val[0].dt, result))); } - stmt->replace_with(result_stmt.get()); + stmt->replace_usages_with(result_stmt.get()); modifier.insert_before(stmt, std::move(result_stmt)); modifier.erase(stmt); } diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp index 0e9a1f2e6a4cb..0bdbf062abc81 100644 --- a/taichi/transforms/demote_atomics.cpp +++ b/taichi/transforms/demote_atomics.cpp @@ -13,6 +13,7 @@ TLANG_NAMESPACE_BEGIN class DemoteAtomics : public BasicStmtVisitor { private: std::unordered_map loop_unique_ptr_; + std::unordered_map loop_unique_arr_ptr_; public: using BasicStmtVisitor::visit; @@ -40,23 +41,59 @@ class DemoteAtomics : public BasicStmtVisitor { } if (!demote && (current_offloaded->task_type == OffloadedTaskType::range_for || - current_offloaded->task_type == OffloadedTaskType::struct_for) && - stmt->dest->is()) { - demote = true; - auto dest = stmt->dest->as(); - for (auto snode : dest->snodes.data) { - if (loop_unique_ptr_[snode] == nullptr || - loop_unique_ptr_[snode]->indices.empty()) { - // not uniquely accessed - demote = false; - break; + current_offloaded->task_type == OffloadedTaskType::mesh_for || + current_offloaded->task_type == OffloadedTaskType::struct_for)) { + if (stmt->dest->is()) { + demote = true; + auto dest = stmt->dest->as(); + for (auto snode : dest->snodes.data) { + if (loop_unique_ptr_[snode] == nullptr || + loop_unique_ptr_[snode]->indices.empty()) { + // not uniquely accessed + demote = false; + break; + } + if (current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::block_local) || + current_offloaded->mem_access_opt.has_flag( + snode, SNodeAccessFlag::mesh_local)) { + // BLS does not support write access yet so we keep atomic_adds. + demote = false; + break; + } } - if (current_offloaded->mem_access_opt.has_flag( - snode, SNodeAccessFlag::block_local)) { - // BLS does not support write access yet so we keep atomic_adds. + // demote from-end atomics + if (current_offloaded->task_type == OffloadedTaskType::mesh_for) { + if (dest->indices.size() == 1 && + dest->indices[0]->is()) { + auto idx = dest->indices[0]->as()->idx; + while (idx->is()) { // special case: l2g + // + g2r + idx = idx->as()->idx; + } + if (idx->is() && + idx->as()->is_mesh_index() && + loop_unique_ptr_[stmt->dest->as() + ->snodes.data[0]] != nullptr) { + demote = true; + } + } + } + } else if (stmt->dest->is()) { + ExternalPtrStmt *dest_ptr = stmt->dest->as(); + demote = true; + if (dest_ptr->indices.empty()) { demote = false; - break; } + for (Stmt *base_stmt : dest_ptr->base_ptrs.data) { + ArgLoadStmt *arg_load_stmt = base_stmt->as(); + int arg_id = arg_load_stmt->arg_id; + if (loop_unique_arr_ptr_[arg_id] == nullptr) { + // Not loop unique + demote = false; + } + } + // TODO: Is BLS / Mem Access Opt a thing for any_arr? } } } @@ -117,7 +154,7 @@ class DemoteAtomics : public BasicStmtVisitor { // value. The correct thing is to replace |stmt| $d with the loaded // old value $d'. // See also: https://github.com/taichi-dev/taichi/issues/332 - stmt->replace_with(load); + stmt->replace_usages_with(load); modifier.replace_with(stmt, std::move(new_stmts), /*replace_usages=*/false); } @@ -126,9 +163,12 @@ class DemoteAtomics : public BasicStmtVisitor { void visit(OffloadedStmt *stmt) override { current_offloaded = stmt; if (stmt->task_type == OffloadedTaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || stmt->task_type == OffloadedTaskType::struct_for) { - loop_unique_ptr_ = + auto uniquely_accessed_pointers = irpass::analysis::gather_uniquely_accessed_pointers(stmt); + loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); + loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second); } // We don't need to visit TLS/BLS prologues/epilogues. if (stmt->body) { diff --git a/taichi/transforms/demote_dense_struct_fors.cpp b/taichi/transforms/demote_dense_struct_fors.cpp index 4c058d40dd590..c7ef6d25293a2 100644 --- a/taichi/transforms/demote_dense_struct_fors.cpp +++ b/taichi/transforms/demote_dense_struct_fors.cpp @@ -29,7 +29,7 @@ void convert_to_range_for(OffloadedStmt *offloaded, bool packed) { TI_ASSERT(total_bits <= 30); // general shape calculation - no dependence on POT - int total_n = 1; + int64 total_n = 1; std::array total_shape; total_shape.fill(1); for (const auto *s : snodes) { diff --git a/taichi/transforms/demote_mesh_statements.cpp b/taichi/transforms/demote_mesh_statements.cpp new file mode 100644 index 0000000000000..9b67c64c962b9 --- /dev/null +++ b/taichi/transforms/demote_mesh_statements.cpp @@ -0,0 +1,153 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/transforms/demote_mesh_statements.h" +#include "taichi/ir/visitors.h" + +namespace taichi { +namespace lang { + +const PassID DemoteMeshStatements::id = "DemoteMeshStatements"; + +namespace irpass { + +auto get_load = [](SNode *snode, Stmt *idx, VecStatement &block) { + const auto lane = std::vector{idx}; + Stmt *globalptr = + block.push_back(LaneAttribute{snode}, lane); + Stmt *load = block.push_back(globalptr); + return load; +}; + +class ReplaceIndexConversion : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + ReplaceIndexConversion(OffloadedStmt *node) { + allow_undefined_visitor = true; + invoke_default_visitor = true; + + offload = node; + visit(node); + } + + void visit(MeshIndexConversionStmt *stmt) override { + SNode *mapping = (stmt->mesh->index_mapping + .find(std::make_pair(stmt->idx_type, stmt->conv_type)) + ->second); + + VecStatement block; + if (stmt->conv_type == mesh::ConvType::g2r) { + // E.g, v_reordered = v_g2r[v_global] + Stmt *val = get_load(mapping, stmt->idx, block); + } else { + // E.g, v_global = v_l2g[v_local + total_vertices_offset] + Stmt *offset = offload->total_offset_local.find(stmt->idx_type)->second; + Stmt *index = + block.push_back(BinaryOpType::add, stmt->idx, offset); + [[maybe_unused]] Stmt *val = get_load(mapping, index, block); + } + stmt->replace_with(std::move(block)); + } + + OffloadedStmt *offload; +}; + +void demote_mesh_statements_offload(OffloadedStmt *offload, + const CompileConfig &config, + const std::string &kernel_name) { + ReplaceIndexConversion rep_conv( + offload); // This demote should work for any offloaed statement + + if (offload->task_type != OffloadedStmt::TaskType::mesh_for) { + return; + } + + auto stmts = irpass::analysis::gather_statements( + offload->body.get(), + [&](Stmt *stmt) { return stmt->is(); }); + + for (int i = stmts.size() - 1; i >= 0; --i) { + auto stmt = stmts[i]->cast(); + mesh::MeshElementType from_type = stmt->from_type(); + + auto from_order = mesh::element_order(from_type); + auto to_order = mesh::element_order(stmt->to_type); + mesh::MeshRelationType rel_type = + mesh::relation_by_orders(from_order, to_order); + if (from_order > to_order) { // high-to-low relation + if (stmt->is_size()) { + stmt->replace_with(Stmt::make(LaneAttribute{ + from_type == mesh::MeshElementType::Cell && + stmt->to_type == mesh::MeshElementType::Edge + ? /*Cell-Edge=*/6 + : (from_order + 1)})); + } else { + SNode *rel_value = stmt->mesh->relations.find(rel_type)->second.value; + VecStatement block; + Stmt *to_size = block.push_back(LaneAttribute{ + from_type == mesh::MeshElementType::Cell && + stmt->to_type == mesh::MeshElementType::Edge + ? /*Cell-Edge=*/6 + : (from_order + 1)}); + // E.g, v_2 = CV[(c + total_cells_offset) * 4 + 2] + Stmt *offset = offload->total_offset_local.find(from_type)->second; + Stmt *tmp0 = block.push_back(BinaryOpType::add, offset, + stmt->mesh_idx); + Stmt *tmp1 = + block.push_back(BinaryOpType::mul, tmp0, to_size); + Stmt *index = block.push_back(BinaryOpType::add, tmp1, + stmt->neighbor_idx); + [[maybe_unused]] Stmt *val = get_load(rel_value, index, block); + stmt->replace_with(std::move(block)); + } + } else { // low-to-high or same-order + SNode *rel_offset = stmt->mesh->relations.find(rel_type)->second.offset; + VecStatement block; + Stmt *patch_idx = block.push_back(); + Stmt *owned_offset = offload->owned_offset_local.find(from_type)->second; + Stmt *index_offset = block.push_back( + BinaryOpType::add, patch_idx, owned_offset); + Stmt *index = block.push_back(BinaryOpType::add, + index_offset, stmt->mesh_idx); + Stmt *offset = get_load(rel_offset, index, block); + if (stmt->is_size()) { + Stmt *one = block.push_back(LaneAttribute{1}); + Stmt *index_1 = + block.push_back(BinaryOpType::add, index, one); + Stmt *offset_1 = get_load(rel_offset, index_1, block); + [[maybe_unused]] Stmt *val = + block.push_back(BinaryOpType::sub, offset_1, offset); + } else { + SNode *rel_value = stmt->mesh->relations.find(rel_type)->second.value; + Stmt *val_index = block.push_back( + BinaryOpType::add, offset, stmt->neighbor_idx); + [[maybe_unused]] Stmt *val = get_load(rel_value, val_index, block); + } + stmt->replace_with(std::move(block)); + } + } +} + +// This pass should happen after offloading but before lower_access +void demote_mesh_statements(IRNode *root, + const CompileConfig &config, + const DemoteMeshStatements::Args &args) { + TI_AUTO_PROF; + + if (auto root_block = root->cast()) { + for (auto &offload : root_block->statements) { + demote_mesh_statements_offload(offload->cast(), config, + args.kernel_name); + } + } else { + demote_mesh_statements_offload(root->as(), config, + args.kernel_name); + } + type_check(root, config); +} + +} // namespace irpass +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/demote_mesh_statements.h b/taichi/transforms/demote_mesh_statements.h new file mode 100644 index 0000000000000..dad295315d7fb --- /dev/null +++ b/taichi/transforms/demote_mesh_statements.h @@ -0,0 +1,18 @@ +#pragma once + +#include "taichi/ir/pass.h" + +namespace taichi { +namespace lang { + +class DemoteMeshStatements : public Pass { + public: + static const PassID id; + + struct Args { + std::string kernel_name; + }; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/demote_no_access_mesh_fors.cpp b/taichi/transforms/demote_no_access_mesh_fors.cpp new file mode 100644 index 0000000000000..0f2fc28c4efdd --- /dev/null +++ b/taichi/transforms/demote_no_access_mesh_fors.cpp @@ -0,0 +1,72 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include "taichi/transforms/utils.h" + +TLANG_NAMESPACE_BEGIN + +namespace { + +void convert_to_range_for(OffloadedStmt *offloaded) { + TI_ASSERT(offloaded->task_type == OffloadedTaskType::mesh_for); + + DelayedIRModifier modifier; + auto stmts = irpass::analysis::gather_statements( + offloaded->body.get(), + [&](Stmt *stmt) { return stmt->is(); }); + for (size_t i = 0; i < stmts.size(); ++i) { + auto conv_stmt = stmts[i]->cast(); + if (conv_stmt->conv_type == mesh::ConvType::l2g) { + stmts[i]->replace_usages_with(conv_stmt->idx); + modifier.erase(stmts[i]); + } else if (conv_stmt->conv_type == mesh::ConvType::l2r) { + stmts[i]->as()->conv_type = mesh::ConvType::g2r; + } + } + + modifier.modify_ir(); + + offloaded->const_begin = true; + offloaded->const_end = true; + offloaded->begin_value = 0; + offloaded->end_value = + offloaded->mesh->num_elements.find(offloaded->major_from_type)->second; + offloaded->mesh = nullptr; + offloaded->task_type = OffloadedTaskType::range_for; +} + +void maybe_convert(OffloadedStmt *offloaded) { + if (offloaded->task_type == OffloadedTaskType::mesh_for && + offloaded->major_to_types.size() == 0) { + auto stmts = irpass::analysis::gather_statements( // ti.mesh_patch_idx() + // relies on mesh-for + offloaded->body.get(), + [&](Stmt *stmt) { return stmt->is(); }); + if (stmts.size() == 0) { + convert_to_range_for(offloaded); + } + } +} + +} // namespace + +namespace irpass { + +void demote_no_access_mesh_fors(IRNode *root) { + if (auto *block = root->cast()) { + for (auto &s_ : block->statements) { + if (auto *s = s_->cast()) { + maybe_convert(s); + } + } + } else if (auto *s = root->cast()) { + maybe_convert(s); + } + re_id(root); +} + +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index ebae903db8791..a1d0372bc164a 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -30,7 +30,7 @@ class DemoteOperations : public BasicStmtVisitor { auto ret = statements.push_back(BinaryOpType::bit_and, input_sar_begin, mask); - stmt->replace_with(ret); + stmt->replace_usages_with(ret); modifier.insert_before(stmt, std::move(statements)); modifier.erase(stmt); } @@ -68,7 +68,7 @@ class DemoteOperations : public BasicStmtVisitor { auto real_ret = Stmt::make(BinaryOpType::add, ret.get(), cond.get()); - stmt->replace_with(real_ret.get()); + stmt->replace_usages_with(real_ret.get()); modifier.insert_before(stmt, std::move(ret)); modifier.insert_before(stmt, std::move(zero)); modifier.insert_before(stmt, std::move(lhs_ltz)); @@ -89,7 +89,7 @@ class DemoteOperations : public BasicStmtVisitor { // return ti.floor(r) auto div = Stmt::make(BinaryOpType::div, lhs, rhs); auto floor = Stmt::make(UnaryOpType::floor, div.get()); - stmt->replace_with(floor.get()); + stmt->replace_usages_with(floor.get()); modifier.insert_before(stmt, std::move(div)); modifier.insert_before(stmt, std::move(floor)); modifier.erase(stmt); @@ -112,7 +112,7 @@ class DemoteOperations : public BasicStmtVisitor { auto signed_cast = Stmt::make(UnaryOpType::cast_bits, shift.get()); signed_cast->as()->cast_type = lhs->element_type(); - stmt->replace_with(signed_cast.get()); + stmt->replace_usages_with(signed_cast.get()); modifier.insert_before(stmt, std::move(unsigned_cast)); modifier.insert_before(stmt, std::move(shift)); modifier.insert_before(stmt, std::move(signed_cast)); diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 355c2ae803308..3176d8f576949 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -52,7 +52,7 @@ class DIE : public IRVisitor { } } - void visit(Stmt *stmt) { + void visit(Stmt *stmt) override { TI_ASSERT(!stmt->erased); if (phase == 0) { register_usage(stmt); @@ -64,13 +64,13 @@ class DIE : public IRVisitor { } } - void visit(Block *stmt_list) { + void visit(Block *stmt_list) override { for (auto &stmt : stmt_list->statements) { stmt->accept(this); } } - void visit(IfStmt *if_stmt) { + void visit(IfStmt *if_stmt) override { register_usage(if_stmt); if (if_stmt->true_statements) if_stmt->true_statements->accept(this); @@ -79,23 +79,34 @@ class DIE : public IRVisitor { } } - void visit(WhileStmt *stmt) { + void visit(WhileStmt *stmt) override { register_usage(stmt); stmt->body->accept(this); } - void visit(RangeForStmt *for_stmt) { + void visit(RangeForStmt *for_stmt) override { register_usage(for_stmt); for_stmt->body->accept(this); } - void visit(StructForStmt *for_stmt) { + void visit(StructForStmt *for_stmt) override { register_usage(for_stmt); for_stmt->body->accept(this); } - void visit(OffloadedStmt *stmt) { - stmt->all_blocks_accept(this); + void visit(MeshForStmt *for_stmt) override { + register_usage(for_stmt); + for_stmt->body->accept(this); + } + + void visit(OffloadedStmt *stmt) override { + // TODO: A hack to make sure end_stmt is registered. + // Ideally end_stmt should be its own Block instead. + if (stmt->end_stmt && + used.find(stmt->end_stmt->instance_id) == used.end()) { + used.insert(stmt->end_stmt->instance_id); + } + stmt->all_blocks_accept(this, true); } }; diff --git a/taichi/transforms/extract_constant.cpp b/taichi/transforms/extract_constant.cpp index 991d22b38a960..d7db7c76e01ea 100644 --- a/taichi/transforms/extract_constant.cpp +++ b/taichi/transforms/extract_constant.cpp @@ -8,30 +8,30 @@ TLANG_NAMESPACE_BEGIN class ExtractConstant : public BasicStmtVisitor { private: - Block *top_level; - DelayedIRModifier modifier; + Block *top_level_; + DelayedIRModifier modifier_; public: using BasicStmtVisitor::visit; - explicit ExtractConstant(IRNode *node) : top_level(nullptr) { + explicit ExtractConstant(IRNode *node) : top_level_(nullptr) { if (node->is()) - top_level = node->as(); + top_level_ = node->as(); } void visit(ConstStmt *stmt) override { - TI_ASSERT(top_level); - if (stmt->parent != top_level) { - modifier.extract_to_block_front(stmt, top_level); + TI_ASSERT(top_level_); + if (stmt->parent != top_level_) { + modifier_.extract_to_block_front(stmt, top_level_); } } void visit(OffloadedStmt *offload) override { if (offload->body) { - Block *backup = top_level; - top_level = offload->body.get(); + Block *backup = top_level_; + top_level_ = offload->body.get(); offload->body->accept(this); - top_level = backup; + top_level_ = backup; } } @@ -40,7 +40,7 @@ class ExtractConstant : public BasicStmtVisitor { bool ir_modified = false; while (true) { node->accept(&extractor); - if (extractor.modifier.modify_ir()) { + if (extractor.modifier_.modify_ir()) { ir_modified = true; } else { break; diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 92ef2883a542d..cc127abdc2384 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -15,13 +15,13 @@ class FlagAccess : public IRVisitor { node->accept(this); } - void visit(Block *stmt_list) { // block itself has no id + void visit(Block *stmt_list) override { // block itself has no id for (auto &stmt : stmt_list->statements) { stmt->accept(this); } } - void visit(IfStmt *if_stmt) { + void visit(IfStmt *if_stmt) override { if (if_stmt->true_statements) if_stmt->true_statements->accept(this); if (if_stmt->false_statements) { @@ -29,28 +29,32 @@ class FlagAccess : public IRVisitor { } } - void visit(WhileStmt *stmt) { + void visit(WhileStmt *stmt) override { stmt->body->accept(this); } - void visit(RangeForStmt *for_stmt) { + void visit(RangeForStmt *for_stmt) override { for_stmt->body->accept(this); } - void visit(StructForStmt *for_stmt) { + void visit(StructForStmt *for_stmt) override { for_stmt->body->accept(this); } - void visit(OffloadedStmt *stmt) { + void visit(MeshForStmt *for_stmt) override { + for_stmt->body->accept(this); + } + + void visit(OffloadedStmt *stmt) override { stmt->all_blocks_accept(this); } // Assuming pointers will be visited before global load/st - void visit(GlobalPtrStmt *stmt) { + void visit(GlobalPtrStmt *stmt) override { stmt->activate = false; } - void visit(GlobalStoreStmt *stmt) { + void visit(GlobalStoreStmt *stmt) override { if (stmt->dest->is()) { stmt->dest->as()->activate = true; } @@ -62,7 +66,7 @@ class FlagAccess : public IRVisitor { } } - void visit(AtomicOpStmt *stmt) { + void visit(AtomicOpStmt *stmt) override { if (stmt->dest->is()) { stmt->dest->as()->activate = true; } @@ -84,28 +88,28 @@ class WeakenAccess : public BasicStmtVisitor { WeakenAccess(IRNode *node) { allow_undefined_visitor = true; invoke_default_visitor = false; - current_struct_for = nullptr; - current_offload = nullptr; + current_struct_for_ = nullptr; + current_offload_ = nullptr; node->accept(this); } - void visit(Block *stmt_list) { // block itself has no id + void visit(Block *stmt_list) override { // block itself has no id for (auto &stmt : stmt_list->statements) { stmt->accept(this); } } - void visit(StructForStmt *stmt) { - current_struct_for = stmt; + void visit(StructForStmt *stmt) override { + current_struct_for_ = stmt; stmt->body->accept(this); - current_struct_for = nullptr; + current_struct_for_ = nullptr; } - void visit(OffloadedStmt *stmt) { - current_offload = stmt; + void visit(OffloadedStmt *stmt) override { + current_offload_ = stmt; if (stmt->body) stmt->body->accept(this); - current_offload = nullptr; + current_offload_ = nullptr; } static SNode *least_sparse_ancestor(SNode *a) { @@ -121,20 +125,20 @@ class WeakenAccess : public BasicStmtVisitor { return least_sparse_ancestor(a) == least_sparse_ancestor(b); } - void visit(GlobalPtrStmt *stmt) { + void visit(GlobalPtrStmt *stmt) override { if (stmt->activate) { bool is_struct_for = - (current_offload && - current_offload->task_type == OffloadedStmt::TaskType::struct_for) || - current_struct_for; + (current_offload_ && current_offload_->task_type == + OffloadedStmt::TaskType::struct_for) || + current_struct_for_; if (is_struct_for) { bool same_as_loop_snode = true; for (auto snode : stmt->snodes.data) { SNode *loop_snode = nullptr; - if (current_struct_for) { - loop_snode = current_struct_for->snode; + if (current_struct_for_) { + loop_snode = current_struct_for_->snode; } else { - loop_snode = current_offload->snode; + loop_snode = current_offload_->snode; } TI_ASSERT(loop_snode); if (!share_sparsity(snode, loop_snode)) { @@ -160,8 +164,8 @@ class WeakenAccess : public BasicStmtVisitor { } private: - OffloadedStmt *current_offload; - StructForStmt *current_struct_for; + OffloadedStmt *current_offload_; + StructForStmt *current_struct_for_; }; namespace irpass { diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp new file mode 100644 index 0000000000000..9e6c21871665c --- /dev/null +++ b/taichi/transforms/frontend_type_check.cpp @@ -0,0 +1,111 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/frontend_ir.h" +#include "taichi/ir/statements.h" + +namespace taichi { +namespace lang { + +class FrontendTypeCheck : public IRVisitor { + void check_cond_type(const Expr &cond, std::string stmt_name) { + if (!cond->ret_type->is_primitive(PrimitiveTypeID::i32)) + throw TaichiTypeError(fmt::format( + "`{0}` conditions must be of type int32; found {1}. Consider using " + "`{0} x != 0` instead of `{0} x` for float values.", + stmt_name, cond->ret_type->to_string())); + } + + public: + explicit FrontendTypeCheck() { + allow_undefined_visitor = true; + } + + void visit(Block *block) override { + std::vector stmts; + // Make a copy since type casts may be inserted for type promotion. + for (auto &stmt : block->statements) + stmts.push_back(stmt.get()); + for (auto stmt : stmts) + stmt->accept(this); + } + + void visit(FrontendExternalFuncStmt *stmt) override { + // TODO: noop for now; add typechecking after we have type specification + } + + void visit(FrontendExprStmt *stmt) override { + // Noop + } + + void visit(FrontendAllocaStmt *stmt) override { + // Noop + } + + void visit(FrontendSNodeOpStmt *stmt) override { + // Noop + } + + void visit(FrontendAssertStmt *stmt) override { + check_cond_type(stmt->cond, "assert"); + } + + void visit(FrontendAssignStmt *stmt) override { + // No implicit cast at frontend for now + } + + void visit(FrontendIfStmt *stmt) override { + // TODO: use PrimitiveType::u1 when it's supported + check_cond_type(stmt->condition, "if"); + if (stmt->true_statements) + stmt->true_statements->accept(this); + if (stmt->false_statements) + stmt->false_statements->accept(this); + } + + void visit(FrontendPrintStmt *stmt) override { + // Noop + } + + void visit(FrontendEvalStmt *stmt) override { + // Noop + } + + void visit(FrontendForStmt *stmt) override { + stmt->body->accept(this); + } + + void visit(FrontendFuncDefStmt *stmt) override { + stmt->body->accept(this); + // Determine ret_type after this is actually used + } + + void visit(FrontendBreakStmt *stmt) override { + // Noop + } + + void visit(FrontendContinueStmt *stmt) override { + // Noop + } + + void visit(FrontendWhileStmt *stmt) override { + check_cond_type(stmt->cond, "while"); + stmt->body->accept(this); + } + + void visit(FrontendReturnStmt *stmt) override { + // Noop + } +}; + +namespace irpass { + +void frontend_type_check(IRNode *root) { + TI_AUTO_PROF; + FrontendTypeCheck checker; + root->accept(&checker); +} + +} // namespace irpass + +} // namespace lang + +} // namespace taichi diff --git a/taichi/transforms/inlining.cpp b/taichi/transforms/inlining.cpp index 7dc9f8dc7f9c7..9a0a47fbdd3a7 100644 --- a/taichi/transforms/inlining.cpp +++ b/taichi/transforms/inlining.cpp @@ -32,8 +32,8 @@ class Inliner : public BasicStmtVisitor { [&](Stmt *s) { return stmt->args[s->as()->arg_id]; }); } if (func->rets.empty()) { - modifier.replace_with(stmt, - std::move(inlined_ir->as()->statements)); + modifier_.replace_with(stmt, + std::move(inlined_ir->as()->statements)); } else { if (irpass::analysis::gather_statements(inlined_ir.get(), [&](Stmt *s) { return s->is(); @@ -50,13 +50,14 @@ class Inliner : public BasicStmtVisitor { /*filter=*/[&](Stmt *s) { return s->is(); }, /*generator=*/ [&](Stmt *s) { + TI_ASSERT(s->as()->values.size() == 1); return Stmt::make(return_address, - s->as()->value); + s->as()->values[0]); }); - modifier.insert_before(stmt, - std::move(inlined_ir->as()->statements)); + modifier_.insert_before(stmt, + std::move(inlined_ir->as()->statements)); // Load the return value here - modifier.replace_with( + modifier_.replace_with( stmt, Stmt::make(LocalAddress(return_address, 0))); } } @@ -66,7 +67,7 @@ class Inliner : public BasicStmtVisitor { bool modified = false; while (true) { node->accept(&inliner); - if (inliner.modifier.modify_ir()) + if (inliner.modifier_.modify_ir()) modified = true; else break; @@ -75,7 +76,7 @@ class Inliner : public BasicStmtVisitor { } private: - DelayedIRModifier modifier; + DelayedIRModifier modifier_; }; const PassID InliningPass::id = "InliningPass"; diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index ce108d219e5ba..8039981f328f1 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -62,7 +62,7 @@ class IRPrinter : public IRVisitor { void print_raw(std::string f) { for (int i = 0; i < current_indent; i++) - f = " " + f; + f.insert(0, " "); f += "\n"; if (output) { ss << f; @@ -96,7 +96,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendExprStmt *stmt) override { - print("{}", stmt->val->serialize()); + print("{}", stmt->val.serialize()); } void visit(FrontendBreakStmt *stmt) override { @@ -108,7 +108,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendAssignStmt *assign) override { - print("{} = {}", assign->lhs->serialize(), assign->rhs->serialize()); + print("{} = {}", assign->lhs.serialize(), assign->rhs.serialize()); } void visit(FrontendAllocaStmt *alloca) override { @@ -117,7 +117,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendAssertStmt *assert) override { - print("{} : assert {}", assert->id, assert->cond->serialize()); + print("{} : assert {}", assert->id, assert->cond.serialize()); } void visit(AssertStmt *assert) override { @@ -131,7 +131,15 @@ class IRPrinter : public IRVisitor { } void visit(ExternalFuncCallStmt *stmt) override { - std::string extras = "inputs="; + std::string extras; + if (stmt->so_func != nullptr) { + extras += fmt::format("so {:x} ", (uint64)stmt->so_func); + } else if (!stmt->asm_source.empty()) { + extras += fmt::format("asm \"{}\" ", stmt->asm_source); + } else { + extras += fmt::format("bc {}:{} ", stmt->bc_filename, stmt->bc_funcname); + } + extras += "inputs="; for (auto &arg : stmt->arg_stmts) { extras += ", "; extras += arg->name(); @@ -141,14 +149,13 @@ class IRPrinter : public IRVisitor { extras += ", "; extras += output->name(); } - print("{} : func_call {:x}, {}", stmt->name(), (std::size_t)stmt->func, - extras); + print("{} : {}", stmt->name(), extras); } void visit(FrontendSNodeOpStmt *stmt) override { std::string extras = "["; for (int i = 0; i < (int)stmt->indices.size(); i++) { - extras += stmt->indices[i]->serialize(); + extras += stmt->indices[i].serialize(); if (i + 1 < (int)stmt->indices.size()) extras += ", "; } @@ -180,6 +187,18 @@ class IRPrinter : public IRVisitor { print("{}{} = rand()", stmt->type_hint(), stmt->name()); } + void visit(DecorationStmt *stmt) override { + if (stmt->decoration.size() == 2 && + stmt->decoration[0] == + uint32_t(DecorationStmt::Decoration::kLoopUnique)) { + print("decorate {} : Loop-unique {}", stmt->operand->name(), + stmt->decoration[0], stmt->decoration[1]); + } else { + print("decorate {} : ... size = {}", stmt->operand->name(), + stmt->decoration.size()); + } + } + void visit(UnaryOpStmt *stmt) override { if (stmt->is_cast()) { std::string reint = @@ -223,7 +242,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendIfStmt *if_stmt) override { - print("{} : if {} {{", if_stmt->name(), if_stmt->condition->serialize()); + print("{} : if {} {{", if_stmt->name(), if_stmt->condition.serialize()); if (if_stmt->true_statements) if_stmt->true_statements->accept(this); if (if_stmt->false_statements) { @@ -233,10 +252,6 @@ class IRPrinter : public IRVisitor { print("}}"); } - void visit(FrontendEvalStmt *stmt) override { - print("{} = eval {}", stmt->name(), stmt->expr.serialize()); - } - void visit(FrontendPrintStmt *print_stmt) override { std::vector contents; for (auto const &c : print_stmt->contents) { @@ -297,12 +312,6 @@ class IRPrinter : public IRVisitor { print("}}"); } - void visit(FuncBodyStmt *stmt) override { - print("func \"{}\" {{"); - stmt->body->accept(this); - print("}}"); - } - void visit(WhileStmt *stmt) override { print("{} : while true {{", stmt->name()); stmt->body->accept(this); @@ -310,7 +319,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendWhileStmt *stmt) override { - print("{} : while {} {{", stmt->name(), stmt->cond->serialize()); + print("{} : while {} {{", stmt->name(), stmt->cond.serialize()); stmt->body->accept(this); print("}}"); } @@ -321,14 +330,16 @@ class IRPrinter : public IRVisitor { [](const Identifier &id) -> std::string { return id.name(); }); if (for_stmt->is_ranged()) { print("{} : for {} in range({}, {}) {}{{", for_stmt->name(), vars, - for_stmt->begin->serialize(), for_stmt->end->serialize(), + for_stmt->begin.serialize(), for_stmt->end.serialize(), block_dim_info(for_stmt->block_dim)); + } else if (for_stmt->mesh_for) { + print("{} : for {} in mesh {{", for_stmt->name(), vars); } else { print("{} : for {} in {} {}{}{{", for_stmt->name(), vars, for_stmt->global_var.is() ? for_stmt->global_var.cast() ->snode->get_node_type_name_hinted() - : for_stmt->global_var->serialize(), + : for_stmt->global_var.serialize(), scratch_pad_info(for_stmt->mem_access_opt), block_dim_info(for_stmt->block_dim)); } @@ -337,24 +348,34 @@ class IRPrinter : public IRVisitor { } void visit(RangeForStmt *for_stmt) override { - print("{} : {}for in range({}, {}) (vectorize {}) (bit_vectorize {}) {}{{", + print("{} : {}for in range({}, {}) (bit_vectorize {}) {}{{", for_stmt->name(), for_stmt->reversed ? "reversed " : "", - for_stmt->begin->name(), for_stmt->end->name(), for_stmt->vectorize, + for_stmt->begin->name(), for_stmt->end->name(), for_stmt->bit_vectorize, block_dim_info(for_stmt->block_dim)); for_stmt->body->accept(this); print("}}"); } void visit(StructForStmt *for_stmt) override { - print("{} : struct for in {} (vectorize {}) (bit_vectorize {}) {}{}{{", - for_stmt->name(), for_stmt->snode->get_node_type_name_hinted(), - for_stmt->vectorize, for_stmt->bit_vectorize, + print("{} : struct for in {} (bit_vectorize {}) {}{}{{", for_stmt->name(), + for_stmt->snode->get_node_type_name_hinted(), for_stmt->bit_vectorize, scratch_pad_info(for_stmt->mem_access_opt), block_dim_info(for_stmt->block_dim)); for_stmt->body->accept(this); print("}}"); } + void visit(MeshForStmt *for_stmt) override { + print("{} : mesh for ({} -> {}) {}{{", for_stmt->name(), + mesh::element_type_name(for_stmt->major_from_type), + for_stmt->major_to_types.size() == 0 + ? "Unknown" + : mesh::element_type_name(*for_stmt->major_to_types.begin()), + scratch_pad_info(for_stmt->mem_access_opt)); + for_stmt->body->accept(this); + print("}}"); + } + void visit(GlobalPtrStmt *stmt) override { std::string s = fmt::format("{}{} = global ptr [", stmt->type_hint(), stmt->name()); @@ -397,13 +418,13 @@ class IRPrinter : public IRVisitor { } void visit(FrontendReturnStmt *stmt) override { - print("{}{} : return {}", stmt->type_hint(), stmt->name(), - stmt->value->serialize()); + print("{}{} : return [{}]", stmt->type_hint(), stmt->name(), + stmt->values.serialize()); } void visit(ReturnStmt *stmt) override { print("{}{} : return {}", stmt->type_hint(), stmt->name(), - stmt->value->name()); + stmt->values_raw_names()); } void visit(LocalLoadStmt *stmt) override { @@ -513,6 +534,16 @@ class IRPrinter : public IRVisitor { } } s += "]"; + if (stmt->element_shape.size()) { + s += ", ("; + for (int i = 0; i < (int)stmt->element_shape.size(); i++) { + s += fmt::format("{}", stmt->element_shape[i]); + if (i + 1 < (int)stmt->element_shape.size()) { + s += ", "; + } + } + s += ")"; + } print(fmt::format("{}{} = external_ptr {}", stmt->type_hint(), stmt->name(), s)); @@ -529,6 +560,9 @@ class IRPrinter : public IRVisitor { } if (stmt->const_end) { end_str = std::to_string(stmt->end_value); + } else if (stmt->end_stmt && !stmt->end_stmt->is()) { + // range_for end is a non-const stmt (e.g. ndarray axis) + end_str = stmt->end_stmt->name(); } else { end_str = fmt::format("tmp(offset={}B)", stmt->end_offset); } @@ -540,6 +574,15 @@ class IRPrinter : public IRVisitor { fmt::format("struct_for({}) grid_dim={} block_dim={} bls={}", stmt->snode->get_node_type_name_hinted(), stmt->grid_dim, stmt->block_dim, scratch_pad_info(stmt->mem_access_opt)); + } else if (stmt->task_type == OffloadedTaskType::mesh_for) { + details = fmt::format( + "mesh_for({} -> {}) num_patches={} grid_dim={} block_dim={} bls={}", + mesh::element_type_name(stmt->major_from_type), + stmt->major_to_types.size() == 0 + ? "Unknown" + : mesh::element_type_name(*stmt->major_to_types.begin()), + stmt->mesh->num_patches, stmt->grid_dim, stmt->block_dim, + scratch_pad_info(stmt->mem_access_opt)); } if (stmt->task_type == OffloadedTaskType::listgen) { print("{} = offloaded listgen {}->{}", stmt->name(), @@ -555,6 +598,12 @@ class IRPrinter : public IRVisitor { stmt->tls_prologue->accept(this); print("}}"); } + if (stmt->mesh_prologue) { + TI_ASSERT(stmt->task_type == OffloadedTaskType::mesh_for); + print("body prologue {{"); + stmt->mesh_prologue->accept(this); + print("}}"); + } if (stmt->bls_prologue) { print("bls prologue {{"); stmt->bls_prologue->accept(this); @@ -597,10 +646,6 @@ class IRPrinter : public IRVisitor { stmt->name(), stmt->loop->name(), stmt->index); } - void visit(BlockDimStmt *stmt) override { - print("{}{} = block dim", stmt->type_hint(), stmt->name()); - } - void visit(GlobalTemporaryStmt *stmt) override { print("{}{} = global tmp var (offset = {} B)", stmt->type_hint(), stmt->name(), stmt->offset); @@ -679,6 +724,48 @@ class IRPrinter : public IRVisitor { print("{} : {}bit_struct_store {}, ch_ids=[{}], values=[{}]", stmt->name(), stmt->is_atomic ? "atomic " : "", stmt->ptr->name(), ch_ids, values); } + + // Mesh related. + + void visit(MeshRelationAccessStmt *stmt) override { + if (stmt->is_size()) { + print("{}{} = {} idx relation {} size", stmt->type_hint(), stmt->name(), + stmt->mesh_idx->name(), mesh::element_type_name(stmt->to_type)); + } else { + print("{}{} = {} idx relation {}[{}]", stmt->type_hint(), stmt->name(), + stmt->mesh_idx->name(), mesh::element_type_name(stmt->to_type), + stmt->neighbor_idx->name()); + } + } + + void visit(MeshIndexConversionStmt *stmt) override { + print("{}{} = {} {} {}", stmt->type_hint(), stmt->name(), + mesh::conv_type_name(stmt->conv_type), + mesh::element_type_name(stmt->idx_type), stmt->idx->name()); + } + + void visit(MeshPatchIndexStmt *stmt) override { + print("{}{} = mesh patch idx", stmt->type_hint(), stmt->name()); + } + + void visit(FrontendExternalFuncStmt *stmt) override { + if (stmt->so_func != nullptr) { + print("so {:x}", (uint64)stmt->so_func); + } else if (!stmt->asm_source.empty()) { + print("asm \"{}\"", stmt->asm_source); + } else { + print("bc {}:{}", stmt->bc_filename, stmt->bc_funcname); + } + print(" (inputs="); + for (auto &s : stmt->args) { + print(s.serialize()); + } + print(", outputs="); + for (auto &s : stmt->outputs) { + print(s.serialize()); + } + print(")"); + } }; } // namespace diff --git a/taichi/transforms/loop_invariant_code_motion.cpp b/taichi/transforms/loop_invariant_code_motion.cpp index 42af66fdab4ef..49f735ff2ea7b 100644 --- a/taichi/transforms/loop_invariant_code_motion.cpp +++ b/taichi/transforms/loop_invariant_code_motion.cpp @@ -73,7 +73,7 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor { void visit(BinaryOpStmt *stmt) override { if (stmt_can_be_moved(stmt)) { auto replacement = stmt->clone(); - stmt->replace_with(replacement.get()); + stmt->replace_usages_with(replacement.get()); modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); modifier.erase(stmt); @@ -83,7 +83,7 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor { void visit(UnaryOpStmt *stmt) override { if (stmt_can_be_moved(stmt)) { auto replacement = stmt->clone(); - stmt->replace_with(replacement.get()); + stmt->replace_usages_with(replacement.get()); modifier.insert_before(stmt->parent->parent_stmt, std::move(replacement)); modifier.erase(stmt); @@ -111,6 +111,10 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor { visit_loop(stmt->body.get()); } + void visit(MeshForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + void visit(WhileStmt *stmt) override { visit_loop(stmt->body.get()); } @@ -119,6 +123,9 @@ class LoopInvariantCodeMotion : public BasicStmtVisitor { if (stmt->tls_prologue) stmt->tls_prologue->accept(this); + if (stmt->mesh_prologue) + stmt->mesh_prologue->accept(this); + if (stmt->bls_prologue) stmt->bls_prologue->accept(this); diff --git a/taichi/transforms/loop_vectorize.cpp b/taichi/transforms/loop_vectorize.cpp deleted file mode 100644 index 80e2d99ab7409..0000000000000 --- a/taichi/transforms/loop_vectorize.cpp +++ /dev/null @@ -1,173 +0,0 @@ -// The loop vectorizer - -#include "taichi/program/program.h" -#include "taichi/ir/ir.h" -#include "taichi/ir/type_factory.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" - -TLANG_NAMESPACE_BEGIN - -// Lower Expr tree to a bunch of binary/unary(binary/unary) statements -// Goal: eliminate Expression, and mutable local variables. Make AST SSA. -class LoopVectorize : public IRVisitor { - public: - int vectorize; - Stmt *loop_var; // an alloca... - const CompileConfig &config; - - explicit LoopVectorize(const CompileConfig &config) : config(config) { - allow_undefined_visitor = true; - invoke_default_visitor = true; - loop_var = nullptr; - vectorize = 1; - } - - static void widen_type(DataType &type, int width) { - if (width != 1) { - type = Program::get_type_factory().get_vector_type(width, type); - } - } - - void visit(Stmt *stmt) override { - widen_type(stmt->ret_type, vectorize); - } - - void visit(ConstStmt *stmt) override { - stmt->val.repeat(vectorize); - widen_type(stmt->ret_type, vectorize); - } - - void visit(Block *stmt_list) override { - std::vector statements; - for (auto &stmt : stmt_list->statements) { - statements.push_back(stmt.get()); - } - for (auto stmt : statements) { - stmt->accept(this); - } - } - - void visit(GlobalPtrStmt *ptr) override { - ptr->snodes.repeat(vectorize); - widen_type(ptr->ret_type, vectorize); - } - - void visit(AllocaStmt *alloca) override { - widen_type(alloca->ret_type, vectorize); - } - - void visit(SNodeOpStmt *stmt) override { - if (vectorize == 1) - return; - // TI_NOT_IMPLEMENTED; - /* - stmt->snodes.repeat(vectorize); - stmt->ret_type.width *= vectorize; - */ - } - - void visit(ElementShuffleStmt *stmt) override { - if (vectorize == 1) - return; - int original_width = stmt->width(); - widen_type(stmt->ret_type, vectorize); - stmt->elements.repeat(vectorize); - // TODO: this can be buggy - int stride = stmt->elements[original_width - 1].index + 1; - if (stmt->elements[0].stmt->width() != 1) { - for (int i = 0; i < vectorize; i++) { - for (int j = 0; j < original_width; j++) { - stmt->elements[i * original_width + j].index += i * stride; - } - } - } - } - - void visit(LocalLoadStmt *stmt) override { - if (vectorize == 1) - return; - int original_width = stmt->width(); - widen_type(stmt->ret_type, vectorize); - stmt->src.repeat(vectorize); - // TODO: this can be buggy - int stride = stmt->src[original_width - 1].offset + 1; - if (stmt->src[0].var->width() != 1) { - for (int i = 0; i < vectorize; i++) { - for (int j = 0; j < original_width; j++) { - stmt->src[i * original_width + j].offset += i * stride; - } - } - } - if (loop_var && stmt->same_source() && stmt->src[0].var == loop_var) { - // insert_before_me - LaneAttribute const_offsets; - const_offsets.resize(vectorize * original_width); - for (int i = 0; i < vectorize * original_width; i++) { - const_offsets[i] = TypedConstant(i / original_width); - } - auto offsets = std::make_unique(const_offsets); - auto add_op = std::make_unique(BinaryOpType::add, stmt, - offsets.get()); - irpass::type_check(add_op.get(), config); - auto offsets_p = offsets.get(); - stmt->replace_with(add_op.get()); - stmt->insert_after_me(std::move(offsets)); - offsets_p->insert_after_me(std::move(add_op)); - } - } - - void visit(IfStmt *if_stmt) override { - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(RangeForStmt *for_stmt) override { - auto old_vectorize = for_stmt->vectorize; - if (for_stmt->vectorize != 1) - vectorize = for_stmt->vectorize; - // TODO: RangeForStmt::loop_var is deprecated - // loop_var = for_stmt->loop_var; - for_stmt->body->accept(this); - // loop_var = nullptr; - vectorize = old_vectorize; - } - - void visit(StructForStmt *for_stmt) override { - // TODO: StructForStmt::loop_var is deprecated - return; - /*if (for_stmt->loop_vars.empty()) - return; - auto old_vectorize = for_stmt->vectorize; - if (for_stmt->vectorize != 1) - vectorize = for_stmt->vectorize; - loop_var = for_stmt->loop_vars.back(); - for_stmt->body->accept(this); - loop_var = nullptr; - vectorize = old_vectorize;*/ - } - - void visit(WhileStmt *stmt) override { - stmt->body->accept(this); - } - - static void run(IRNode *node, const CompileConfig &config) { - LoopVectorize inst(config); - node->accept(&inst); - } -}; - -namespace irpass { - -void loop_vectorize(IRNode *root, const CompileConfig &config) { - TI_AUTO_PROF; - return LoopVectorize::run(root, config); -} - -} // namespace irpass - -TLANG_NAMESPACE_END diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 651ec5d734499..73d437cc7b605 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -269,7 +269,7 @@ Stmt *PtrLowererImpl::handle_snode_at_level(int level, if (!diff.linear_related()) { on_loop_tree = false; } else if (j == (int)indices_.size() - 1) { - if (!(0 <= diff.low && diff.high <= current_struct_for->vectorize)) { + if (!(0 <= diff.low && diff.high <= 1)) { // TODO: Vectorize on_loop_tree = false; } } else { diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 112aec7f085d0..81daee000c54d 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -30,33 +30,37 @@ std::vector make_raw_pointer_list( // AST SSA. class LowerAST : public IRVisitor { private: - Stmt *capturing_loop; - std::unordered_set detected_fors_with_break; - Block *current_block; + Stmt *capturing_loop_; + std::unordered_set detected_fors_with_break_; + Block *current_block_; + int current_block_depth_; FlattenContext make_flatten_ctx() { FlattenContext fctx; - fctx.current_block = this->current_block; + fctx.current_block = this->current_block_; return fctx; } public: explicit LowerAST(const std::unordered_set &_detected_fors_with_break) - : detected_fors_with_break(_detected_fors_with_break), - current_block(nullptr) { + : detected_fors_with_break_(_detected_fors_with_break), + current_block_(nullptr), + current_block_depth_(0) { // TODO: change this to false allow_undefined_visitor = true; - capturing_loop = nullptr; + capturing_loop_ = nullptr; } void visit(Block *stmt_list) override { - auto backup_block = this->current_block; - this->current_block = stmt_list; + auto backup_block = this->current_block_; + this->current_block_ = stmt_list; auto stmts = make_raw_pointer_list(stmt_list->statements); + current_block_depth_++; for (auto &stmt : stmts) { stmt->accept(this); } - this->current_block = backup_block; + current_block_depth_--; + this->current_block_ = backup_block; } void visit(FrontendAllocaStmt *stmt) override { @@ -75,12 +79,11 @@ class LowerAST : public IRVisitor { block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); stmt->parent->replace_with(stmt, std::move(lowered)); } - throw IRModified(); } void visit(FrontendIfStmt *stmt) override { auto fctx = make_flatten_ctx(); - stmt->condition->flatten(&fctx); + flatten_rvalue(stmt->condition, &fctx); auto new_if = std::make_unique(stmt->condition->stmt); @@ -100,10 +103,10 @@ class LowerAST : public IRVisitor { new_if->set_false_statements(std::move(stmt->false_statements)); new_if->false_statements->mask_var = new_if->false_mask; } - + auto pif = new_if.get(); fctx.push_back(std::move(new_if)); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); + pif->accept(this); } void visit(IfStmt *if_stmt) override { @@ -122,7 +125,7 @@ class LowerAST : public IRVisitor { for (auto c : stmt->contents) { if (std::holds_alternative(c)) { auto x = std::get(c); - x->flatten(&fctx); + flatten_rvalue(x, &fctx); stmts.push_back(x->stmt); new_contents.push_back(x->stmt); } else { @@ -132,16 +135,14 @@ class LowerAST : public IRVisitor { } fctx.push_back(new_contents); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); } void visit(FrontendBreakStmt *stmt) override { - auto while_stmt = capturing_loop->as(); + auto while_stmt = capturing_loop_->as(); VecStatement stmts; auto const_true = stmts.push_back(TypedConstant((int32)0)); stmts.push_back(while_stmt->mask, const_true); stmt->parent->replace_with(stmt, std::move(stmts)); - throw IRModified(); } void visit(FrontendContinueStmt *stmt) override { @@ -153,7 +154,7 @@ class LowerAST : public IRVisitor { // while (1) { cond; if (no active) break; original body...} auto cond = stmt->cond; auto fctx = make_flatten_ctx(); - cond->flatten(&fctx); + flatten_rvalue(cond, &fctx); auto cond_stmt = fctx.back_stmt(); auto &&new_while = std::make_unique(std::move(stmt->body)); @@ -174,16 +175,17 @@ class LowerAST : public IRVisitor { stmt->insert_before_me( std::make_unique(new_while->mask, const_stmt_ptr)); new_while->body->mask_var = new_while->mask; + auto pwhile = new_while.get(); stmt->parent->replace_with(stmt, std::move(new_while)); + pwhile->accept(this); // insert an alloca for the mask - throw IRModified(); } void visit(WhileStmt *stmt) override { - auto old_capturing_loop = capturing_loop; - capturing_loop = stmt; + auto old_capturing_loop = capturing_loop_; + capturing_loop_ = stmt; stmt->body->accept(this); - capturing_loop = old_capturing_loop; + capturing_loop_ = old_capturing_loop; } void visit(LoopIndexStmt *stmt) override { @@ -200,18 +202,17 @@ class LowerAST : public IRVisitor { TI_ASSERT(stmt->loop_var_id.size() == 1); auto begin = stmt->begin; auto end = stmt->end; - begin->flatten(&fctx); - end->flatten(&fctx); + flatten_rvalue(begin, &fctx); + flatten_rvalue(end, &fctx); bool is_good_range_for = - capturing_loop == nullptr || - detected_fors_with_break.find(stmt) == detected_fors_with_break.end(); + current_block_depth_ == 1 || detected_fors_with_break_.find(stmt) == + detected_fors_with_break_.end(); // #578: a good range for is a range for that doesn't contains a break // statement if (is_good_range_for) { auto &&new_for = std::make_unique( - begin->stmt, end->stmt, std::move(stmt->body), stmt->vectorize, - stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim, - stmt->strictly_serialized); + begin->stmt, end->stmt, std::move(stmt->body), stmt->bit_vectorize, + stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized); new_for->body->insert(std::make_unique(new_for.get(), 0), 0); new_for->body->local_var_to_stmt[stmt->loop_var_id[0]] = @@ -267,6 +268,17 @@ class LowerAST : public IRVisitor { new_while->body->mask_var = new_while->mask; fctx.push_back(std::move(new_while)); } + } else if (stmt->mesh_for) { + auto &&new_for = std::make_unique( + stmt->mesh, stmt->element_type, std::move(stmt->body), + stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim); + new_for->body->insert(std::make_unique(new_for.get(), 0), + 0); + new_for->body->local_var_to_stmt[stmt->loop_var_id[0]] = + new_for->body->statements[0].get(); + new_for->mem_access_opt = stmt->mem_access_opt; + new_for->fields_registered = true; + fctx.push_back(std::move(new_for)); } else if (stmt->global_var.is()) { auto snode = stmt->global_var.cast()->snode; std::vector offsets; @@ -293,7 +305,7 @@ class LowerAST : public IRVisitor { snode = snode->parent; auto &&new_for = std::make_unique( - snode, std::move(stmt->body), stmt->vectorize, stmt->bit_vectorize, + snode, std::move(stmt->body), stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim); new_for->index_offsets = offsets; VecStatement new_statements; @@ -331,91 +343,91 @@ class LowerAST : public IRVisitor { for (int i = 0; i < (int)shape.size(); i++) { end = fctx.push_back(BinaryOpType::mul, end, shape[i]); } + // TODO: add a note explaining why shape might be empty. auto &&new_for = std::make_unique( - begin, end, std::move(stmt->body), stmt->vectorize, - stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim, - stmt->strictly_serialized); + begin, end, std::move(stmt->body), stmt->bit_vectorize, + stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, + /*range_hint=*/fmt::format("arg {}", tensor->arg_id)); VecStatement new_statements; Stmt *loop_index = new_statements.push_back(new_for.get(), 0); for (int i = (int)shape.size() - 1; i >= 0; i--) { - new_for->body->local_var_to_stmt[stmt->loop_var_id[i]] = - new_statements.push_back(BinaryOpType::mod, - loop_index, shape[i]); + Stmt *loop_var = new_statements.push_back( + BinaryOpType::mod, loop_index, shape[i]); + new_for->body->local_var_to_stmt[stmt->loop_var_id[i]] = loop_var; + std::vector decoration = { + uint32_t(DecorationStmt::Decoration::kLoopUnique), uint32_t(i)}; + new_statements.push_back(loop_var, decoration); loop_index = new_statements.push_back( BinaryOpType::div, loop_index, shape[i]); } new_for->body->insert(std::move(new_statements), 0); fctx.push_back(std::move(new_for)); } + auto pfor = fctx.stmts.back().get(); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); + pfor->accept(this); } void visit(RangeForStmt *for_stmt) override { - auto old_capturing_loop = capturing_loop; - capturing_loop = for_stmt; + auto old_capturing_loop = capturing_loop_; + capturing_loop_ = for_stmt; for_stmt->body->accept(this); - capturing_loop = old_capturing_loop; + capturing_loop_ = old_capturing_loop; } void visit(StructForStmt *for_stmt) override { - auto old_capturing_loop = capturing_loop; - capturing_loop = for_stmt; + auto old_capturing_loop = capturing_loop_; + capturing_loop_ = for_stmt; for_stmt->body->accept(this); - capturing_loop = old_capturing_loop; + capturing_loop_ = old_capturing_loop; } - void visit(FrontendReturnStmt *stmt) override { - auto expr = stmt->value; - auto fctx = make_flatten_ctx(); - expr->flatten(&fctx); - fctx.push_back(fctx.back_stmt()); - stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); + void visit(MeshForStmt *for_stmt) override { + auto old_capturing_loop = capturing_loop_; + capturing_loop_ = for_stmt; + for_stmt->body->accept(this); + capturing_loop_ = old_capturing_loop; } - void visit(FrontendEvalStmt *stmt) override { - // expand rhs - auto expr = stmt->expr; + void visit(FrontendReturnStmt *stmt) override { + auto expr_group = stmt->values; auto fctx = make_flatten_ctx(); - expr->flatten(&fctx); - if (stmt->eval_expr.expr && stmt->eval_expr.is()) { - stmt->eval_expr.cast()->stmt_ptr = stmt->expr->stmt; + std::vector return_ele; + for (auto &x : expr_group.exprs) { + flatten_rvalue(x, &fctx); + return_ele.push_back(fctx.back_stmt()); } + fctx.push_back(return_ele); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); } void visit(FrontendAssignStmt *assign) override { - // expand rhs + auto dest = assign->lhs; auto expr = assign->rhs; auto fctx = make_flatten_ctx(); - expr->flatten(&fctx); - if (assign->lhs.is()) { // local variable - // emit local store stmt + flatten_rvalue(expr, &fctx); + if (dest.is()) { fctx.push_back( assign->parent->lookup_var(assign->lhs.cast()->id), expr->stmt); - } else if (assign->lhs.is()) { - auto tensor_ptr = assign->lhs.cast(); - tensor_ptr->flatten(&fctx); + } else if (dest.is()) { + flatten_lvalue(dest, &fctx); + auto tensor_ptr = dest.cast(); if (tensor_ptr->is_local_tensor()) { - fctx.push_back(tensor_ptr->stmt, expr->stmt); + fctx.push_back(dest->stmt, expr->stmt); } else if (tensor_ptr->is_global_tensor()) { - fctx.push_back(tensor_ptr->stmt, expr->stmt); + fctx.push_back(dest->stmt, expr->stmt); } else { TI_NOT_IMPLEMENTED } } else { // global variable - TI_ASSERT(assign->lhs.is()); - auto global_ptr = assign->lhs.cast(); - global_ptr->flatten(&fctx); - fctx.push_back(fctx.back_stmt(), expr->stmt); + TI_ASSERT(dest.is()); + flatten_lvalue(dest, &fctx); + fctx.push_back(dest->stmt, expr->stmt); } fctx.stmts.back()->set_tb(assign->tb); assign->parent->replace_with(assign, std::move(fctx.stmts)); - throw IRModified(); } void visit(FrontendSNodeOpStmt *stmt) override { @@ -424,13 +436,13 @@ class LowerAST : public IRVisitor { auto fctx = make_flatten_ctx(); if (stmt->val.expr) { auto expr = stmt->val; - expr->flatten(&fctx); + flatten_rvalue(expr, &fctx); val_stmt = expr->stmt; } std::vector indices_stmt(stmt->indices.size(), nullptr); for (int i = 0; i < (int)stmt->indices.size(); i++) { - stmt->indices[i]->flatten(&fctx); + flatten_rvalue(stmt->indices[i], &fctx); indices_stmt[i] = stmt->indices[i]->stmt; } @@ -452,7 +464,6 @@ class LowerAST : public IRVisitor { } stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); } void visit(FrontendAssertStmt *stmt) override { @@ -461,40 +472,65 @@ class LowerAST : public IRVisitor { auto fctx = make_flatten_ctx(); if (stmt->cond.expr) { auto expr = stmt->cond; - expr->flatten(&fctx); + flatten_rvalue(expr, &fctx); val_stmt = expr->stmt; } auto &fargs = stmt->args; // frontend stmt args std::vector args_stmts(fargs.size()); for (int i = 0; i < (int)fargs.size(); ++i) { - fargs[i]->flatten(&fctx); + flatten_rvalue(fargs[i], &fctx); args_stmts[i] = fargs[i]->stmt; } fctx.push_back(val_stmt, stmt->text, args_stmts); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); } void visit(FrontendExprStmt *stmt) override { auto fctx = make_flatten_ctx(); - stmt->val->flatten(&fctx); + flatten_rvalue(stmt->val, &fctx); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); - throw IRModified(); } - static void run(IRNode *node) { - LowerAST inst(irpass::analysis::detect_fors_with_break(node)); - while (true) { - bool modified = false; - try { - node->accept(&inst); - } catch (IRModified) { - modified = true; + void visit(FrontendExternalFuncStmt *stmt) override { + auto ctx = make_flatten_ctx(); + TI_ASSERT((int)(stmt->so_func != nullptr) + + (int)(!stmt->asm_source.empty()) + + (int)(!stmt->bc_filename.empty()) == + 1) + std::vector arg_statements, output_statements; + if (stmt->so_func != nullptr || !stmt->asm_source.empty()) { + for (auto &s : stmt->args) { + flatten_rvalue(s, &ctx); + arg_statements.push_back(s->stmt); + } + for (auto &s : stmt->outputs) { + flatten_lvalue(s, &ctx); + output_statements.push_back(s->stmt); + } + ctx.push_back(std::make_unique( + (stmt->so_func != nullptr) ? ExternalFuncCallStmt::SHARED_OBJECT + : ExternalFuncCallStmt::ASSEMBLY, + stmt->so_func, stmt->asm_source, "", "", arg_statements, + output_statements)); + } else { + for (auto &s : stmt->args) { + TI_ASSERT_INFO( + s.is(), + "external func call via bitcode must pass in local variables.") + flatten_lvalue(s, &ctx); + arg_statements.push_back(s->stmt); } - if (!modified) - break; + ctx.push_back(std::make_unique( + ExternalFuncCallStmt::BITCODE, nullptr, "", stmt->bc_filename, + stmt->bc_funcname, arg_statements, output_statements)); } + stmt->parent->replace_with(stmt, std::move(ctx.stmts)); + } + + static void run(IRNode *node) { + LowerAST inst(irpass::analysis::detect_fors_with_break(node)); + node->accept(&inst); } }; diff --git a/taichi/transforms/make_mesh_block_local.cpp b/taichi/transforms/make_mesh_block_local.cpp new file mode 100644 index 0000000000000..98f5319141a71 --- /dev/null +++ b/taichi/transforms/make_mesh_block_local.cpp @@ -0,0 +1,681 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/transforms/make_mesh_block_local.h" + +namespace taichi { +namespace lang { + +const PassID MakeMeshBlockLocal::id = "MakeMeshBlockLocal"; + +void MakeMeshBlockLocal::simplify_nested_conversion() { + std::vector stmts; + std::vector ori_indices; + + irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { + if (auto conv1 = stmt->cast()) { + if (auto conv2 = conv1->idx->cast()) { + if (conv1->conv_type == mesh::ConvType::g2r && + conv2->conv_type == mesh::ConvType::l2g && + conv1->mesh == conv2->mesh && + conv1->idx_type == conv2->idx_type) { // nested + stmts.push_back(conv1); + ori_indices.push_back(conv2->idx); + } + } + } + return false; + }); + + for (size_t i = 0; i < stmts.size(); ++i) { + stmts[i]->replace_with(Stmt::make( + stmts[i]->mesh, stmts[i]->idx_type, ori_indices[i], + mesh::ConvType::l2r)); + } +} + +void MakeMeshBlockLocal::gather_candidate_mapping() { + irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { + if (auto conv = stmt->cast()) { + if (conv->conv_type != mesh::ConvType::g2r) { + bool is_from_end = (conv->idx_type == offload_->major_from_type); + bool is_to_end = false; + for (auto type : offload_->major_to_types) { + is_to_end |= (conv->idx_type == type); + } + for (auto rel : offload_->minor_relation_types) { + auto from_type = + mesh::MeshElementType(mesh::from_end_element_order(rel)); + auto to_type = mesh::MeshElementType(mesh::to_end_element_order(rel)); + is_from_end |= (conv->idx_type == from_type); + is_to_end |= (conv->idx_type == to_type); + } + if ((is_to_end && config_.mesh_localize_to_end_mapping) || + (is_from_end && config_.mesh_localize_from_end_mapping)) { + mappings_.insert(std::make_pair(conv->idx_type, conv->conv_type)); + } + } + } + return false; + }); +} + +void MakeMeshBlockLocal::replace_conv_statements() { + std::vector idx_conv_stmts; + + irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { + if (auto idx_conv = stmt->cast()) { + if (idx_conv->mesh == offload_->mesh && + idx_conv->conv_type == conv_type_ && + idx_conv->idx_type == element_type_) { + idx_conv_stmts.push_back(idx_conv); + } + } + return false; + }); + + for (auto stmt : idx_conv_stmts) { + VecStatement bls; + Stmt *bls_element_offset_bytes = bls.push_back( + LaneAttribute{(int32)mapping_bls_offset_in_bytes_}); + Stmt *idx_byte = bls.push_back( + BinaryOpType::mul, stmt->idx, + bls.push_back(TypedConstant(mapping_dtype_size_))); + Stmt *offset = bls.push_back( + BinaryOpType::add, bls_element_offset_bytes, idx_byte); + Stmt *bls_ptr = bls.push_back( + offset, + TypeFactory::create_vector_or_scalar_type(1, mapping_data_type_, true)); + [[maybe_unused]] Stmt *bls_load = bls.push_back(bls_ptr); + stmt->replace_with(std::move(bls)); + } +} + +void MakeMeshBlockLocal::replace_global_ptrs(SNode *snode) { + auto data_type = snode->dt.ptr_removed(); + auto dtype_size = data_type_size(data_type); + auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second; + + std::vector global_ptrs; + irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { + if (auto global_ptr = stmt->cast()) { + TI_ASSERT(global_ptr->width() == 1); + if (global_ptr->snodes[0] == snode && + global_ptr->indices[0]->is()) { + global_ptrs.push_back(global_ptr); + } + } + return false; + }); + + for (auto global_ptr : global_ptrs) { + VecStatement bls; + Stmt *local_idx = + global_ptr->indices[0]->as()->idx; + Stmt *local_idx_byte = bls.push_back( + BinaryOpType::mul, local_idx, + bls.push_back(TypedConstant(dtype_size))); + Stmt *offset = + bls.push_back(TypedConstant(int32(offset_in_bytes))); + Stmt *index = + bls.push_back(BinaryOpType::add, offset, local_idx_byte); + [[maybe_unused]] Stmt *bls_ptr = bls.push_back( + index, TypeFactory::create_vector_or_scalar_type(1, data_type, true)); + global_ptr->replace_with(std::move(bls)); + } + + // in the cpu backend, atomic op in body block could be demoted to non-atomic + if (config_.arch != Arch::x64) { + return; + } + std::vector atomic_ops; + irpass::analysis::gather_statements(offload_->body.get(), [&](Stmt *stmt) { + if (auto atomic_op = stmt->cast()) { + if (atomic_op->op_type == AtomicOpType::add && + atomic_op->dest->is()) { + atomic_ops.push_back(atomic_op); + } + } + return false; + }); + + for (auto atomic_op : atomic_ops) { + VecStatement non_atomic; + Stmt *dest_val = non_atomic.push_back(atomic_op->dest); + Stmt *res_val = non_atomic.push_back( + BinaryOpType::add, dest_val, atomic_op->val); + non_atomic.push_back(atomic_op->dest, res_val); + atomic_op->replace_with(std::move(non_atomic)); + } +} + +// This function creates loop like: +// int i = start_val; +// while (i < end_val) { +// body(i); +// i += blockDim.x; +// } +Stmt *MakeMeshBlockLocal::create_xlogue( + Stmt *start_val, + Stmt *end_val, + std::function body_) { + Stmt *idx = block_->push_back(mapping_data_type_); + [[maybe_unused]] Stmt *init_val = + block_->push_back(idx, start_val); + Stmt *block_dim_val; + if (config_.arch == Arch::x64) { + block_dim_val = block_->push_back(TypedConstant(1)); + } else { + block_dim_val = block_->push_back( + LaneAttribute{offload_->block_dim}); + } + + std::unique_ptr body = std::make_unique(); + { + Stmt *idx_val = body->push_back(LocalAddress{idx, 0}); + Stmt *cond = + body->push_back(BinaryOpType::cmp_lt, idx_val, end_val); + body->push_back(nullptr, cond); + body_(body.get(), idx_val); + Stmt *idx_val_ = body->push_back(BinaryOpType::add, idx_val, + block_dim_val); + [[maybe_unused]] Stmt *idx_store = + body->push_back(idx, idx_val_); + } + block_->push_back(std::move(body)); + Stmt *idx_val = block_->push_back(LocalAddress{idx, 0}); + return idx_val; +} + +// This function creates loop like: +// int i = start_val; +// while (i < end_val) { +// mapping_shared[i] = global_val(i); +// i += blockDim.x; +// } +Stmt *MakeMeshBlockLocal::create_cache_mapping( + Stmt *start_val, + Stmt *end_val, + std::function global_val) { + Stmt *bls_element_offset_bytes = block_->push_back( + LaneAttribute{(int32)mapping_bls_offset_in_bytes_}); + return create_xlogue(start_val, end_val, [&](Block *body, Stmt *idx_val) { + Stmt *idx_val_byte = body->push_back( + BinaryOpType::mul, idx_val, + body->push_back(TypedConstant(mapping_dtype_size_))); + Stmt *offset = body->push_back( + BinaryOpType::add, bls_element_offset_bytes, idx_val_byte); + Stmt *bls_ptr = body->push_back( + offset, + TypeFactory::create_vector_or_scalar_type(1, mapping_data_type_, true)); + [[maybe_unused]] Stmt *bls_store = + body->push_back(bls_ptr, global_val(body, idx_val)); + }); +} + +void MakeMeshBlockLocal::fetch_attr_to_bls(Block *body, + Stmt *idx_val, + Stmt *mapping_val) { + auto attrs = rec_.find(std::make_pair(element_type_, conv_type_)); + if (attrs == rec_.end()) { + return; + } + for (auto [snode, total_flags] : attrs->second) { + auto data_type = snode->dt.ptr_removed(); + auto dtype_size = data_type_size(data_type); + + bool bls_has_read = total_flags & AccessFlag::read; + bool bls_has_write = total_flags & AccessFlag::write; + bool bls_has_accumulate = total_flags & AccessFlag::accumulate; + + TI_ASSERT_INFO(!bls_has_write, "BLS with write accesses is not supported."); + TI_ASSERT_INFO(!(bls_has_accumulate && bls_has_read), + "BLS with both read and accumulation is not supported."); + + bool first_allocate = {false}; + if (attr_bls_offset_in_bytes_.find(snode) == + attr_bls_offset_in_bytes_.end()) { + first_allocate = {true}; + bls_offset_in_bytes_ += + (dtype_size - bls_offset_in_bytes_ % dtype_size) % dtype_size; + attr_bls_offset_in_bytes_.insert( + std::make_pair(snode, bls_offset_in_bytes_)); + bls_offset_in_bytes_ += + dtype_size * + offload_->mesh->patch_max_element_num.find(element_type_)->second; + } + auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second; + + Stmt *value{nullptr}; + if (bls_has_read) { + // Read access + // Fetch from global to BLS + Stmt *global_ptr = body->push_back( + LaneAttribute{snode}, std::vector{mapping_val}); + value = body->push_back(global_ptr); + } else { + // Accumulation access + // Zero-fill + value = body->push_back(TypedConstant(data_type, 0)); + } + + Stmt *offset = + body->push_back(TypedConstant(int32(offset_in_bytes))); + Stmt *idx_val_byte = body->push_back( + BinaryOpType::mul, idx_val, + body->push_back(TypedConstant(dtype_size))); + Stmt *index = + body->push_back(BinaryOpType::add, offset, idx_val_byte); + Stmt *bls_ptr = body->push_back( + index, TypeFactory::create_vector_or_scalar_type(1, data_type, true)); + body->push_back(bls_ptr, value); + + // Step 3-2-1: + // Make loop body load from BLS instead of global fields + // NOTE that first_allocate ensures this step only do ONCE + if (first_allocate) { + replace_global_ptrs(snode); + } + } +} + +void MakeMeshBlockLocal::push_attr_to_global(Block *body, + Stmt *idx_val, + Stmt *mapping_val) { + auto attrs = rec_.find(std::make_pair(element_type_, conv_type_)); + if (attrs == rec_.end()) { + return; + } + for (auto [snode, total_flags] : attrs->second) { + bool bls_has_accumulate = total_flags & AccessFlag::accumulate; + if (!bls_has_accumulate) { + continue; + } + auto data_type = snode->dt.ptr_removed(); + auto dtype_size = data_type_size(data_type); + auto offset_in_bytes = attr_bls_offset_in_bytes_.find(snode)->second; + + Stmt *offset = + body->push_back(TypedConstant(int32(offset_in_bytes))); + Stmt *idx_val_byte = body->push_back( + BinaryOpType::mul, idx_val, + body->push_back(TypedConstant(dtype_size))); + Stmt *index = + body->push_back(BinaryOpType::add, offset, idx_val_byte); + Stmt *bls_ptr = body->push_back( + index, TypeFactory::create_vector_or_scalar_type(1, data_type, true)); + Stmt *bls_val = body->push_back(bls_ptr); + + Stmt *global_ptr = body->push_back( + LaneAttribute{snode}, std::vector{mapping_val}); + body->push_back(AtomicOpType::add, global_ptr, bls_val); + } +} + +void MakeMeshBlockLocal::fetch_mapping( + std::function< + Stmt *(Stmt * /*start_val*/, + Stmt * /*end_val*/, + std::function)> + mapping_callback_handler, + std::function + attr_callback_handler) { + Stmt *thread_idx_stmt; + if (config_.arch == Arch::x64) { + thread_idx_stmt = block_->push_back(TypedConstant(0)); + } else { + thread_idx_stmt = block_->push_back( + offload_); // Equivalent to CUDA threadIdx + } + Stmt *total_element_num = + offload_->total_num_local.find(element_type_)->second; + Stmt *total_element_offset = + offload_->total_offset_local.find(element_type_)->second; + + if (config_.optimize_mesh_reordered_mapping && + conv_type_ == mesh::ConvType::l2r) { + // int i = threadIdx.x; + // while (i < owned_{}_num) { + // mapping_shared[i] = i + owned_{}_offset; + // { + // x0_shared[i] = x0[mapping_shared[i]]; + // ... + // } + // i += blockDim.x; + // } + // while (i < total_{}_num) { + // mapping_shared[i] = mapping[i + total_{}_offset]; + // { + // x0_shared[i] = x0[mapping_shared[i]]; + // ... + // } + // i += blockDim.x; + // } + Stmt *owned_element_num = + offload_->owned_num_local.find(element_type_)->second; + Stmt *owned_element_offset = + offload_->owned_offset_local.find(element_type_)->second; + Stmt *pre_idx_val = mapping_callback_handler( + thread_idx_stmt, owned_element_num, [&](Block *body, Stmt *idx_val) { + Stmt *global_index = body->push_back( + BinaryOpType::add, idx_val, owned_element_offset); + attr_callback_handler(body, idx_val, global_index); + return global_index; + }); + mapping_callback_handler( + pre_idx_val, total_element_num, [&](Block *body, Stmt *idx_val) { + Stmt *global_offset = body->push_back( + BinaryOpType::add, total_element_offset, idx_val); + Stmt *global_ptr = body->push_back( + LaneAttribute{mapping_snode_}, + std::vector{global_offset}); + Stmt *global_load = body->push_back(global_ptr); + attr_callback_handler(body, idx_val, global_load); + return global_load; + }); + } else { + // int i = threadIdx.x; + // while (i < total_{}_num) { + // mapping_shared[i] = mapping[i + total_{}_offset]; + // { + // x0_shared[i] = x0[mapping_shared[i]]; + // ... + // } + // i += blockDim.x; + // } + mapping_callback_handler( + thread_idx_stmt, total_element_num, [&](Block *body, Stmt *idx_val) { + Stmt *global_offset = body->push_back( + BinaryOpType::add, total_element_offset, idx_val); + Stmt *global_ptr = body->push_back( + LaneAttribute{mapping_snode_}, + std::vector{global_offset}); + Stmt *global_load = body->push_back(global_ptr); + attr_callback_handler(body, idx_val, global_load); + return global_load; + }); + } +} + +MakeMeshBlockLocal::MakeMeshBlockLocal(OffloadedStmt *offload, + const CompileConfig &config) + : config_(config), offload_(offload) { + // Step 0: simplify l2g + g2r -> l2r + simplify_nested_conversion(); + + // Step 1: A analyzer to determine which mapping should be localized + mappings_.clear(); + gather_candidate_mapping(); + + // Step 1: use Mesh BLS analyzer to gather which mesh attributes user declared + // to cache + bool auto_mesh_local = config.experimental_auto_mesh_local; + if (offload->major_to_types.size() != + 1 || // not support multiple major relations yet + offload->minor_relation_types.size() > + 0 || // not support minor relations yet + offload->mem_access_opt.get_snodes_with_flag(SNodeAccessFlag::mesh_local) + .size() > 0) { // disable when user determine which attributes to + // be cached manually + auto_mesh_local = false; + } + auto caches = irpass::analysis::initialize_mesh_local_attribute( + offload, auto_mesh_local, config); + + if (auto_mesh_local && config.arch == Arch::cuda) { + const auto to_type = *offload->major_to_types.begin(); + std::size_t shared_mem_size_per_block = + default_shared_mem_size / config.auto_mesh_local_default_occupacy; + int available_bytes = + shared_mem_size_per_block / + offload->mesh->patch_max_element_num.find(to_type)->second; + if (mappings_.find(std::make_pair(to_type, mesh::ConvType::l2g)) != + mappings_.end()) { + available_bytes -= 4; + } + if (mappings_.find(std::make_pair(to_type, mesh::ConvType::l2r)) != + mappings_.end()) { + available_bytes -= 4; + } + TI_TRACE("available cache attributes bytes = {}", available_bytes); + TI_TRACE("caches size = {}", caches->caches.size()); + std::vector priority_caches; + for (const auto [snode, cache] : caches->caches) { + priority_caches.push_back(cache); + } + std::sort(priority_caches.begin(), priority_caches.end(), + [](const MeshBLSCache &a, const MeshBLSCache &b) { + return a.total_flags > b.total_flags || + (a.total_flags == b.total_flags && + a.loop_index > b.loop_index) || + (a.total_flags == b.total_flags && + a.loop_index == b.loop_index && + a.unique_accessed > b.unique_accessed); + }); + caches->caches.clear(); + for (const auto &cache : priority_caches) { + available_bytes -= data_type_size(cache.snode->dt); + if (available_bytes < 0) { + break; // not enough space to ensure occupacy + } + TI_TRACE("available = {}, x = {}, loop_index = {}, unique_access = {}", + available_bytes, cache.total_flags, int(cache.loop_index), + cache.unique_accessed); + caches->caches.insert(std::make_pair(cache.snode, cache)); + } + } + rec_ = caches->finalize(); + + // If a mesh attribute is in bls, the config makes its index mapping must also + // be in bls + if (config.mesh_localize_all_attr_mappings && + !config.experimental_auto_mesh_local) { + for (auto [mapping, attr_set] : rec_) { + if (mappings_.find(mapping) == mappings_.end()) { + mappings_.insert(mapping); + } + } + } + + auto has_acc = [&](mesh::MeshElementType element_type, + mesh::ConvType conv_type) { + auto ptr = rec_.find(std::make_pair(element_type, conv_type)); + if (ptr == rec_.end()) { + return false; + } + bool has_accumulate = {false}; + for (auto [snode, total_flags] : ptr->second) { + has_accumulate |= (total_flags & AccessFlag::accumulate); + } + return has_accumulate; + }; + + // Step 3: Cache the mappings and the attributes + bls_offset_in_bytes_ = offload->bls_size; + if (offload->bls_prologue == nullptr) { + offload->bls_prologue = std::make_unique(); + offload->bls_prologue->parent_stmt = offload; + } + if (offload->bls_epilogue == nullptr) { + offload->bls_epilogue = std::make_unique(); + offload->bls_epilogue->parent_stmt = offload; + } + + // Cache both mappings and mesh attribute + for (auto [element_type, conv_type] : mappings_) { + this->element_type_ = element_type; + this->conv_type_ = conv_type; + TI_ASSERT(conv_type != mesh::ConvType::g2r); // g2r will not be cached. + // There is not corresponding mesh element attribute read/write, + // It's useless to localize this mapping + if (offload->total_offset_local.find(element_type) == + offload->total_offset_local.end()) { + continue; + } + + mapping_snode_ = (offload->mesh->index_mapping + .find(std::make_pair(element_type, conv_type)) + ->second); + mapping_data_type_ = mapping_snode_->dt.ptr_removed(); + mapping_dtype_size_ = data_type_size(mapping_data_type_); + + // Ensure BLS alignment + bls_offset_in_bytes_ += + (mapping_dtype_size_ - bls_offset_in_bytes_ % mapping_dtype_size_) % + mapping_dtype_size_; + mapping_bls_offset_in_bytes_ = bls_offset_in_bytes_; + // allocate storage for the BLS variable + bls_offset_in_bytes_ += + mapping_dtype_size_ * + offload->mesh->patch_max_element_num.find(element_type)->second; + + // Step 3-1: + // Fetch index mapping to the BLS block + // Step 3-2 + // Fetch mesh attributes to the BLS block at the same time + // TODO(changyu): better way to use lambda + block_ = offload->bls_prologue.get(); + fetch_mapping( + [&](Stmt *start_val, Stmt *end_val, + std::function + global_val) { + return create_cache_mapping(start_val, end_val, global_val); + }, + [&](Block *body, Stmt *idx_val, Stmt *mapping_val) { + fetch_attr_to_bls(body, idx_val, mapping_val); + }); + + // Step 3-3: + // Make mesh index mapping load from BLS instead of global fields + replace_conv_statements(); + + // Step 3-4 + // Atomic-add BLS contribution to its global version if necessary + if (!has_acc(element_type, conv_type)) { + continue; + } + block_ = offload->bls_epilogue.get(); + { + Stmt *thread_idx_stmt = block_->push_back( + offload); // Equivalent to CUDA threadIdx + Stmt *total_element_num = + offload->total_num_local.find(element_type)->second; + Stmt *total_element_offset = + offload->total_offset_local.find(element_type)->second; + create_xlogue( + thread_idx_stmt, total_element_num, [&](Block *body, Stmt *idx_val) { + Stmt *bls_element_offset_bytes = + body->push_back(LaneAttribute{ + (int32)mapping_bls_offset_in_bytes_}); + Stmt *idx_byte = body->push_back( + BinaryOpType::mul, idx_val, + body->push_back(TypedConstant(mapping_dtype_size_))); + Stmt *offset = body->push_back( + BinaryOpType::add, bls_element_offset_bytes, idx_byte); + Stmt *bls_ptr = body->push_back( + offset, TypeFactory::create_vector_or_scalar_type( + 1, mapping_data_type_, true)); + Stmt *global_val = body->push_back(bls_ptr); + this->push_attr_to_global(body, idx_val, global_val); + }); + } + } + + // Cache mesh attribute only + for (auto [mapping, attr_set] : rec_) { + if (mappings_.find(mapping) != mappings_.end()) { + continue; + } + + this->element_type_ = mapping.first; + this->conv_type_ = mapping.second; + TI_ASSERT(conv_type_ != mesh::ConvType::g2r); // g2r will not be cached. + + mapping_snode_ = (offload->mesh->index_mapping + .find(std::make_pair(element_type_, conv_type_)) + ->second); + mapping_data_type_ = mapping_snode_->dt.ptr_removed(); + mapping_dtype_size_ = data_type_size(mapping_data_type_); + + // Step 3-1 + // Only fetch mesh attributes to the BLS block + // TODO(changyu): better way to use lambda + block_ = offload->bls_prologue.get(); + fetch_mapping( + [&](Stmt *start_val, Stmt *end_val, + std::function + global_val) { + return create_xlogue( + start_val, end_val, + [&](Block *block, Stmt *idx_val) { global_val(block, idx_val); }); + }, + [&](Block *body, Stmt *idx_val, Stmt *mapping_val) { + fetch_attr_to_bls(body, idx_val, mapping_val); + }); + + // Step 3-2 + // Atomic-add BLS contribution to its global version if necessary + if (!has_acc(element_type_, conv_type_)) { + continue; + } + block_ = offload->bls_epilogue.get(); + fetch_mapping( + [&](Stmt *start_val, Stmt *end_val, + std::function + global_val) { + return create_xlogue( + start_val, end_val, + [&](Block *block, Stmt *idx_val) { global_val(block, idx_val); }); + }, + [&](Block *body, Stmt *idx_val, Stmt *mapping_val) { + push_attr_to_global(body, idx_val, mapping_val); + }); + } + + offload->bls_size = std::max(std::size_t(1), bls_offset_in_bytes_); +} + +void MakeMeshBlockLocal::run(OffloadedStmt *offload, + const CompileConfig &config, + const std::string &kernel_name) { + if (offload->task_type != OffloadedStmt::TaskType::mesh_for) { + return; + } + + MakeMeshBlockLocal(offload, config); +} + +namespace irpass { + +// This pass should happen after offloading but before lower_access +void make_mesh_block_local(IRNode *root, + const CompileConfig &config, + const MakeMeshBlockLocal::Args &args) { + TI_AUTO_PROF; + + // ========================================================================================= + // This pass generates code like this: + // // Load V_l2g + // for (int i = threadIdx.x; i < total_vertices; i += blockDim.x) { + // V_l2g[i] = _V_l2g[i + total_vertices_offset]; + // sx[i] = x[V_l2g[i]]; + // sJ[i] = 0.0f; + // } + + if (auto root_block = root->cast()) { + for (auto &offload : root_block->statements) { + MakeMeshBlockLocal::run(offload->cast(), config, + args.kernel_name); + } + } else { + MakeMeshBlockLocal::run(root->as(), config, + args.kernel_name); + } + + type_check(root, config); +} + +} // namespace irpass +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/make_mesh_block_local.h b/taichi/transforms/make_mesh_block_local.h new file mode 100644 index 0000000000000..04d8ce04479b0 --- /dev/null +++ b/taichi/transforms/make_mesh_block_local.h @@ -0,0 +1,74 @@ +#pragma once + +#include "taichi/ir/pass.h" +#include "taichi/ir/statements.h" +#include "taichi/analysis/mesh_bls_analyzer.h" + +#include + +namespace taichi { +namespace lang { + +class MakeMeshBlockLocal : public Pass { + public: + static const PassID id; + + struct Args { + std::string kernel_name; + }; + + MakeMeshBlockLocal(OffloadedStmt *offload, const CompileConfig &config); + + static void run(OffloadedStmt *offload, + const CompileConfig &config, + const std::string &kernel_name); + + private: + void simplify_nested_conversion(); + void gather_candidate_mapping(); + void replace_conv_statements(); + void replace_global_ptrs(SNode *snode); + + void fetch_attr_to_bls(Block *body, Stmt *idx_val, Stmt *mapping_val); + void push_attr_to_global(Block *body, Stmt *idx_val, Stmt *mapping_val); + + Stmt *create_xlogue( + Stmt *start_val, + Stmt *end_val, + std::function body); + Stmt *create_cache_mapping( + Stmt *start_val, + Stmt *end_val, + std::function global_val); + + void fetch_mapping( + std::function< + Stmt *(Stmt * /*start_val*/, + Stmt * /*end_val*/, + std::function)/*global_val*/> + mapping_callback_handler, + std::function attr_callback_handler); + + const CompileConfig &config_; + OffloadedStmt *offload_{nullptr}; + std::set> mappings_{}; + MeshBLSCaches::Rec rec_; + + Block *block_; + + std::size_t bls_offset_in_bytes_{0}; + std::size_t mapping_bls_offset_in_bytes_{0}; + std::unordered_map attr_bls_offset_in_bytes_{}; + + mesh::MeshElementType element_type_; + mesh::ConvType conv_type_; + SNode *mapping_snode_{nullptr}; + DataType mapping_data_type_; + int mapping_dtype_size_{0}; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/make_mesh_thread_local.cpp b/taichi/transforms/make_mesh_thread_local.cpp new file mode 100644 index 0000000000000..4b016ad0d0439 --- /dev/null +++ b/taichi/transforms/make_mesh_thread_local.cpp @@ -0,0 +1,167 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/transforms/make_mesh_thread_local.h" + +namespace taichi { +namespace lang { + +const PassID MakeMeshThreadLocal::id = "MakeMeshThreadLocal"; + +namespace irpass { + +void make_mesh_thread_local_offload(OffloadedStmt *offload, + const CompileConfig &config, + const std::string &kernel_name) { + if (offload->task_type != OffloadedStmt::TaskType::mesh_for) { + return; + } + + std::pair, + /* total= */ std::unordered_set> + accessed = analysis::gather_mesh_thread_local(offload, config); + + std::size_t tls_offset = offload->tls_size; + + auto data_type = PrimitiveType::u32; // uint32_t type address + auto dtype_size = data_type_size(data_type); + + if (offload->tls_prologue == nullptr) { + offload->tls_prologue = std::make_unique(); + offload->tls_prologue->parent_stmt = offload; + } + + if (offload->mesh_prologue == nullptr) { + offload->mesh_prologue = std::make_unique(); + offload->mesh_prologue->parent_stmt = offload; + } + + auto patch_idx = + offload->tls_prologue->insert(std::make_unique(), -1); + auto one = offload->tls_prologue->insert( + std::make_unique(TypedConstant(data_type, 1)), -1); + auto patch_idx_1 = offload->tls_prologue->insert( + std::make_unique(BinaryOpType::add, patch_idx, one), -1); + + auto make_thread_local_store = + [&](mesh::MeshElementType element_type, + const std::unordered_map &offset_, + std::unordered_map &offset_local, + std::unordered_map &num_local) { + const auto offset_tls_offset = + (tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size); + tls_offset += dtype_size; // allocate storage for the TLS variable + + const auto num_tls_offset = + (tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size); + tls_offset += dtype_size; + + // Step 1: + // Create thread local storage + { + auto offset_ptr = + offload->tls_prologue->push_back( + offset_tls_offset, TypeFactory::create_vector_or_scalar_type( + 1, data_type, true)); + auto num_ptr = offload->tls_prologue->push_back( + num_tls_offset, + TypeFactory::create_vector_or_scalar_type(1, data_type, true)); + + const auto offset_snode = offset_.find(element_type); + TI_ASSERT(offset_snode != offset_.end()); + auto offset_globalptr = offload->tls_prologue->insert( + std::make_unique( + LaneAttribute{offset_snode->second}, + std::vector{patch_idx}), + -1); + auto offset_load = offload->tls_prologue->insert( + std::make_unique(offset_globalptr), -1); + auto offset_1_globalptr = offload->tls_prologue->insert( + std::make_unique( + LaneAttribute{offset_snode->second}, + std::vector{patch_idx_1}), + -1); + auto offset_1_load = offload->tls_prologue->insert( + std::make_unique(offset_1_globalptr), -1); + auto num_load = offload->tls_prologue->insert( + std::make_unique(BinaryOpType::sub, offset_1_load, + offset_load), + -1); + + // TODO: do not use GlobalStore for TLS ptr. + offload->tls_prologue->push_back(offset_ptr, + offset_load); + offload->tls_prologue->push_back(num_ptr, num_load); + } + + // Step 2: + // Store TLS mesh_prologue ptr to the offloaded statement + { + auto offset_ptr = + offload->mesh_prologue->push_back( + offset_tls_offset, TypeFactory::create_vector_or_scalar_type( + 1, data_type, true)); + auto offset_val = + offload->mesh_prologue->push_back(offset_ptr); + auto num_ptr = offload->mesh_prologue->push_back( + num_tls_offset, + TypeFactory::create_vector_or_scalar_type(1, data_type, true)); + auto num_val = + offload->mesh_prologue->push_back(num_ptr); + + offset_local.insert(std::pair(element_type, offset_val)); + num_local.insert(std::pair(element_type, num_val)); + } + }; + + for (auto element_type : accessed.first) { + make_thread_local_store(element_type, offload->mesh->owned_offset, + offload->owned_offset_local, + offload->owned_num_local); + } + + for (auto element_type : accessed.second) { + make_thread_local_store(element_type, offload->mesh->total_offset, + offload->total_offset_local, + offload->total_num_local); + } + offload->tls_size = std::max(std::size_t(1), tls_offset); +} + +// This pass should happen after offloading but before lower_access +void make_mesh_thread_local(IRNode *root, + const CompileConfig &config, + const MakeBlockLocalPass::Args &args) { + TI_AUTO_PROF; + + // ========================================================================================= + // This pass generates code like this: + // uint32_t total_vertices_offset = _total_vertices_offset[blockIdx.x]; + // uint32_t total_vertices = _total_vertices_offset[blockIdx.x + 1] - + // total_vertices_offset; + + // uint32_t total_cells_offset = _total_cells_offset[blockIdx.x]; + // uint32_t total_cells = _total_cells_offset[blockIdx.x + 1] - + // total_cells_offset; + + // uint32_t owned_cells_offset = _owned_cells_offset[blockIdx.x]; + // uint32_t owned_cells = _owned_cells_offset[blockIdx.x + 1] - + // owned_cells_offset; + // ========================================================================================= + + if (auto root_block = root->cast()) { + for (auto &offload : root_block->statements) { + make_mesh_thread_local_offload(offload->cast(), config, + args.kernel_name); + } + } else { + make_mesh_thread_local_offload(root->as(), config, + args.kernel_name); + } + type_check(root, config); +} + +} // namespace irpass +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/make_mesh_thread_local.h b/taichi/transforms/make_mesh_thread_local.h new file mode 100644 index 0000000000000..aa054a1548456 --- /dev/null +++ b/taichi/transforms/make_mesh_thread_local.h @@ -0,0 +1,18 @@ +#pragma once + +#include "taichi/ir/pass.h" + +namespace taichi { +namespace lang { + +class MakeMeshThreadLocal : public Pass { + public: + static const PassID id; + + struct Args { + std::string kernel_name; + }; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index c8b4ae8c6cab1..54afb9b6870b3 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -162,7 +162,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { tls_offset, TypeFactory::create_vector_or_scalar_type(1, data_type, true)), 0); - dest.first->replace_with(tls_ptr); + dest.first->replace_usages_with(tls_ptr); } // Step 3: diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 4282b913c8bc2..135e0c7053e16 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -13,7 +13,22 @@ TLANG_NAMESPACE_BEGIN namespace irpass { namespace { - +bool demotable_axis_load(Stmt *stmt) { + // Stmt involving simple arithmetic of ExternalTensorShapeAlongAxisStmt + // shouldn't be saved in global tmp, just clone them to each shader + // separately. + int n_op = stmt->num_operands(); + if (n_op == 0) { + return stmt->is() || + stmt->is(); + } + for (int i = 0; i < n_op; i++) { + auto op = stmt->operand(i); + if (!demotable_axis_load(op)) + return false; + } + return true; +} class SquashPtrOffset : public IRVisitor { public: SquashPtrOffset() { @@ -21,7 +36,7 @@ class SquashPtrOffset : public IRVisitor { invoke_default_visitor = true; } void visit(Stmt *stmt) override { - top_level_ptr = stmt; + top_level_ptr_ = stmt; } void visit(PtrOffsetStmt *stmt) override { stmt->origin->accept(this); @@ -29,11 +44,11 @@ class SquashPtrOffset : public IRVisitor { static Stmt *run(Stmt *root) { SquashPtrOffset v; root->accept(&v); - return v.top_level_ptr; + return v.top_level_ptr_; } private: - Stmt *top_level_ptr = nullptr; + Stmt *top_level_ptr_ = nullptr; }; // Offloaded local variables to its offset in the global tmps memory. @@ -92,23 +107,56 @@ class Offloader { offloaded_ranges.begin_stmts.insert( std::make_pair(offloaded.get(), s->begin)); } + if (auto val = s->end->cast()) { offloaded->const_end = true; offloaded->end_value = val->val[0].val_int32(); } else { + if ((arch == Arch::opengl || arch == Arch::vulkan) && + demotable_axis_load(s->end)) { + // TODO: We need to update codegen for each backend gradually so + // let's limit it to opengl backend for now. + auto end_copy = s->end->clone(); + offloaded->end_stmt = end_copy.get(); + offloaded->body->insert(std::move(end_copy)); + } offloaded_ranges.end_stmts.insert( std::make_pair(offloaded.get(), s->end)); } + offloaded->num_cpu_threads = std::min(s->num_cpu_threads, config.cpu_max_num_threads); replace_all_usages_with(s, s, offloaded.get()); for (int j = 0; j < (int)s->body->statements.size(); j++) { offloaded->body->insert(std::move(s->body->statements[j])); } + offloaded->range_hint = s->range_hint; root_block->insert(std::move(offloaded)); } else if (auto st = stmt->cast()) { assemble_serial_statements(); emit_struct_for(st, root_block, config, st->mem_access_opt); + } else if (auto st = stmt->cast()) { + assemble_serial_statements(); + auto offloaded = Stmt::make_typed( + OffloadedStmt::TaskType::mesh_for, arch); + offloaded->grid_dim = config.saturating_grid_dim; + if (st->block_dim == 0) { + offloaded->block_dim = Program::default_block_dim(config); + } else { + offloaded->block_dim = st->block_dim; + } + offloaded->num_cpu_threads = + std::min(st->num_cpu_threads, config.cpu_max_num_threads); + replace_all_usages_with(st, st, offloaded.get()); + for (int j = 0; j < (int)st->body->statements.size(); j++) { + offloaded->body->insert(std::move(st->body->statements[j])); + } + offloaded->mesh = st->mesh; + offloaded->major_from_type = std::move(st->major_from_type); + offloaded->major_to_types = std::move(st->major_to_types); + offloaded->minor_relation_types = std::move(st->minor_relation_types); + offloaded->mem_access_opt = st->mem_access_opt; + root_block->insert(std::move(offloaded)); } else { pending_serial_statements->body->insert(std::move(stmt)); } @@ -214,29 +262,29 @@ class StmtToOffloaded : public BasicStmtVisitor { StmtToOffloaded() { allow_undefined_visitor = true; invoke_default_visitor = true; - current_offloaded = nullptr; + current_offloaded_ = nullptr; } public: void visit(OffloadedStmt *stmt) override { - current_offloaded = stmt; - stmt_to_offloaded[stmt] = current_offloaded; + current_offloaded_ = stmt; + stmt_to_offloaded_[stmt] = current_offloaded_; if (stmt->body) stmt->body->accept(this); - current_offloaded = nullptr; + current_offloaded_ = nullptr; } void visit(Stmt *stmt) override { - if (current_offloaded != nullptr) { + if (current_offloaded_ != nullptr) { // inside a offloaded stmt, record its belonging offloaded_stmt - stmt_to_offloaded[stmt] = current_offloaded; + stmt_to_offloaded_[stmt] = current_offloaded_; } } void preprocess_container_stmt(Stmt *stmt) override { - if (current_offloaded != nullptr) { + if (current_offloaded_ != nullptr) { // inside a offloaded stmt, record its belonging offloaded_stmt - stmt_to_offloaded[stmt] = current_offloaded; + stmt_to_offloaded_[stmt] = current_offloaded_; } } @@ -244,16 +292,16 @@ class StmtToOffloaded : public BasicStmtVisitor { static std::unordered_map run(IRNode *ir) { StmtToOffloaded pass; ir->accept(&pass); - return pass.stmt_to_offloaded; + return pass.stmt_to_offloaded_; } private: using BasicStmtVisitor::visit; // Local variables to its containing offloaded statement - std::unordered_map stmt_to_offloaded; + std::unordered_map stmt_to_offloaded_; - Stmt *current_offloaded; + Stmt *current_offloaded_; }; /* @@ -276,36 +324,37 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { const CompileConfig &config, const std::unordered_map &stmt_to_offloaded, OffloadedRanges *offloaded_ranges) - : config(config), - stmt_to_offloaded(stmt_to_offloaded), + : config_(config), + stmt_to_offloaded_(stmt_to_offloaded), offloaded_ranges_(offloaded_ranges) { allow_undefined_visitor = true; invoke_default_visitor = true; - current_offloaded = nullptr; - global_offset = 0; + current_offloaded_ = nullptr; + global_offset_ = 0; } std::size_t allocate_global(DataType type) { TI_ASSERT(type->vector_width() == 1 || type->is()); - auto ret = global_offset; + auto ret = global_offset_; if (type->is()) { auto tensor_type = type->cast(); - global_offset += tensor_type->get_num_elements() * - data_type_size(tensor_type->get_element_type()); + global_offset_ += tensor_type->get_num_elements() * + data_type_size(tensor_type->get_element_type()); } else { std::size_t type_size = data_type_size(type); // align global_offset to a multiple of type_size - global_offset = ((global_offset + type_size - 1) / type_size) * type_size; - ret = global_offset; - global_offset += type_size; + global_offset_ = + ((global_offset_ + type_size - 1) / type_size) * type_size; + ret = global_offset_; + global_offset_ += type_size; } - TI_ASSERT(global_offset < taichi_global_tmp_buffer_size); + TI_ASSERT(global_offset_ < taichi_global_tmp_buffer_size); return ret; } public: void visit(OffloadedStmt *stmt) override { - current_offloaded = stmt; + current_offloaded_ = stmt; if (auto begin = offloaded_ranges_->begin_stmts.find(stmt); begin != offloaded_ranges_->begin_stmts.end()) { test_and_allocate(begin->second); @@ -316,17 +365,17 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { } if (stmt->body) stmt->body->accept(this); - current_offloaded = nullptr; + current_offloaded_ = nullptr; } void visit(AllocaStmt *stmt) override { - TI_ASSERT(current_offloaded); + TI_ASSERT(current_offloaded_); } void test_and_allocate(Stmt *stmt) { if (stmt == nullptr) return; - if (stmt_to_offloaded[stmt] == current_offloaded) + if (stmt_to_offloaded_[stmt] == current_offloaded_) return; // Directly insert copies of ConstStmts later if (stmt->is()) @@ -336,13 +385,17 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { if (top_level_ptr->is() || stmt->is() || (stmt->is() && stmt->as()->is_ptr)) return; + if ((config_.arch == Arch::opengl || config_.arch == Arch::vulkan) && + demotable_axis_load(stmt)) + return; // Not yet allocated - if (local_to_global.find(top_level_ptr) == local_to_global.end()) { - local_to_global[top_level_ptr] = allocate_global(top_level_ptr->ret_type); + if (local_to_global_.find(top_level_ptr) == local_to_global_.end()) { + local_to_global_[top_level_ptr] = + allocate_global(top_level_ptr->ret_type); } } - void visit(Stmt *stmt) override { + void generic_visit(Stmt *stmt) { int n_op = stmt->num_operands(); for (int i = 0; i < n_op; i++) { auto op = stmt->operand(i); @@ -350,6 +403,14 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { } } + void preprocess_container_stmt(Stmt *stmt) override { + generic_visit(stmt); + } + + void visit(Stmt *stmt) override { + generic_visit(stmt); + } + static StmtToOffsetMap run( IRNode *root, const CompileConfig &config, @@ -358,17 +419,17 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { IdentifyValuesUsedInOtherOffloads pass(config, stmt_to_offloaded, offloaded_ranges); root->accept(&pass); - return pass.local_to_global; + return pass.local_to_global_; } private: - CompileConfig config; - std::unordered_map stmt_to_offloaded; + CompileConfig config_; + std::unordered_map stmt_to_offloaded_; OffloadedRanges *const offloaded_ranges_; // Local variables to global temporary offsets (in bytes) - StmtToOffsetMap local_to_global; - Stmt *current_offloaded; - std::size_t global_offset; + StmtToOffsetMap local_to_global_; + Stmt *current_offloaded_; + std::size_t global_offset_; }; // Store intermediate values to globals so that statements in later offloaded @@ -379,7 +440,7 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor { private: explicit PromoteIntermediateToGlobalTmp( const StmtToOffsetMap &local_to_global_offset) - : local_to_global_offset(local_to_global_offset) { + : local_to_global_offset_(local_to_global_offset) { allow_undefined_visitor = true; invoke_default_visitor = true; } @@ -387,32 +448,24 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor { public: void visit(Stmt *stmt) override { if (!stmt->is() && - local_to_global_offset.find(stmt) != local_to_global_offset.end() && - stored_to_global.find(stmt) == stored_to_global.end()) { - stored_to_global.insert(stmt); - auto offset = local_to_global_offset[stmt]; + local_to_global_offset_.find(stmt) != local_to_global_offset_.end() && + stored_to_global_.find(stmt) == stored_to_global_.end()) { + stored_to_global_.insert(stmt); + auto offset = local_to_global_offset_[stmt]; auto ptr = stmt->insert_after_me( Stmt::make(offset, stmt->ret_type)); ptr->insert_after_me(Stmt::make(ptr, stmt)); - throw IRModified(); } } static void run(IRNode *root, const StmtToOffsetMap &local_to_global_offset) { PromoteIntermediateToGlobalTmp pass(local_to_global_offset); - while (true) { - try { - root->accept(&pass); - } catch (IRModified) { - continue; - } - break; - } + root->accept(&pass); } private: - StmtToOffsetMap local_to_global_offset; - std::set stored_to_global; + StmtToOffsetMap local_to_global_offset_; + std::set stored_to_global_; }; class FixCrossOffloadReferences : public BasicStmtVisitor { @@ -424,9 +477,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { const StmtToOffsetMap &local_to_global_offset, const std::unordered_map &stmt_to_offloaded, OffloadedRanges *offloaded_ranges) - : config(config), - local_to_global_offset(local_to_global_offset), - stmt_to_offloaded(stmt_to_offloaded), + : config_(config), + local_to_global_offset_(local_to_global_offset), + stmt_to_offloaded_(stmt_to_offloaded), offloaded_ranges_(offloaded_ranges) { allow_undefined_visitor = true; invoke_default_visitor = true; @@ -439,44 +492,49 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (!stmt->const_begin) { TI_ASSERT(offloaded_ranges_->begin_stmts.find(stmt) != offloaded_ranges_->begin_stmts.end()) - TI_ASSERT_INFO(local_to_global_offset.find( + TI_ASSERT_INFO(local_to_global_offset_.find( offloaded_ranges_->begin_stmts.find(stmt)->second) != - local_to_global_offset.end(), + local_to_global_offset_.end(), "Begin fails.") stmt->begin_offset = - local_to_global_offset[offloaded_ranges_->begin_stmts.find(stmt) - ->second]; + local_to_global_offset_[offloaded_ranges_->begin_stmts.find(stmt) + ->second]; } if (!stmt->const_end) { - TI_ASSERT(offloaded_ranges_->end_stmts.find(stmt) != - offloaded_ranges_->end_stmts.end()) - TI_ASSERT_INFO(local_to_global_offset.find( - offloaded_ranges_->end_stmts.find(stmt)->second) != - local_to_global_offset.end(), - "End fails.") - stmt->end_offset = - local_to_global_offset[offloaded_ranges_->end_stmts.find(stmt) - ->second]; + if (stmt->end_stmt) { + stmt->end_stmt->accept(this); + stmt->end_offset = 0; + } else { + TI_ASSERT(offloaded_ranges_->end_stmts.find(stmt) != + offloaded_ranges_->end_stmts.end()) + TI_ASSERT_INFO(local_to_global_offset_.find( + offloaded_ranges_->end_stmts.find(stmt)->second) != + local_to_global_offset_.end(), + "End fails.") + stmt->end_offset = + local_to_global_offset_[offloaded_ranges_->end_stmts.find(stmt) + ->second]; + } } } } // Replace alloca with global var initialization (set to 0) void visit(AllocaStmt *stmt) override { - if (local_to_global_offset.find(stmt) == local_to_global_offset.end()) + if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end()) return; VecStatement replacement; auto ret_type = stmt->ret_type; - local_to_global_vector_type[stmt] = ret_type; + local_to_global_vector_type_[stmt] = ret_type; auto ptr = replacement.push_back( - local_to_global_offset[stmt], ret_type); - auto offloaded = stmt_to_offloaded[stmt]; - stmt_to_offloaded[ptr] = offloaded; + local_to_global_offset_[stmt], ret_type); + auto offloaded = stmt_to_offloaded_[stmt]; + stmt_to_offloaded_[ptr] = offloaded; if (auto tensor_type = stmt->ret_type->cast()) { LaneAttribute zero(std::vector( 1, TypedConstant(tensor_type->get_element_type()))); auto const_zero_stmt = replacement.push_back(zero); - stmt_to_offloaded[const_zero_stmt] = offloaded; + stmt_to_offloaded_[const_zero_stmt] = offloaded; for (int i = 0; i < tensor_type->get_num_elements(); ++i) { LaneAttribute offset(std::vector( 1, TypedConstant(i * @@ -486,9 +544,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { replacement.push_back(ptr, const_offset_stmt); auto global_store_stmt = replacement.push_back( ptr_offset_stmt, const_zero_stmt); - stmt_to_offloaded[const_offset_stmt] = offloaded; - stmt_to_offloaded[ptr_offset_stmt] = offloaded; - stmt_to_offloaded[global_store_stmt] = offloaded; + stmt_to_offloaded_[const_offset_stmt] = offloaded; + stmt_to_offloaded_[ptr_offset_stmt] = offloaded; + stmt_to_offloaded_[global_store_stmt] = offloaded; } } else { LaneAttribute zeros(std::vector( @@ -496,13 +554,12 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto const_zeros = replacement.push_back(zeros); auto global_store_stmt = replacement.push_back(ptr, const_zeros); - stmt_to_offloaded[global_store_stmt] = offloaded; + stmt_to_offloaded_[global_store_stmt] = offloaded; } stmt->parent->replace_with(stmt, std::move(replacement), false); // To deal with the same offloaded visit_operand() - stmt_to_offloaded[stmt] = nullptr; - throw IRModified(); + stmt_to_offloaded_[stmt] = nullptr; } // Replace local LD/ST with global LD/ST @@ -514,9 +571,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (top_level_ptr->is()) { VecStatement replacement; auto global_load = replacement.push_back(ptr); - stmt_to_offloaded[global_load] = stmt_to_offloaded[stmt]; + stmt_to_offloaded_[global_load] = stmt_to_offloaded_[stmt]; stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); } } @@ -528,9 +584,8 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { VecStatement replacement; auto global_store = replacement.push_back(ptr, stmt->val); - stmt_to_offloaded[global_store] = stmt_to_offloaded[stmt]; + stmt_to_offloaded_[global_store] = stmt_to_offloaded_[stmt]; stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); } } @@ -540,35 +595,38 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto op = stmt->operand(index); if (op == nullptr) return false; - if (stmt_to_offloaded[stmt] == stmt_to_offloaded[op]) // same OffloadedStmt + if (stmt_to_offloaded_[stmt] == + stmt_to_offloaded_[op]) // same OffloadedStmt return false; - auto offloaded = stmt_to_offloaded[stmt]; + auto offloaded = stmt_to_offloaded_[stmt]; if (op->is()) { auto copy = op->clone(); + auto pcopy = copy.get(); copy->as()->activate = false; - stmt_to_offloaded[copy.get()] = offloaded; + stmt_to_offloaded_[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); + generic_visit(pcopy); return true; } - if (local_to_global_offset.find(op) == local_to_global_offset.end()) { - TI_ASSERT_INFO( - op->is() || op->is() || - op->is() || op->is() || - (op->is() && op->as()->is_ptr), - "{} is not allowed here.", op->type()); - // For cases like ConstStmt + if (local_to_global_offset_.find(op) == local_to_global_offset_.end()) { + // For stmts that are not promoted to global tmp, clone them into current + // offloaded task. E.g. + // ConstStmt/PtrOffsetStmt/GlobalTemporaryStmt/ExternalTensorShapeAlongAxisStmt + // etc. auto copy = op->clone(); - stmt_to_offloaded[copy.get()] = offloaded; + auto pcopy = copy.get(); + stmt_to_offloaded_[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); + generic_visit(pcopy); } else { auto global_temporary = Stmt::make( - local_to_global_offset[op], op->ret_type); - stmt_to_offloaded[global_temporary.get()] = offloaded; + local_to_global_offset_[op], op->ret_type); + stmt_to_offloaded_[global_temporary.get()] = offloaded; stmt->set_operand(index, global_temporary.get()); if (op->is() || op->ret_type.is_pointer()) { // For cases like Alloca both TensorType and Scalar which will be @@ -577,7 +635,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } else { // For other cases like ArgLoadStmt UnaryOpStmt which needs to load. auto load = Stmt::make(global_temporary.get()); - stmt_to_offloaded[load.get()] = offloaded; + stmt_to_offloaded_[load.get()] = offloaded; stmt->set_operand(index, load.get()); stmt->insert_before_me(std::move(global_temporary)); stmt->insert_before_me(std::move(load)); @@ -588,13 +646,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { void generic_visit(Stmt *stmt) { int n_op = stmt->num_operands(); - bool modified = false; for (int i = 0; i < n_op; i++) { - if (visit_operand(stmt, i)) - modified = true; + visit_operand(stmt, i); } - if (modified) - throw IRModified(); } void visit(Stmt *stmt) override { @@ -614,22 +668,15 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { OffloadedRanges *offloaded_ranges) { FixCrossOffloadReferences pass(config, local_to_global_offset, stmt_to_offloaded, offloaded_ranges); - while (true) { - try { - root->accept(&pass); - } catch (IRModified) { - continue; - } - break; - } + root->accept(&pass); } private: - const CompileConfig &config; - StmtToOffsetMap local_to_global_offset; - std::unordered_map stmt_to_offloaded; + [[maybe_unused]] const CompileConfig &config_; + StmtToOffsetMap local_to_global_offset_; + std::unordered_map stmt_to_offloaded_; OffloadedRanges *const offloaded_ranges_; - std::unordered_map local_to_global_vector_type; + std::unordered_map local_to_global_vector_type_; }; void insert_gc(IRNode *root, const CompileConfig &config) { diff --git a/taichi/transforms/optimize_bit_struct_stores.cpp b/taichi/transforms/optimize_bit_struct_stores.cpp index 1b492ec47284d..0c771a4b15577 100644 --- a/taichi/transforms/optimize_bit_struct_stores.cpp +++ b/taichi/transforms/optimize_bit_struct_stores.cpp @@ -144,6 +144,7 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor { if (current_offloaded->task_type == OffloadedTaskType::serial) { demote = true; } else if (current_offloaded->task_type == OffloadedTaskType::range_for || + current_offloaded->task_type == OffloadedTaskType::mesh_for || current_offloaded->task_type == OffloadedTaskType::struct_for) { auto *snode = stmt->get_bit_struct_snode(); // Find the nearest non-bit-level ancestor @@ -165,6 +166,7 @@ class DemoteAtomicBitStructStores : public BasicStmtVisitor { void visit(OffloadedStmt *stmt) override { current_offloaded = stmt; if (stmt->task_type == OffloadedTaskType::range_for || + stmt->task_type == OffloadedTaskType::mesh_for || stmt->task_type == OffloadedTaskType::struct_for) { current_iterator_ = uniquely_accessed_bit_structs_.find(current_offloaded); diff --git a/taichi/transforms/remove_assume_in_range.cpp b/taichi/transforms/remove_assume_in_range.cpp index 4f3985b7967d4..79e3ffe9d5409 100644 --- a/taichi/transforms/remove_assume_in_range.cpp +++ b/taichi/transforms/remove_assume_in_range.cpp @@ -18,7 +18,7 @@ class RemoveRangeAssumption : public BasicStmtVisitor { DelayedIRModifier modifier; void visit(RangeAssumptionStmt *stmt) override { - stmt->replace_with(stmt->input); + stmt->replace_usages_with(stmt->input); modifier.erase(stmt); } diff --git a/taichi/transforms/remove_loop_unique.cpp b/taichi/transforms/remove_loop_unique.cpp index aecb3eb226e97..d6ae1ceb71898 100644 --- a/taichi/transforms/remove_loop_unique.cpp +++ b/taichi/transforms/remove_loop_unique.cpp @@ -16,7 +16,7 @@ class RemoveLoopUnique : public BasicStmtVisitor { DelayedIRModifier modifier; void visit(LoopUniqueStmt *stmt) override { - stmt->replace_with(stmt->input); + stmt->replace_usages_with(stmt->input); modifier.erase(stmt); } diff --git a/taichi/transforms/reverse_segments.cpp b/taichi/transforms/reverse_segments.cpp index 65388359b1f4d..482d2415f4c11 100644 --- a/taichi/transforms/reverse_segments.cpp +++ b/taichi/transforms/reverse_segments.cpp @@ -64,12 +64,15 @@ void reverse_segments(IRNode *root) { } } */ - if (has_for && has_non_for) + if (has_for && has_non_for) { TI_ERROR( - "Invalid program input for autodiff. Please check the documentation " - "for the \"Kernel Simplicity Rule\":\n" + "Invalid program input for autodiff: " + "Mixed usage of for-loops and statements without looping. \n" + "Please split them into two kernels " + "and check the documentation for more details:\n" "https://docs.taichi.graphics/lang/articles/advanced/" - "differentiable_programming#kernel-simplicity-rule"); + "differentiable_programming"); + } for (auto &sblock : statement_blocks) { for (auto &&s : sblock) { block->statements.push_back(std::move(s)); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 81cb6006f42c5..50acdd97f5d00 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -22,6 +22,7 @@ class BasicBlockSimplify : public IRVisitor { std::set &visited; StructForStmt *current_struct_for; CompileConfig config; + DelayedIRModifier modifier; BasicBlockSimplify(Block *block, std::set &visited, @@ -33,7 +34,6 @@ class BasicBlockSimplify : public IRVisitor { config(config) { allow_undefined_visitor = true; invoke_default_visitor = false; - run(); } bool is_done(Stmt *stmt) { @@ -44,13 +44,30 @@ class BasicBlockSimplify : public IRVisitor { visited.insert(stmt->instance_id); } - void run() { + void accept_block() { for (int i = 0; i < (int)block->statements.size(); i++) { current_stmt_id = i; block->statements[i]->accept(this); } } + static bool run(Block *block, + std::set &visited, + StructForStmt *current_struct_for, + const CompileConfig &config) { + BasicBlockSimplify simplifier(block, visited, current_struct_for, config); + bool ir_modified = false; + while (true) { + simplifier.accept_block(); + if (simplifier.modifier.modify_ir()) { + ir_modified = true; + } else { + break; + } + } + return ir_modified; + } + void visit(ElementShuffleStmt *stmt) override { if (is_done(stmt)) return; @@ -67,9 +84,8 @@ class BasicBlockSimplify : public IRVisitor { if (same_source && inc_index && stmt->width() == stmt->elements[0].stmt->width()) { // useless shuffle. - stmt->replace_with(stmt->elements[0].stmt); - stmt->parent->erase(current_stmt_id); - throw IRModified(); + stmt->replace_usages_with(stmt->elements[0].stmt); + modifier.erase(stmt); } } @@ -120,9 +136,9 @@ class BasicBlockSimplify : public IRVisitor { } } if (!has_store) { - stmt->replace_with(bstmt.get()); - stmt->parent->erase(current_stmt_id); - throw IRModified(); + stmt->replace_usages_with(bstmt.get()); + modifier.erase(stmt); + return; } } } @@ -133,9 +149,8 @@ class BasicBlockSimplify : public IRVisitor { void visit(IntegerOffsetStmt *stmt) override { if (stmt->offset == 0) { - stmt->replace_with(stmt->input); - stmt->parent->erase(stmt); - throw IRModified(); + stmt->replace_usages_with(stmt->input); + modifier.erase(stmt); } } @@ -146,19 +161,19 @@ class BasicBlockSimplify : public IRVisitor { // step 0: eliminate empty extraction if (stmt->bit_begin == stmt->bit_end) { auto zero = Stmt::make(LaneAttribute(0)); - stmt->replace_with(zero.get()); - stmt->insert_after_me(std::move(zero)); - stmt->parent->erase(current_stmt_id); - throw IRModified(); + stmt->replace_usages_with(zero.get()); + modifier.insert_after(stmt, std::move(zero)); + modifier.erase(stmt); + return; } // step 1: eliminate useless extraction of another BitExtractStmt if (stmt->bit_begin == 0 && stmt->input->is()) { auto bstmt = stmt->input->as(); if (stmt->bit_end >= bstmt->bit_end - bstmt->bit_begin) { - stmt->replace_with(bstmt); - stmt->parent->erase(current_stmt_id); - throw IRModified(); + stmt->replace_usages_with(bstmt); + modifier.erase(stmt); + return; } } @@ -167,9 +182,9 @@ class BasicBlockSimplify : public IRVisitor { auto bstmt = stmt->input->as(); const int max_num_bits = bstmt->max_num_bits(); if (max_num_bits != -1 && stmt->bit_end >= max_num_bits) { - stmt->replace_with(bstmt); - stmt->parent->erase(current_stmt_id); - throw IRModified(); + stmt->replace_usages_with(bstmt); + modifier.erase(stmt); + return; } } @@ -182,56 +197,56 @@ class BasicBlockSimplify : public IRVisitor { if (diff.linear_related() && diff.certain()) { // case 1: last loop var, vectorized, has assumption on vec size if (k == num_loop_vars - 1) { - auto load = stmt->insert_before_me( - Stmt::make(current_struct_for, k)); + auto load = Stmt::make(current_struct_for, k); load->ret_type = PrimitiveType::i32; - stmt->input = load; + stmt->input = load.get(); int64 bound = 1LL << stmt->bit_end; auto offset = (((int64)diff.low % bound + bound) % bound) & ~((1LL << (stmt->bit_begin)) - 1); - - if (current_struct_for->vectorize == 1) - offset = diff.low; - if (stmt->bit_begin == 0 && - current_struct_for->vectorize == bound) { + auto load_addr = load.get(); + modifier.insert_before(stmt, std::move(load)); + offset = diff.low; // TODO: Vectorization + if (stmt->bit_begin == 0 && bound == 1) { // TODO: Vectorization // TODO: take care of cases where vectorization width != z // dimension of the block - auto offset_stmt = stmt->insert_after_me( - Stmt::make(stmt, offset)); - stmt->replace_with(offset_stmt); + auto offset_stmt = Stmt::make(stmt, offset); + stmt->replace_usages_with(offset_stmt.get()); // fix the offset stmt operand offset_stmt->as()->input = stmt; + modifier.insert_after(stmt, std::move(offset_stmt)); } else { if (offset != 0) { - auto offset_const = stmt->insert_before_me( + auto offset_const = Stmt::make(LaneAttribute( - TypedConstant(PrimitiveType::i32, offset)))); - auto sum = stmt->insert_before_me(Stmt::make( - BinaryOpType::add, load, offset_const)); - stmt->input = sum; + TypedConstant(PrimitiveType::i32, offset))); + auto sum = Stmt::make( + BinaryOpType::add, load_addr, offset_const.get()); + stmt->input = sum.get(); + modifier.insert_before(stmt, std::move(offset_const)); + modifier.insert_before(stmt, std::move(offset_const)); } } } else { // insert constant - auto load = stmt->insert_before_me( - Stmt::make(current_struct_for, k)); + auto load = Stmt::make(current_struct_for, k); load->ret_type = PrimitiveType::i32; - auto constant = stmt->insert_before_me( - Stmt::make(TypedConstant(diff.low))); - auto add = stmt->insert_before_me( - Stmt::make(BinaryOpType::add, load, constant)); + auto constant = Stmt::make(TypedConstant(diff.low)); + auto add = Stmt::make(BinaryOpType::add, load.get(), + constant.get()); add->ret_type = PrimitiveType::i32; - stmt->input = add; + stmt->input = add.get(); + modifier.insert_before(stmt, std::move(load)); + modifier.insert_before(stmt, std::move(constant)); + modifier.insert_before(stmt, std::move(add)); } stmt->simplified = true; - throw IRModified(); + return; } } } set_done(stmt); } - template static bool identical_vectors(const std::vector &a, const std::vector &b) { @@ -250,13 +265,14 @@ class BasicBlockSimplify : public IRVisitor { if (!stmt->inputs.empty() && stmt->inputs.back()->is()) { auto previous_offset = stmt->inputs.back()->as(); // push forward offset - auto offset_stmt = stmt->insert_after_me( - Stmt::make(stmt, previous_offset->offset)); + auto offset_stmt = + Stmt::make(stmt, previous_offset->offset); stmt->inputs.back() = previous_offset->input; - stmt->replace_with(offset_stmt); + stmt->replace_usages_with(offset_stmt.get()); offset_stmt->as()->input = stmt; - throw IRModified(); + modifier.insert_after(stmt, std::move(offset_stmt)); + return; } // Lower into a series of adds and muls. @@ -269,10 +285,10 @@ class BasicBlockSimplify : public IRVisitor { stride_stmt.get()); auto newsum = Stmt::make(BinaryOpType::add, sum.get(), mul.get()); - stmt->insert_before_me(std::move(sum)); + modifier.insert_before(stmt, std::move(sum)); sum = std::move(newsum); - stmt->insert_before_me(std::move(stride_stmt)); - stmt->insert_before_me(std::move(mul)); + modifier.insert_before(stmt, std::move(stride_stmt)); + modifier.insert_before(stmt, std::move(mul)); stride_product *= stmt->strides[i]; } // Compare the result with 0 to make sure no overflow occurs under Debug @@ -292,20 +308,19 @@ class BasicBlockSimplify : public IRVisitor { auto select = Stmt::make( TernaryOpType::select, check_sum.get(), sum.get(), zero.get()); - stmt->insert_before_me(std::move(zero)); - stmt->insert_before_me(std::move(sum)); - stmt->insert_before_me(std::move(check_sum)); - stmt->insert_before_me(std::move(assert)); - stmt->replace_with(select.get()); - stmt->insert_before_me(std::move(select)); + modifier.insert_before(stmt, std::move(zero)); + modifier.insert_before(stmt, std::move(sum)); + modifier.insert_before(stmt, std::move(check_sum)); + modifier.insert_before(stmt, std::move(assert)); + stmt->replace_usages_with(select.get()); + modifier.insert_before(stmt, std::move(select)); } else { - stmt->replace_with(sum.get()); - stmt->insert_before_me(std::move(sum)); + stmt->replace_usages_with(sum.get()); + modifier.insert_before(stmt, std::move(sum)); } - stmt->parent->erase(stmt); + modifier.erase(stmt); // get types of adds and muls - irpass::type_check(stmt->parent, config); - throw IRModified(); + modifier.type_check(stmt->parent, config); } void visit(SNodeLookupStmt *stmt) override { @@ -324,13 +339,14 @@ class BasicBlockSimplify : public IRVisitor { snode->ch[i]->dt->is_primitive(PrimitiveTypeID::f32)); } - auto offset_stmt = stmt->insert_after_me(Stmt::make( - stmt, previous_offset->offset * sizeof(int32) * (snode->ch.size()))); + auto offset_stmt = Stmt::make( + stmt, previous_offset->offset * sizeof(int32) * (snode->ch.size())); stmt->input_index = previous_offset->input; - stmt->replace_with(offset_stmt); + stmt->replace_usages_with(offset_stmt.get()); offset_stmt->as()->input = stmt; - throw IRModified(); + modifier.insert_after(stmt, std::move(offset_stmt)); + return; } set_done(stmt); @@ -345,15 +361,16 @@ class BasicBlockSimplify : public IRVisitor { // push forward offset // auto snode = stmt->input_snode; - auto offset_stmt = stmt->insert_after_me(Stmt::make( - stmt, stmt->chid * sizeof(int32) + previous_offset->offset)); + auto offset_stmt = Stmt::make( + stmt, stmt->chid * sizeof(int32) + previous_offset->offset); stmt->input_ptr = previous_offset->input; - stmt->replace_with(offset_stmt); + stmt->replace_usages_with(offset_stmt.get()); stmt->chid = 0; stmt->output_snode = stmt->input_snode->ch[stmt->chid].get(); offset_stmt->as()->input = stmt; - throw IRModified(); + modifier.insert_after(stmt, std::move(offset_stmt)); + return; } set_done(stmt); @@ -362,7 +379,8 @@ class BasicBlockSimplify : public IRVisitor { void visit(WhileControlStmt *stmt) override { if (stmt->width() == 1 && stmt->mask) { stmt->mask = nullptr; - throw IRModified(); + modifier.mark_as_modified(); + return; } } @@ -391,7 +409,8 @@ class BasicBlockSimplify : public IRVisitor { if (if_stmt->width() == 1 && (if_stmt->true_mask || if_stmt->false_mask)) { if_stmt->true_mask = nullptr; if_stmt->false_mask = nullptr; - throw IRModified(); + modifier.mark_as_modified(); + return; } auto flatten = [&](std::vector &clause, bool true_branch) { bool plain_clause = true; // no global store, no container @@ -435,18 +454,19 @@ class BasicBlockSimplify : public IRVisitor { for (int l = 0; l < store->width(); l++) { lanes.push_back(LocalAddress(store->dest, l)); } - auto load = - if_stmt->insert_before_me(Stmt::make(lanes)); - irpass::type_check(load, config); - auto select = if_stmt->insert_before_me( - Stmt::make(TernaryOpType::select, if_stmt->cond, - true_branch ? store->val : load, - true_branch ? load : store->val)); - irpass::type_check(select, config); - store->val = select; - if_stmt->insert_before_me(std::move(clause[i])); + auto load = Stmt::make(lanes); + modifier.type_check(load.get(), config); + auto select = Stmt::make( + TernaryOpType::select, if_stmt->cond, + true_branch ? store->val : load.get(), + true_branch ? load.get() : store->val); + modifier.type_check(select.get(), config); + store->val = select.get(); + modifier.insert_before(if_stmt, std::move(load)); + modifier.insert_before(if_stmt, std::move(select)); + modifier.insert_before(if_stmt, std::move(clause[i])); } else { - if_stmt->insert_before_me(std::move(clause[i])); + modifier.insert_before(if_stmt, std::move(clause[i])); } } auto clean_clause = std::vector(); @@ -467,39 +487,43 @@ class BasicBlockSimplify : public IRVisitor { if (config.flatten_if) { if (if_stmt->true_statements && flatten(if_stmt->true_statements->statements, true)) { - throw IRModified(); + modifier.mark_as_modified(); + return; } if (if_stmt->false_statements && flatten(if_stmt->false_statements->statements, false)) { - throw IRModified(); + modifier.mark_as_modified(); + return; } } if (if_stmt->true_statements) { if (if_stmt->true_statements->statements.empty()) { if_stmt->set_true_statements(nullptr); - throw IRModified(); + modifier.mark_as_modified(); + return; } } if (if_stmt->false_statements) { if (if_stmt->false_statements->statements.empty()) { if_stmt->set_false_statements(nullptr); - throw IRModified(); + modifier.mark_as_modified(); + return; } } if (!if_stmt->true_statements && !if_stmt->false_statements) { - if_stmt->parent->erase(if_stmt); - throw IRModified(); + modifier.erase(if_stmt); + return; } if (config.advanced_optimization) { // Merge adjacent if's with the identical condition. // TODO: What about IfStmt::true_mask and IfStmt::false_mask? - if (current_stmt_id > 0 && - block->statements[current_stmt_id - 1]->is()) { - auto bstmt = block->statements[current_stmt_id - 1]->as(); + if (current_stmt_id < block->size() - 1 && + block->statements[current_stmt_id + 1]->is()) { + auto bstmt = block->statements[current_stmt_id + 1]->as(); if (bstmt->cond == if_stmt->cond) { auto concatenate = [](std::unique_ptr &clause1, std::unique_ptr &clause2) { @@ -508,12 +532,12 @@ class BasicBlockSimplify : public IRVisitor { return; } if (clause2 != nullptr) - clause1->insert(VecStatement(std::move(clause2->statements))); + clause1->insert(VecStatement(std::move(clause2->statements)), 0); }; concatenate(bstmt->true_statements, if_stmt->true_statements); concatenate(bstmt->false_statements, if_stmt->false_statements); - if_stmt->parent->erase(if_stmt); - throw IRModified(); + modifier.erase(if_stmt); + return; } } } @@ -521,15 +545,16 @@ class BasicBlockSimplify : public IRVisitor { void visit(OffloadedStmt *stmt) override { if (stmt->has_body() && stmt->body->statements.empty()) { - stmt->parent->erase(stmt); - throw IRModified(); + modifier.erase(stmt); + return; } } void visit(WhileStmt *stmt) override { if (stmt->width() == 1 && stmt->mask) { stmt->mask = nullptr; - throw IRModified(); + modifier.mark_as_modified(); + return; } } }; @@ -550,14 +575,8 @@ class Simplify : public IRVisitor { void visit(Block *block) override { std::set visited; - while (true) { - try { - BasicBlockSimplify _(block, visited, current_struct_for, config); - } catch (IRModified) { - modified = true; - continue; - } - break; + if (BasicBlockSimplify::run(block, visited, current_struct_for, config)) { + modified = true; } for (auto &stmt : block->statements) { stmt->accept(this); @@ -584,6 +603,10 @@ class Simplify : public IRVisitor { current_struct_for = nullptr; } + void visit(MeshForStmt *for_stmt) override { + for_stmt->body->accept(this); + } + void visit(WhileStmt *stmt) override { stmt->body->accept(this); } @@ -639,11 +662,12 @@ void full_simplify(IRNode *root, modified = true; if (die(root)) modified = true; - if (whole_kernel_cse(root)) + if (config.opt_level > 0 && whole_kernel_cse(root)) modified = true; // Don't do this time-consuming optimization pass again if the IR is // not modified. - if ((first_iteration || modified) && config.cfg_optimization && + if (config.opt_level > 0 && (first_iteration || modified) && + config.cfg_optimization && cfg_optimization(root, args.after_lower_access)) modified = true; first_iteration = false; diff --git a/taichi/transforms/statement_usage_replace.cpp b/taichi/transforms/statement_usage_replace.cpp index 2a741e49519a5..a46d104c01aa6 100644 --- a/taichi/transforms/statement_usage_replace.cpp +++ b/taichi/transforms/statement_usage_replace.cpp @@ -53,6 +53,10 @@ class StatementUsageReplace : public IRVisitor { stmt->body->accept(this); } + void visit(MeshForStmt *stmt) override { + stmt->body->accept(this); + } + void visit(OffloadedStmt *stmt) override { stmt->all_blocks_accept(this); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 1cc7dd6d0da0f..b187af806401b 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -13,10 +13,10 @@ TLANG_NAMESPACE_BEGIN // Var lookup and Type inference class TypeCheck : public IRVisitor { private: - CompileConfig config; + CompileConfig config_; public: - explicit TypeCheck(const CompileConfig &config) : config(config) { + explicit TypeCheck(const CompileConfig &config) : config_(config) { allow_undefined_visitor = true; } @@ -34,11 +34,6 @@ class TypeCheck : public IRVisitor { } void visit(IfStmt *if_stmt) override { - // TODO: use PrimitiveType::u1 when it's supported - TI_ASSERT_INFO( - if_stmt->cond->ret_type->is_primitive(PrimitiveTypeID::i32), - "`if` conditions must be of type int32, consider using `if x != 0:` " - "instead of `if x:` for float values."); if (if_stmt->true_statements) if_stmt->true_statements->accept(this); if (if_stmt->false_statements) { @@ -61,14 +56,15 @@ class TypeCheck : public IRVisitor { // TODO(type): test_ad_for fails if we assume dest is a pointer type. auto dst_type = stmt->dest->ret_type.ptr_removed(); if (auto cit = dst_type->cast()) { - dst_type = cit->get_compute_type(); + dst_type = cit->get_physical_type(); } else if (auto cft = dst_type->cast()) { - dst_type = cft->get_compute_type(); - } - if (stmt->val->ret_type != dst_type) { - TI_WARN("[{}] Atomic add ({} to {}) may lose precision, at", stmt->name(), - data_type_name(stmt->val->ret_type), data_type_name(dst_type)); - TI_WARN("\n{}", stmt->tb); + auto cit = cft->get_digits_type()->as(); + dst_type = cit->get_physical_type(); + } else if (stmt->val->ret_type != dst_type) { + TI_WARN("[{}] Atomic {} ({} to {}) may lose precision, at\n{}", + stmt->name(), atomic_op_type_name(stmt->op_type), + data_type_name(stmt->val->ret_type), data_type_name(dst_type), + stmt->tb); stmt->val = insert_type_cast_before(stmt, stmt->val, dst_type); } stmt->ret_type = dst_type; @@ -118,9 +114,9 @@ class TypeCheck : public IRVisitor { stmt->val = insert_type_cast_before(stmt, stmt->val, dst_value_type); } if (dst_value_type != promoted && dst_value_type != stmt->val->ret_type) { - TI_WARN("[{}] Local store may lose precision: {} <- {}, at", - stmt->name(), dst_value_type->to_string(), input_type); - TI_WARN("\n{}", stmt->tb); + TI_WARN("[{}] Local store may lose precision: {} <- {}, at\n{}", + stmt->name(), dst_value_type->to_string(), input_type, + stmt->tb); } stmt->ret_type = dst_value_type; return; @@ -140,10 +136,10 @@ class TypeCheck : public IRVisitor { } if (stmt->dest->ret_type != common_container_type) { TI_WARN( - "[{}] Local store may lose precision (target = {}, value = {}) at", + "[{}] Local store may lose precision (target = {}, value = {}), " + "at\n{}", stmt->name(), stmt->dest->ret_data_type_name(), - old_data->ret_data_type_name(), stmt->id); - TI_WARN("\n{}", stmt->tb); + old_data->ret_data_type_name(), stmt->id, stmt->tb); } stmt->ret_type = stmt->dest->ret_type; } @@ -224,9 +220,8 @@ class TypeCheck : public IRVisitor { // TODO: do not use "promoted" here since u8 + u8 = i32 in C++ and storing // u8 to u8 leads to extra warnings. if (dst_value_type != promoted && dst_value_type != stmt->val->ret_type) { - TI_WARN("[{}] Global store may lose precision: {} <- {}, at", - stmt->name(), dst_value_type->to_string(), input_type); - TI_WARN("\n{}", stmt->tb); + TI_WARN("[{}] Global store may lose precision: {} <- {}, at\n{}", + stmt->name(), dst_value_type->to_string(), input_type, stmt->tb); } } @@ -242,6 +237,10 @@ class TypeCheck : public IRVisitor { stmt->body->accept(this); } + void visit(MeshForStmt *stmt) override { + stmt->body->accept(this); + } + void visit(WhileStmt *stmt) override { stmt->body->accept(this); } @@ -252,17 +251,10 @@ class TypeCheck : public IRVisitor { stmt->ret_type = stmt->cast_type; } if (!is_real(stmt->operand->ret_type)) { - if (is_trigonometric(stmt->op_type)) { - TI_ERROR("[{}] Trigonometric operator takes real inputs only. At {}", - stmt->name(), stmt->tb); - } else if (stmt->op_type == UnaryOpType::floor || - stmt->op_type == UnaryOpType::ceil) { - TI_ERROR("[{}] floor/ceil takes real inputs only. At {}", stmt->name(), - stmt->tb); - } else if (stmt->op_type == UnaryOpType::sqrt || - stmt->op_type == UnaryOpType::exp || - stmt->op_type == UnaryOpType::log) { - cast(stmt->operand, config.default_fp); + if (stmt->op_type == UnaryOpType::sqrt || + stmt->op_type == UnaryOpType::exp || + stmt->op_type == UnaryOpType::log) { + cast(stmt->operand, config_.default_fp); } } } @@ -300,14 +292,13 @@ class TypeCheck : public IRVisitor { auto error = [&](std::string comment = "") { if (comment == "") { TI_WARN( - "[{}] Error: type mismatch (left = {}, right = {}, stmt_id = {}) " - "at", + "[{}] Error: type mismatch (left = {}, right = {}, stmt_id = {}), " + "at\n{}", stmt->name(), stmt->lhs->ret_data_type_name(), - stmt->rhs->ret_data_type_name(), stmt->id); + stmt->rhs->ret_data_type_name(), stmt->id, stmt->tb); } else { - TI_WARN("[{}] {} at", stmt->name(), comment); + TI_WARN("[{}] {} at\n{}", stmt->name(), comment, stmt->tb); } - TI_WARN("\n{}", stmt->tb); TI_WARN("Compilation stopped due to type mismatch."); throw std::runtime_error("Binary operator type mismatch"); }; @@ -318,7 +309,7 @@ class TypeCheck : public IRVisitor { // lower truediv into div if (stmt->op_type == BinaryOpType::truediv) { - auto default_fp = config.default_fp; + auto default_fp = config_.default_fp; if (!is_real(stmt->lhs->ret_type)) { cast(stmt->lhs, default_fp); } @@ -357,11 +348,6 @@ class TypeCheck : public IRVisitor { if (!matching) { error(); } - if (binary_is_bitwise(stmt->op_type)) { - if (!is_integral(stmt->lhs->ret_type)) { - error("Error: bitwise operations can only apply to integral types."); - } - } if (is_comparison(stmt->op_type)) { stmt->ret_type = TypeFactory::create_vector_or_scalar_type( stmt->lhs->width(), PrimitiveType::i32); @@ -374,10 +360,8 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == TernaryOpType::select) { auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type); TI_ASSERT(stmt->op1->ret_type->is_primitive(PrimitiveTypeID::i32)) - TI_ASSERT(stmt->op1->ret_type->vector_width() == - stmt->op2->ret_type->vector_width()); - TI_ASSERT(stmt->op2->ret_type->vector_width() == - stmt->op3->ret_type->vector_width()); + TI_ASSERT(stmt->op1->width() == stmt->op2->width()); + TI_ASSERT(stmt->op2->width() == stmt->op3->width()); if (ret_type != stmt->op2->ret_type) { auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type); stmt->op2 = cast_stmt; @@ -399,7 +383,6 @@ class TypeCheck : public IRVisitor { } void visit(RangeAssumptionStmt *stmt) override { - TI_ASSERT(stmt->input->ret_type == stmt->base->ret_type); stmt->ret_type = stmt->input->ret_type; } @@ -417,19 +400,16 @@ class TypeCheck : public IRVisitor { } void visit(ArgLoadStmt *stmt) override { - const auto &rt = stmt->ret_type; // TODO: Maybe have a type_inference() pass, which takes in the args/rets // defined by the kernel. After that, type_check() pass will purely do // verification, without modifying any types. - TI_ASSERT(rt != PrimitiveType::unknown); - TI_ASSERT(rt->vector_width() == 1); + TI_ASSERT(stmt->width() == 1); stmt->ret_type.set_is_pointer(stmt->is_ptr); } void visit(ReturnStmt *stmt) override { // TODO: Support stmt->ret_id? - stmt->ret_type = stmt->value->ret_type; - TI_ASSERT(stmt->ret_type->vector_width() == 1); + TI_ASSERT(stmt->width() == 1); } void visit(ExternalPtrStmt *stmt) override { @@ -453,11 +433,6 @@ class TypeCheck : public IRVisitor { TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32); } - void visit(BlockDimStmt *stmt) override { - stmt->ret_type = - TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32); - } - void visit(GetRootStmt *stmt) override { stmt->ret_type = TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::gen, true); diff --git a/taichi/transforms/unreachable_code_elimination.cpp b/taichi/transforms/unreachable_code_elimination.cpp index a12a975abd982..973ed54568866 100644 --- a/taichi/transforms/unreachable_code_elimination.cpp +++ b/taichi/transforms/unreachable_code_elimination.cpp @@ -70,6 +70,10 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { visit_loop(stmt->body.get()); } + void visit(MeshForStmt *stmt) override { + visit_loop(stmt->body.get()); + } + void visit(WhileStmt *stmt) override { visit_loop(stmt->body.get()); } @@ -78,10 +82,14 @@ class UnreachableCodeEliminator : public BasicStmtVisitor { if (stmt->tls_prologue) stmt->tls_prologue->accept(this); + if (stmt->mesh_prologue) + stmt->mesh_prologue->accept(this); + if (stmt->bls_prologue) stmt->bls_prologue->accept(this); if (stmt->task_type == OffloadedStmt::TaskType::range_for || + stmt->task_type == OffloadedStmt::TaskType::mesh_for || stmt->task_type == OffloadedStmt::TaskType::struct_for) visit_loop(stmt->body.get()); else if (stmt->body) diff --git a/taichi/transforms/variable_optimization.cpp b/taichi/transforms/variable_optimization.cpp deleted file mode 100644 index 8ad5d4f41d39a..0000000000000 --- a/taichi/transforms/variable_optimization.cpp +++ /dev/null @@ -1,647 +0,0 @@ -#include "taichi/ir/ir.h" -#include "taichi/ir/analysis.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" -#include "taichi/ir/state_machine.h" -#include "taichi/program/compile_config.h" - -#include - -TLANG_NAMESPACE_BEGIN - -class VariableOptimize : public IRVisitor { - protected: - bool maybe_run; - - public: - VariableOptimize() { - allow_undefined_visitor = true; - invoke_default_visitor = true; - maybe_run = false; - } - - virtual StateMachine &get_state_machine(Stmt *stmt) = 0; - - virtual void modify_all_state_machines(void (StateMachine::*func)()) = 0; - - virtual void clear() = 0; - - virtual void finalize() { - modify_all_state_machines(&StateMachine::finalize); - } - - void visit(Stmt *stmt) override { - if (stmt->is_container_statement()) { - TI_ERROR("Visitor for container stmt undefined."); - } - } - - void visit(WhileControlStmt *stmt) override { - if (!maybe_run) { - modify_all_state_machines(&StateMachine::continue_or_break); - } - } - - void visit(ContinueStmt *stmt) override { - if (!maybe_run) { - modify_all_state_machines(&StateMachine::continue_or_break); - } - } - - virtual void visit_loop(Block *body) = 0; - - void visit(Block *block) override { - for (auto &stmt : block->statements) { - stmt->accept(this); - } - } - - void visit(WhileStmt *stmt) override { - TI_ASSERT(stmt->mask == nullptr); - visit_loop(stmt->body.get()); - } - - void visit(RangeForStmt *stmt) override { - visit_loop(stmt->body.get()); - } - - void visit(StructForStmt *stmt) override { - visit_loop(stmt->body.get()); - } - - void visit(OffloadedStmt *stmt) override { - if (stmt->body) { - modify_all_state_machines(&StateMachine::begin_offload); - stmt->body->accept(this); - } - } - - void run(IRNode *node) { - StateMachine::rebuild_atomics_usage(node); - while (true) { - bool modified = false; - try { - clear(); - node->accept(this); - finalize(); - } catch (IRModified) { - modified = true; - } - if (!modified) - break; - } - } -}; - -class AllocaOptimize : public VariableOptimize { - private: - std::unordered_map> - state_machines; - - public: - using VariableOptimize::visit; - - StateMachine &get_state_machine(Stmt *stmt) override { - return state_machines[stmt->parent][stmt]; - } - - void modify_all_state_machines(void (StateMachine::*func)()) override { - for (auto &i : state_machines) { - for (auto &j : i.second) { - (j.second.*func)(); - } - } - } - - void clear() override { - state_machines.clear(); - } - - void visit(AllocaStmt *stmt) override { - state_machines[stmt->parent].insert( - std::make_pair(stmt, StateMachine(stmt, true))); - } - - void visit(AtomicOpStmt *stmt) override { - if (!stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_atomic_op(); - else - get_state_machine(stmt->dest).atomic_op(stmt); - } - - void visit(LocalStoreStmt *stmt) override { - if (maybe_run) - get_state_machine(stmt->dest).maybe_store(stmt); - else - get_state_machine(stmt->dest).store(stmt); - } - - void visit(LocalLoadStmt *stmt) override { - TI_ASSERT(stmt->width() == 1); - TI_ASSERT(stmt->src[0].offset == 0); - if (maybe_run) - get_state_machine(stmt->src[0].var).maybe_load(); - else - get_state_machine(stmt->src[0].var).load(stmt); - } - - void visit(IfStmt *if_stmt) override { - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->true_statements) { - if_stmt->true_statements->accept(this); - } - auto true_branch = std::move(state_machines); - - state_machines = origin; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - auto false_branch = std::move(state_machines); - - state_machines = std::move(origin); - for (auto &i : state_machines) { - auto &true_branch_block = true_branch[i.first]; - auto &false_branch_block = false_branch[i.first]; - for (auto &j : i.second) { - j.second.merge_from_if(true_branch_block[j.first], - false_branch_block[j.first]); - } - } - } - - void visit_loop(Block *body) override { - if (maybe_run) { - body->accept(this); - return; - } - - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - maybe_run = true; - body->accept(this); - maybe_run = false; - body->accept(this); - for (auto &i : origin) { - auto &loop_block = state_machines[i.first]; - for (auto &j : i.second) { - j.second.merge_from_loop(loop_block[j.first]); - } - } - state_machines = std::move(origin); - } - - void visit(Block *block) override { - state_machines.insert( - std::make_pair(block, std::unordered_map())); - - for (auto &stmt : block->statements) { - stmt->accept(this); - } - if (!maybe_run) { - for (auto &it : state_machines[block]) { - it.second.finalize(); - } - } - state_machines.erase(block); - } -}; - -class GlobalTempOptimize : public VariableOptimize { - private: - std::unordered_map state_machines; - - public: - using VariableOptimize::visit; - - StateMachine &get_state_machine(Stmt *stmt) override { - return state_machines[stmt->as()->offset]; - } - - void modify_all_state_machines(void (StateMachine::*func)()) override { - for (auto &i : state_machines) { - (i.second.*func)(); - } - } - - void clear() override { - state_machines.clear(); - } - - void visit(GlobalTemporaryStmt *stmt) override { - if (state_machines.find(stmt->offset) == state_machines.end()) - state_machines.insert( - std::make_pair(stmt->offset, StateMachine(stmt, false))); - } - - void visit(AtomicOpStmt *stmt) override { - if (!stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_atomic_op(); - else - get_state_machine(stmt->dest).atomic_op(stmt); - } - - void visit(GlobalStoreStmt *stmt) override { - if (!stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_store(stmt); - else - get_state_machine(stmt->dest).store(stmt); - } - - void visit(GlobalLoadStmt *stmt) override { - if (!stmt->src->is()) - return; - if (maybe_run) - get_state_machine(stmt->src).maybe_load(); - else - get_state_machine(stmt->src).load(stmt); - } - - void visit(IfStmt *if_stmt) override { - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->true_statements) { - if_stmt->true_statements->accept(this); - } - auto true_branch = std::move(state_machines); - - state_machines = origin; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - auto false_branch = std::move(state_machines); - - state_machines = std::move(origin); - for (auto &it : state_machines) { - it.second.merge_from_if(true_branch[it.first], false_branch[it.first]); - } - for (auto &it : true_branch) { - if (state_machines.find(it.first) == state_machines.end()) - state_machines.insert(it); - } - for (auto &it : false_branch) { - if (state_machines.find(it.first) == state_machines.end()) - state_machines.insert(it); - } - } - - void visit_loop(Block *body) override { - if (maybe_run) { - body->accept(this); - return; - } - - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - maybe_run = true; - body->accept(this); - maybe_run = false; - body->accept(this); - for (auto &it : origin) { - it.second.merge_from_loop(state_machines[it.first]); - } - for (auto &it : state_machines) { - if (origin.find(it.first) == origin.end()) { - StateMachine state_machine(it.second.get_var(), false); - state_machine.merge_from_loop(it.second); - origin.insert(std::make_pair(it.first, state_machine)); - } - } - state_machines = std::move(origin); - } - - void visit(OffloadedStmt *stmt) override { - if (stmt->task_type == OffloadedTaskType::range_for) { - TI_ASSERT(!maybe_run); - if (!stmt->const_begin) { - TI_ASSERT(state_machines.find(stmt->begin_offset) != - state_machines.end()); - state_machines[stmt->begin_offset].load(); - } - if (!stmt->const_end) { - TI_ASSERT(state_machines.find(stmt->end_offset) != - state_machines.end()); - state_machines[stmt->end_offset].load(); - } - } - if (stmt->body) { - modify_all_state_machines(&StateMachine::begin_offload); - stmt->body->accept(this); - } - } -}; - -class GlobalPtrOptimize : public VariableOptimize { - private: - std::unordered_map> - state_machines; - - public: - using VariableOptimize::visit; - - StateMachine &get_state_machine(Stmt *stmt) override { - return state_machines[stmt->as()->snodes[0]->id][stmt]; - } - - void modify_all_state_machines(void (StateMachine::*func)()) override { - for (auto &i : state_machines) { - for (auto &j : i.second) { - (j.second.*func)(); - } - } - } - - void clear() override { - state_machines.clear(); - } - - void finalize() override { - // do nothing - } - - void visit(GlobalPtrStmt *stmt) override { - TI_ASSERT(stmt->width() == 1); - auto &state_machines_map = state_machines[stmt->snodes[0]->id]; - if (state_machines_map.find(stmt) == state_machines_map.end()) - state_machines_map.insert( - std::make_pair(stmt, StateMachine(stmt, false))); - } - - void visit(AtomicOpStmt *stmt) override { - if (!stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_atomic_op(); - else - get_state_machine(stmt->dest).atomic_op(stmt); - auto dest = stmt->dest->as(); - for (auto &var : state_machines[dest->snodes[0]->id]) { - if (var.first != dest && - irpass::analysis::maybe_same_address(dest, var.first)) { - var.second.maybe_atomic_op(); - } - } - } - - void visit(GlobalStoreStmt *stmt) override { - if (!stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_store(stmt); - else - get_state_machine(stmt->dest).store(stmt); - auto dest = stmt->dest->as(); - for (auto &var : state_machines[dest->snodes[0]->id]) { - if (var.first != dest && - irpass::analysis::maybe_same_address(dest, var.first)) { - var.second.maybe_store(stmt); - } - } - } - - void visit(GlobalLoadStmt *stmt) override { - if (!stmt->src->is()) - return; - if (maybe_run) - get_state_machine(stmt->src).maybe_load(); - else - get_state_machine(stmt->src).load(stmt); - auto dest = stmt->src->as(); - for (auto &var : state_machines[dest->snodes[0]->id]) { - if (var.first != dest && - irpass::analysis::maybe_same_address(dest, var.first)) { - var.second.maybe_load(); - } - } - } - - void visit(IfStmt *if_stmt) override { - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->true_statements) { - if_stmt->true_statements->accept(this); - } - auto true_branch = std::move(state_machines); - - state_machines = origin; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - auto false_branch = std::move(state_machines); - - state_machines = std::move(origin); - for (auto &i : state_machines) { - auto &true_branch_block = true_branch[i.first]; - auto &false_branch_block = false_branch[i.first]; - for (auto &j : i.second) { - j.second.merge_from_if(true_branch_block[j.first], - false_branch_block[j.first]); - } - } - for (auto &i : true_branch) { - for (auto &j : i.second) { - if (state_machines[i.first].find(j.first) == - state_machines[i.first].end()) - state_machines[i.first].insert(j); - } - } - for (auto &i : false_branch) { - for (auto &j : i.second) { - if (state_machines[i.first].find(j.first) == - state_machines[i.first].end()) - state_machines[i.first].insert(j); - } - } - } - - void visit_loop(Block *body) override { - if (maybe_run) { - body->accept(this); - return; - } - - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - maybe_run = true; - body->accept(this); - maybe_run = false; - body->accept(this); - for (auto &i : origin) { - auto &loop_snode = state_machines[i.first]; - for (auto &j : i.second) { - j.second.merge_from_loop(loop_snode[j.first]); - } - } - for (auto &i : state_machines) { - auto &origin_snode = origin[i.first]; - for (auto &j : i.second) { - if (origin_snode.find(j.first) == origin_snode.end()) { - StateMachine state_machine(j.second.get_var(), false); - state_machine.merge_from_loop(j.second); - origin_snode.insert(std::make_pair(j.first, state_machine)); - } - } - } - state_machines = std::move(origin); - } -}; - -class OtherVariableOptimize : public VariableOptimize { - private: - std::unordered_map state_machines; - - public: - using VariableOptimize::visit; - - StateMachine &get_state_machine(Stmt *stmt) override { - if (state_machines.find(stmt) == state_machines.end()) - state_machines.insert(std::make_pair(stmt, StateMachine(stmt, false))); - return state_machines[stmt]; - } - - void modify_all_state_machines(void (StateMachine::*func)()) override { - for (auto &i : state_machines) { - (i.second.*func)(); - } - } - - void clear() override { - state_machines.clear(); - } - - void finalize() override { - // do nothing - } - - void visit(AtomicOpStmt *stmt) override { - if (stmt->dest->is() || stmt->dest->is() || - stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_atomic_op(); - else - get_state_machine(stmt->dest).atomic_op(stmt); - for (auto &var : state_machines) { - if (var.first != stmt->dest && - irpass::analysis::maybe_same_address(stmt->dest, var.first)) { - var.second.maybe_atomic_op(); - } - } - } - - void visit(GlobalStoreStmt *stmt) override { - if (stmt->dest->is()) - return; - if (maybe_run) - get_state_machine(stmt->dest).maybe_store(stmt); - else - get_state_machine(stmt->dest).store(stmt); - for (auto &var : state_machines) { - if (var.first != stmt->dest && - irpass::analysis::maybe_same_address(stmt->dest, var.first)) { - var.second.maybe_store(stmt); - } - } - } - - void visit(GlobalLoadStmt *stmt) override { - if (stmt->src->is()) - return; - if (maybe_run) - get_state_machine(stmt->src).maybe_load(); - else - get_state_machine(stmt->src).load(stmt); - for (auto &var : state_machines) { - if (var.first != stmt->src && - irpass::analysis::maybe_same_address(stmt->src, var.first)) { - var.second.maybe_load(); - } - } - } - - void visit(IfStmt *if_stmt) override { - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->true_statements) { - if_stmt->true_statements->accept(this); - } - auto true_branch = std::move(state_machines); - - state_machines = origin; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - auto false_branch = std::move(state_machines); - - state_machines = std::move(origin); - for (auto &it : state_machines) { - it.second.merge_from_if(true_branch[it.first], false_branch[it.first]); - } - for (auto &it : true_branch) { - if (state_machines.find(it.first) == state_machines.end()) - state_machines.insert(it); - } - for (auto &it : false_branch) { - if (state_machines.find(it.first) == state_machines.end()) - state_machines.insert(it); - } - } - - void visit_loop(Block *body) override { - if (maybe_run) { - body->accept(this); - return; - } - - auto origin = state_machines; - modify_all_state_machines(&StateMachine::begin_if_or_loop); - maybe_run = true; - body->accept(this); - maybe_run = false; - body->accept(this); - for (auto &it : origin) { - it.second.merge_from_loop(state_machines[it.first]); - } - for (auto &it : state_machines) { - if (origin.find(it.first) == origin.end()) { - StateMachine state_machine(it.second.get_var(), false); - state_machine.merge_from_loop(it.second); - origin.insert(std::make_pair(it.first, state_machine)); - } - } - state_machines = std::move(origin); - } -}; - -namespace irpass { -void variable_optimization(IRNode *root, bool after_lower_access) { - TI_AUTO_PROF; - // This pass has been replaced with cfg_optimization. - if (!root->get_config().advanced_optimization) - return; - AllocaOptimize alloca_optimizer; - alloca_optimizer.run(root); - GlobalTempOptimize global_temp_optimizer; - global_temp_optimizer.run(root); - if (after_lower_access) { - OtherVariableOptimize other_variable_optimizer; - other_variable_optimizer.run(root); - } else { - GlobalPtrOptimize global_ptr_optimizer; - global_ptr_optimizer.run(root); - } -} -} // namespace irpass - -TLANG_NAMESPACE_END diff --git a/taichi/transforms/vector_split.cpp b/taichi/transforms/vector_split.cpp deleted file mode 100644 index ea7bcdbc0418a..0000000000000 --- a/taichi/transforms/vector_split.cpp +++ /dev/null @@ -1,337 +0,0 @@ -// Split vectors wider than machine vector width into multiple vectors - -#include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" -#include "taichi/program/program.h" - -#include - -TLANG_NAMESPACE_BEGIN - -class BasicBlockVectorSplit : public IRVisitor { - public: - Block *block; - std::vector statements; - std::vector> splits; - int max_width; - - int current_split_factor; - std::vector current_split; - bool need_split; - bool serial_schedule; - std::unordered_map> origin2split; - - BasicBlockVectorSplit(Block *block, int max_width, bool serial_schedule) - : block(block), max_width(max_width), serial_schedule(serial_schedule) { - // allow_undefined_visitor = true; - // invoke_default_visitor = false; - run(); - } - - int lane_start(int split) { - return split * max_width; - } - - int lane_end(int split) { - return (split + 1) * max_width; - } - - Stmt *lookup(Stmt *old, int index) { - if (origin2split.find(old) == origin2split.end()) { - TI_WARN("VectorSplitter looking for statement outside current block?"); - return old; - } else { - TI_ASSERT(0 <= index); - TI_ASSERT(index < (int)origin2split[old].size()); - return origin2split[old][index]; - } - } - - void run() { - std::vector statements = std::move(block->statements); - for (int i = 0; i < (int)statements.size(); i++) { - auto stmt = statements[i].get(); - if (stmt->width() > max_width) { - TI_ASSERT(stmt->width() % max_width == 0); - current_split_factor = stmt->width() / max_width; - current_split.resize(current_split_factor); - need_split = true; - stmt->accept(this); - origin2split[stmt] = std::vector(current_split_factor, nullptr); - for (int j = 0; j < current_split_factor; j++) { - current_split[j]->ret_type = - Program::get_type_factory().get_vector_type(max_width, - stmt->element_type()); - origin2split[stmt][j] = current_split[j].get(); - } - splits.push_back(std::move(current_split)); - } else { // recreate a statement anyway since the original one may be - // pointing to unknown statements - current_split_factor = 1; - current_split.resize(current_split_factor); - need_split = false; - stmt->accept(this); - origin2split[stmt] = std::vector(1, nullptr); - current_split[0]->element_type() = stmt->element_type(); - current_split[0]->ret_type = - Program::get_type_factory().get_vector_type(stmt->width(), - stmt->element_type()); - origin2split[stmt][0] = current_split[0].get(); - std::vector split; - split.push_back(std::move(current_split[0])); - splits.push_back(std::move(split)); - } - } - block->statements.clear(); - if (!serial_schedule) { - // finish vectors one by one - for (int i = 0; i < (int)splits.size(); i++) { - for (int j = 0;; j++) { - bool modified = false; - if (j < (int)splits[i].size()) { - block->insert(std::move(splits[i][j])); - modified = true; - } - if (!modified) { - break; - } - } - } - } else { - for (int j = 0;; j++) { - bool modified = false; - for (int i = 0; i < (int)splits.size(); i++) { - if (j < (int)splits[i].size()) { - block->insert(std::move(splits[i][j])); - modified = true; - } - } - if (!modified) { - break; - } - } - } - for (int i = 0; i < (int)block->statements.size(); i++) { - auto stmt_ = block->statements[i].get(); - if (stmt_->is()) { - auto stmt = stmt_->as(); - for (int l = 0; l < stmt->width(); l++) { - auto *old_var = stmt->src[l].var; - if (origin2split.find(old_var) != origin2split.end()) { - auto new_var = - origin2split[old_var][stmt->src[l].offset / max_width]; - stmt->src[l].var = new_var; - stmt->src[l].offset %= max_width; - // TI_WARN("replaced..."); - } - } - } - } - } - - // Visitors: set current_split[0...current_split_factor] - - void visit(GlobalPtrStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - std::vector indices; - for (int j = 0; j < (int)stmt->indices.size(); j++) { - indices.push_back(lookup(stmt->indices[j], i)); - } - current_split[i] = Stmt::make( - stmt->snodes.slice(lane_start(i), - need_split ? lane_end(i) : stmt->width()), - indices); - } - } - - void visit(ConstStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(stmt->val.slice( - lane_start(i), need_split ? lane_end(i) : stmt->width())); - } - } - - void visit(AllocaStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) - current_split[i] = Stmt::make( - need_split ? max_width : stmt->width(), stmt->element_type()); - } - - void visit(ElementShuffleStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - LaneAttribute ptr; - int new_width = need_split ? max_width : stmt->width(); - ptr.resize(new_width); - for (int j = 0; j < new_width; j++) { - VectorElement addr(stmt->elements[lane_start(i) + j]); - if (origin2split.find(addr.stmt) == origin2split.end()) { - ptr[j] = addr; - } else { - ptr[j].stmt = lookup(addr.stmt, addr.index / max_width); - ptr[j].index = addr.index % max_width; - } - } - current_split[i] = Stmt::make(ptr); - } - } - - void visit(LocalLoadStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - LaneAttribute ptr; - int new_width = need_split ? max_width : stmt->width(); - ptr.reserve(new_width); - for (int j = 0; j < new_width; j++) { - LocalAddress addr(stmt->src[lane_start(i) + j]); - if (origin2split.find(addr.var) == origin2split.end()) { - ptr.push_back(addr); - } else { - ptr.push_back(LocalAddress(lookup(addr.var, addr.offset / max_width), - addr.offset % max_width)); - } - } - current_split[i] = Stmt::make(ptr); - } - } - - void visit(LocalStoreStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->dest, i), - lookup(stmt->val, i)); - } - } - - void visit(GlobalLoadStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->src, i)); - } - } - - void visit(GlobalStoreStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->dest, i), - lookup(stmt->val, i)); - } - } - - void visit(UnaryOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = - Stmt::make(stmt->op_type, lookup(stmt->operand, i)); - current_split[i]->as()->cast_type = - stmt->as()->cast_type; - } - } - - void visit(BinaryOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make( - stmt->op_type, lookup(stmt->lhs, i), lookup(stmt->rhs, i)); - } - } - - void visit(TernaryOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = - Stmt::make(stmt->op_type, lookup(stmt->op1, i), - lookup(stmt->op2, i), lookup(stmt->op3, i)); - } - } - - void visit(AtomicOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make( - stmt->op_type, lookup(stmt->dest, i), lookup(stmt->val, i)); - } - } - - void visit(PrintStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - std::vector new_contents; - std::transform(stmt->contents.begin(), stmt->contents.end(), - std::back_inserter(new_contents), - [=](auto const &x) -> PrintStmt::EntryType { - if (std::holds_alternative(x)) { - return lookup(std::get(x), i); - } else { - return x; - } - }); - current_split[i] = Stmt::make(new_contents); - } - } - - void visit(RandStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(stmt->element_type()); - } - } - - void visit(WhileControlStmt *stmt) override { - TI_ASSERT(need_split == false); - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->mask, i), - lookup(stmt->cond, i)); - } - } -}; - -// Goal: eliminate vectors that are longer than physical vector width (e.g. 8 -// on AVX2) -class VectorSplit : public IRVisitor { - public: - int max_width; - bool serial_schedule; - - VectorSplit(IRNode *node, int max_width, bool serial_schedule) - : max_width(max_width), serial_schedule(serial_schedule) { - allow_undefined_visitor = true; - invoke_default_visitor = true; - node->accept(this); - } - - void visit(Block *block) override { - if (!block->has_container_statements()) { - bool all_within_width = true; - for (auto &stmt : block->statements) { - if (stmt->width() > max_width) { - all_within_width = false; - } - } - if (!all_within_width) - BasicBlockVectorSplit(block, max_width, serial_schedule); - } else { - for (auto &stmt : block->statements) { - stmt->accept(this); - } - } - } - - void visit(IfStmt *if_stmt) override { - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(RangeForStmt *for_stmt) override { - for_stmt->body->accept(this); - } - - void visit(WhileStmt *stmt) override { - stmt->body->accept(this); - } -}; - -namespace irpass { - -void vector_split(IRNode *root, int max_width, bool serial_schedule) { - TI_AUTO_PROF; - VectorSplit(root, max_width, serial_schedule); -} - -} // namespace irpass - -TLANG_NAMESPACE_END diff --git a/taichi/transforms/whole_kernel_cse.cpp b/taichi/transforms/whole_kernel_cse.cpp index bc8b8dfc412c1..1ecb6c1954032 100644 --- a/taichi/transforms/whole_kernel_cse.cpp +++ b/taichi/transforms/whole_kernel_cse.cpp @@ -12,27 +12,27 @@ TLANG_NAMESPACE_BEGIN // A helper class to maintain WholeKernelCSE::visited class MarkUndone : public BasicStmtVisitor { private: - std::unordered_set *const visited; - Stmt *const modified_operand; + std::unordered_set *const visited_; + Stmt *const modified_operand_; public: using BasicStmtVisitor::visit; MarkUndone(std::unordered_set *visited, Stmt *modified_operand) - : visited(visited), modified_operand(modified_operand) { + : visited_(visited), modified_operand_(modified_operand) { allow_undefined_visitor = true; invoke_default_visitor = true; } void visit(Stmt *stmt) override { - if (stmt->has_operand(modified_operand)) { - visited->erase(stmt->instance_id); + if (stmt->has_operand(modified_operand_)) { + visited_->erase(stmt->instance_id); } } void preprocess_container_stmt(Stmt *stmt) override { - if (stmt->has_operand(modified_operand)) { - visited->erase(stmt->instance_id); + if (stmt->has_operand(modified_operand_)) { + visited_->erase(stmt->instance_id); } } @@ -45,11 +45,11 @@ class MarkUndone : public BasicStmtVisitor { // Whole Kernel Common Subexpression Elimination class WholeKernelCSE : public BasicStmtVisitor { private: - std::unordered_set visited; + std::unordered_set visited_; // each scope corresponds to an unordered_set - std::vector>> - visible_stmts; - DelayedIRModifier modifier; + std::vector > > + visible_stmts_; + DelayedIRModifier modifier_; public: using BasicStmtVisitor::visit; @@ -60,16 +60,38 @@ class WholeKernelCSE : public BasicStmtVisitor { } bool is_done(Stmt *stmt) { - return visited.find(stmt->instance_id) != visited.end(); + return visited_.find(stmt->instance_id) != visited_.end(); } void set_done(Stmt *stmt) { - visited.insert(stmt->instance_id); + visited_.insert(stmt->instance_id); + } + + static std::size_t operand_hash(const Stmt *stmt) { + std::size_t hash_code{0}; + auto hash_type = + std::hash{}(std::type_index(typeid(stmt))); + if (stmt->is() || stmt->is()) { + // special cases in common_statement_eliminable() + return hash_type; + } + auto op = stmt->get_operands(); + for (auto &x : op) { + if (x == nullptr) + continue; + // Hash the addresses of the operand pointers. + hash_code = + (hash_code * 33) ^ + (std::hash{}(reinterpret_cast(x))); + } + return hash_type ^ hash_code; } static bool common_statement_eliminable(Stmt *this_stmt, Stmt *prev_stmt) { // Is this_stmt eliminable given that prev_stmt appears before it and has // the same type with it? + if (this_stmt->type() != prev_stmt->type()) + return false; if (this_stmt->is()) { auto this_ptr = this_stmt->as(); auto prev_ptr = prev_stmt->as(); @@ -95,31 +117,35 @@ class WholeKernelCSE : public BasicStmtVisitor { void visit(Stmt *stmt) override { if (!stmt->common_statement_eliminable()) return; + // container_statement does not need to be CSE-ed + if (stmt->is_container_statement()) + return; // Generic visitor for all CSE-able statements. + std::size_t hash_value = operand_hash(stmt); if (is_done(stmt)) { - visible_stmts.back()[std::type_index(typeid(*stmt))].insert(stmt); + visible_stmts_.back()[hash_value].insert(stmt); return; } - for (auto &scope : visible_stmts) { - for (auto &prev_stmt : scope[std::type_index(typeid(*stmt))]) { + for (auto &scope : visible_stmts_) { + for (auto &prev_stmt : scope[hash_value]) { if (common_statement_eliminable(stmt, prev_stmt)) { - MarkUndone::run(&visited, stmt); - stmt->replace_with(prev_stmt); - modifier.erase(stmt); + MarkUndone::run(&visited_, stmt); + stmt->replace_usages_with(prev_stmt); + modifier_.erase(stmt); return; } } } - visible_stmts.back()[std::type_index(typeid(*stmt))].insert(stmt); + visible_stmts_.back()[hash_value].insert(stmt); set_done(stmt); } void visit(Block *stmt_list) override { - visible_stmts.emplace_back(); + visible_stmts_.emplace_back(); for (auto &stmt : stmt_list->statements) { stmt->accept(this); } - visible_stmts.pop_back(); + visible_stmts_.pop_back(); } void visit(IfStmt *if_stmt) override { @@ -148,7 +174,7 @@ class WholeKernelCSE : public BasicStmtVisitor { irpass::replace_all_usages_with(false_clause.get(), false_clause->statements[0].get(), common_stmt.get()); - modifier.insert_before(if_stmt, std::move(common_stmt)); + modifier_.insert_before(if_stmt, std::move(common_stmt)); false_clause->erase(0); } if (!true_clause->statements.empty() && @@ -161,7 +187,7 @@ class WholeKernelCSE : public BasicStmtVisitor { irpass::replace_all_usages_with(false_clause.get(), false_clause->statements.back().get(), common_stmt.get()); - modifier.insert_after(if_stmt, std::move(common_stmt)); + modifier_.insert_after(if_stmt, std::move(common_stmt)); false_clause->erase((int)false_clause->size() - 1); } } @@ -177,7 +203,7 @@ class WholeKernelCSE : public BasicStmtVisitor { bool modified = false; while (true) { node->accept(&eliminator); - if (eliminator.modifier.modify_ir()) + if (eliminator.modifier_.modify_ir()) modified = true; else break; diff --git a/taichi/ui/backends/vulkan/app_context.cpp b/taichi/ui/backends/vulkan/app_context.cpp index 6fc6a460b7dcb..d8626a7fbb7df 100644 --- a/taichi/ui/backends/vulkan/app_context.cpp +++ b/taichi/ui/backends/vulkan/app_context.cpp @@ -1,6 +1,7 @@ #include "taichi/ui/utils/utils.h" #include "taichi/ui/backends/vulkan/app_context.h" #include "taichi/ui/backends/vulkan/swap_chain.h" +#include "taichi/program/program.h" #include @@ -9,9 +10,19 @@ TI_UI_NAMESPACE_BEGIN namespace vulkan { using namespace taichi::lang::vulkan; +using namespace taichi::lang; namespace { std::vector get_required_instance_extensions() { +#ifdef ANDROID + std::vector extensions; + + extensions.push_back(VK_KHR_SURFACE_EXTENSION_NAME); + extensions.push_back(VK_KHR_ANDROID_SURFACE_EXTENSION_NAME); + extensions.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); + + return extensions; +#else uint32_t glfw_ext_count = 0; const char **glfw_extensions; glfw_extensions = glfwGetRequiredInstanceExtensions(&glfw_ext_count); @@ -22,26 +33,29 @@ std::vector get_required_instance_extensions() { extensions.push_back(glfw_extensions[i]); } - // EmbeddedVulkanDevice will check that these are supported + // VulkanDeviceCreator will check that these are supported extensions.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); extensions.push_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME); extensions.push_back(VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_EXTENSION_NAME); extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME); return extensions; +#endif } std::vector get_required_device_extensions() { - static std::vector extensions{ - VK_KHR_SWAPCHAIN_EXTENSION_NAME, - VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME, - VK_KHR_EXTERNAL_SEMAPHORE_EXTENSION_NAME, + static std::vector extensions { + VK_KHR_SWAPCHAIN_EXTENSION_NAME, +#if !defined(ANDROID) + VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME, + VK_KHR_EXTERNAL_SEMAPHORE_EXTENSION_NAME, #ifdef _WIN64 - VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME, - VK_KHR_EXTERNAL_SEMAPHORE_WIN32_EXTENSION_NAME, + VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME, + VK_KHR_EXTERNAL_SEMAPHORE_WIN32_EXTENSION_NAME, #else - VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME, - VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME, + VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME, + VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME, +#endif #endif }; @@ -49,39 +63,80 @@ std::vector get_required_device_extensions() { } } // namespace -void AppContext::init(GLFWwindow *glfw_window, const AppConfig &config) { - glfw_window_ = glfw_window; +void AppContext::init(Program *prog, + TaichiWindow *window, + const AppConfig &config) { + taichi_window_ = window; + prog_ = prog; this->config = config; - EmbeddedVulkanDevice::Params evd_params; - evd_params.additional_instance_extensions = - get_required_instance_extensions(); - evd_params.additional_device_extensions = get_required_device_extensions(); - evd_params.is_for_ui = true; - evd_params.surface_creator = [&](VkInstance instance) -> VkSurfaceKHR { - VkSurfaceKHR surface = VK_NULL_HANDLE; - if (glfwCreateWindowSurface(instance, glfw_window, nullptr, &surface) != - VK_SUCCESS) { - throw std::runtime_error("failed to create window surface!"); - } - return surface; - }; - vulkan_device_ = std::make_unique(evd_params); + + // Create a Vulkan device if the original configuration is not for Vulkan or + // there is no active current program (usage from external library for AOT + // modules for example). + if (config.ti_arch != Arch::vulkan || prog == nullptr) { + VulkanDeviceCreator::Params evd_params; + evd_params.additional_instance_extensions = + get_required_instance_extensions(); + evd_params.additional_device_extensions = get_required_device_extensions(); + evd_params.is_for_ui = config.show_window; + evd_params.surface_creator = [&](VkInstance instance) -> VkSurfaceKHR { + VkSurfaceKHR surface = VK_NULL_HANDLE; +#ifdef ANDROID + VkAndroidSurfaceCreateInfoKHR createInfo{ + .sType = VK_STRUCTURE_TYPE_ANDROID_SURFACE_CREATE_INFO_KHR, + .pNext = nullptr, + .flags = 0, + .window = window}; + + vkCreateAndroidSurfaceKHR(instance, &createInfo, nullptr, &surface); +#else + if (glfwCreateWindowSurface(instance, window, nullptr, &surface) != + VK_SUCCESS) { + throw std::runtime_error("failed to create window surface!"); + } +#endif + return surface; + }; + embedded_vulkan_device_ = std::make_unique(evd_params); + } else { + vulkan_device_ = static_cast(prog->get_graphics_device()); + } } taichi::lang::vulkan::VulkanDevice &AppContext::device() { - return *(vulkan_device_->device()); + if (vulkan_device_) { + return *vulkan_device_; + } + return *(embedded_vulkan_device_->device()); } const taichi::lang::vulkan::VulkanDevice &AppContext::device() const { - return *(vulkan_device_->device()); + if (vulkan_device_) { + return *vulkan_device_; + } + return *(embedded_vulkan_device_->device()); } void AppContext::cleanup() { - vulkan_device_.reset(); + if (embedded_vulkan_device_) { + embedded_vulkan_device_.reset(); + } +} + +bool AppContext::requires_export_sharing() const { + // only the cuda backends needs export_sharing to interop with vk + // with other backends (e.g. vulkan backend on mac), turning export_sharing to + // true leads to crashes + // TODO: investigate this, and think of a more universal solution. + return config.ti_arch == Arch::cuda; +} + +TaichiWindow *AppContext::taichi_window() const { + return taichi_window_; } -GLFWwindow *AppContext::glfw_window() const { - return glfw_window_; +lang::Program *AppContext::prog() const { + return prog_; } } // namespace vulkan diff --git a/taichi/ui/backends/vulkan/app_context.h b/taichi/ui/backends/vulkan/app_context.h index 8e0265ae5c636..6403db378f76e 100644 --- a/taichi/ui/backends/vulkan/app_context.h +++ b/taichi/ui/backends/vulkan/app_context.h @@ -1,32 +1,54 @@ #pragma once #include "taichi/ui/common/app_config.h" #include -#include "taichi/backends/vulkan/embedded_device.h" -#include "taichi/backends/vulkan/loader.h" +#include "taichi/backends/vulkan/vulkan_device_creator.h" +#include "taichi/backends/vulkan/vulkan_loader.h" #include "taichi/backends/vulkan/vulkan_device.h" #include "taichi/ui/backends/vulkan/swap_chain.h" +#ifdef ANDROID +#include +#endif + +namespace taichi { +namespace lang { +class Program; +} // namespace lang +} // namespace taichi TI_UI_NAMESPACE_BEGIN +#ifdef ANDROID +using TaichiWindow = ANativeWindow; +#else +using TaichiWindow = GLFWwindow; +#endif + namespace vulkan { -class AppContext { +class TI_DLL_EXPORT AppContext { public: - void init(GLFWwindow *glfw_window, const AppConfig &config); + void init(lang::Program *prog, TaichiWindow *window, const AppConfig &config); void cleanup(); - GLFWwindow *glfw_window() const; + TaichiWindow *taichi_window() const; + lang::Program *prog() const; taichi::lang::vulkan::VulkanDevice &device(); const taichi::lang::vulkan::VulkanDevice &device() const; + bool requires_export_sharing() const; AppConfig config; private: - std::unique_ptr vulkan_device_{ - nullptr}; + std::unique_ptr + embedded_vulkan_device_{nullptr}; + + // not owned + taichi::lang::vulkan::VulkanDevice *vulkan_device_{nullptr}; + + TaichiWindow *taichi_window_{nullptr}; - GLFWwindow *glfw_window_{nullptr}; + lang::Program *prog_{nullptr}; }; } // namespace vulkan diff --git a/taichi/ui/backends/vulkan/canvas.h b/taichi/ui/backends/vulkan/canvas.h index 97779dc9b1249..287c195362143 100644 --- a/taichi/ui/backends/vulkan/canvas.h +++ b/taichi/ui/backends/vulkan/canvas.h @@ -35,7 +35,7 @@ TI_UI_NAMESPACE_BEGIN namespace vulkan { -class Canvas final : public CanvasBase { +class TI_DLL_EXPORT Canvas final : public CanvasBase { public: Canvas(Renderer *renderer); diff --git a/taichi/ui/backends/vulkan/gui.cpp b/taichi/ui/backends/vulkan/gui.cpp index 76534dac24e85..0d4962a6e5811 100644 --- a/taichi/ui/backends/vulkan/gui.cpp +++ b/taichi/ui/backends/vulkan/gui.cpp @@ -15,8 +15,9 @@ PFN_vkVoidFunction load_vk_function_for_gui(const char *name, void *userData) { return result; } -Gui::Gui(AppContext *app_context, GLFWwindow *window) { +Gui::Gui(AppContext *app_context, SwapChain *swap_chain, TaichiWindow *window) { app_context_ = app_context; + swap_chain_ = swap_chain; create_descriptor_pool(); @@ -26,7 +27,13 @@ Gui::Gui(AppContext *app_context, GLFWwindow *window) { ImGui::StyleColorsDark(); - ImGui_ImplGlfw_InitForVulkan(window, true); + if (app_context->config.show_window) { +#ifdef ANDROID + ImGui_ImplAndroid_Init(window); +#else + ImGui_ImplGlfw_InitForVulkan(window, true); +#endif + } } void Gui::init_render_resources(VkRenderPass render_pass) { @@ -44,8 +51,8 @@ void Gui::init_render_resources(VkRenderPass render_pass) { init_info.PipelineCache = VK_NULL_HANDLE; init_info.DescriptorPool = descriptor_pool_; init_info.Allocator = VK_NULL_HANDLE; - init_info.MinImageCount = 1; - init_info.ImageCount = 1; + init_info.MinImageCount = swap_chain_->surface().get_image_count(); + init_info.ImageCount = swap_chain_->surface().get_image_count(); ImGui_ImplVulkan_Init(&init_info, render_pass); render_pass_ = render_pass; @@ -95,7 +102,20 @@ void Gui::prepare_for_next_frame() { return; } ImGui_ImplVulkan_NewFrame(); - ImGui_ImplGlfw_NewFrame(); + if (app_context_->config.show_window) { +#ifdef ANDROID + ImGui_ImplAndroid_NewFrame(); +#else + ImGui_ImplGlfw_NewFrame(); +#endif + } else { + // io.DisplaySize is set during ImGui_ImplGlfw_NewFrame() + // but since we're headless, we do it explicitly here + auto w = app_context_->config.width; + auto h = app_context_->config.height; + ImGuiIO &io = ImGui::GetIO(); + io.DisplaySize = ImVec2((float)w, (float)h); + } ImGui::NewFrame(); is_empty_ = true; } @@ -130,7 +150,7 @@ void Gui::text(std::string text) { if (!initialized()) { return; } - ImGui::Text(text.c_str()); + ImGui::Text("%s", text.c_str()); } bool Gui::checkbox(std::string name, bool old_value) { if (!initialized()) { @@ -183,7 +203,13 @@ void Gui::cleanup_render_resources() { } void Gui::cleanup() { - ImGui_ImplGlfw_Shutdown(); + if (app_context_->config.show_window) { +#ifdef ANDROID + ImGui_ImplAndroid_Shutdown(); +#else + ImGui_ImplGlfw_Shutdown(); +#endif + } cleanup_render_resources(); ImGui::DestroyContext(); } diff --git a/taichi/ui/backends/vulkan/gui.h b/taichi/ui/backends/vulkan/gui.h index 06b588aa08a51..8896c8d4118c4 100644 --- a/taichi/ui/backends/vulkan/gui.h +++ b/taichi/ui/backends/vulkan/gui.h @@ -7,7 +7,11 @@ #endif #include +#ifdef ANDROID +#include +#else #include +#endif #include #include "taichi/ui/backends/vulkan/app_context.h" #include "taichi/ui/common/gui_base.h" @@ -17,9 +21,9 @@ TI_UI_NAMESPACE_BEGIN namespace vulkan { -class Gui final : public GuiBase { +class TI_DLL_EXPORT Gui final : public GuiBase { public: - Gui(AppContext *app_context, GLFWwindow *window); + Gui(AppContext *app_context, SwapChain *swap_chain, TaichiWindow *window); void cleanup(); void init_render_resources(VkRenderPass render_pass); @@ -55,6 +59,7 @@ class Gui final : public GuiBase { private: bool is_empty_; AppContext *app_context_; + SwapChain *swap_chain_; VkRenderPass render_pass_{VK_NULL_HANDLE}; diff --git a/taichi/ui/backends/vulkan/renderable.cpp b/taichi/ui/backends/vulkan/renderable.cpp index 130c01845e7ed..4ba2da75ffed1 100644 --- a/taichi/ui/backends/vulkan/renderable.cpp +++ b/taichi/ui/backends/vulkan/renderable.cpp @@ -1,4 +1,5 @@ #include "taichi/ui/backends/vulkan/renderable.h" +#include "taichi/program/program.h" #include "taichi/ui/utils/utils.h" TI_UI_NAMESPACE_BEGIN @@ -38,6 +39,13 @@ void Renderable::init_buffers() { } void Renderable::update_data(const RenderableInfo &info) { + // We might not have a current program if GGUI is used in external apps to + // load AOT modules + Program *prog = app_context_->prog(); + if (prog) { + prog->synchronize(); + } + int num_vertices = info.vbo.shape[0]; int num_indices; if (info.indices.valid) { @@ -51,23 +59,31 @@ void Renderable::update_data(const RenderableInfo &info) { } else { num_indices = 1; } - if (num_vertices > config_.vertices_count || - num_indices > config_.indices_count) { + + config_.vertices_count = num_vertices; + config_.indices_count = num_indices; + + if (num_vertices > config_.max_vertices_count || + num_indices > config_.max_indices_count) { free_buffers(); - config_.vertices_count = num_vertices; - config_.indices_count = num_indices; + config_.max_vertices_count = num_vertices; + config_.max_indices_count = num_indices; init_buffers(); } - Program &program = get_current_program(); - DevicePtr vbo_dev_ptr = get_device_ptr(&program, info.vbo.snode); + // If there is no current program, VBO information should be provided directly + // instead of accessing through the current SNode + DevicePtr vbo_dev_ptr = info.vbo.dev_alloc.get_ptr(); + if (prog) { + vbo_dev_ptr = get_device_ptr(prog, info.vbo.snode); + } + uint64_t vbo_size = sizeof(Vertex) * num_vertices; Device::MemcpyCapability memcpy_cap = Device::check_memcpy_capability( vertex_buffer_.get_ptr(), vbo_dev_ptr, vbo_size); if (memcpy_cap == Device::MemcpyCapability::Direct) { - Device::memcpy_direct(vertex_buffer_.get_ptr(), vbo_dev_ptr.get_ptr(), - vbo_size); + Device::memcpy_direct(vertex_buffer_.get_ptr(), vbo_dev_ptr, vbo_size); } else if (memcpy_cap == Device::MemcpyCapability::RequiresStagingBuffer) { Device::memcpy_via_staging(vertex_buffer_.get_ptr(), staging_vertex_buffer_.get_ptr(), vbo_dev_ptr, @@ -78,7 +94,7 @@ void Renderable::update_data(const RenderableInfo &info) { if (info.indices.valid) { indexed_ = true; - DevicePtr ibo_dev_ptr = get_device_ptr(&program, info.indices.snode); + DevicePtr ibo_dev_ptr = get_device_ptr(prog, info.indices.snode); uint64_t ibo_size = num_indices * sizeof(int); if (memcpy_cap == Device::MemcpyCapability::Direct) { Device::memcpy_direct(index_buffer_.get_ptr(), ibo_dev_ptr, ibo_size); @@ -130,16 +146,17 @@ void Renderable::create_graphics_pipeline() { {0, 0, BufferFormat::rgb32f, offsetof(Vertex, pos)}, {1, 0, BufferFormat::rgb32f, offsetof(Vertex, normal)}, {2, 0, BufferFormat::rg32f, offsetof(Vertex, texCoord)}, - {3, 0, BufferFormat::rgb32f, offsetof(Vertex, color)}}; + {3, 0, BufferFormat::rgba32f, offsetof(Vertex, color)}}; pipeline_ = app_context_->device().create_raster_pipeline( source, raster_params, vertex_inputs, vertex_attribs); } void Renderable::create_vertex_buffer() { - size_t buffer_size = sizeof(Vertex) * config_.vertices_count; + size_t buffer_size = sizeof(Vertex) * config_.max_vertices_count; - Device::AllocParams vb_params{buffer_size, false, false, true, + Device::AllocParams vb_params{buffer_size, false, false, + app_context_->requires_export_sharing(), AllocUsage::Vertex}; vertex_buffer_ = app_context_->device().allocate_memory(vb_params); @@ -150,9 +167,10 @@ void Renderable::create_vertex_buffer() { } void Renderable::create_index_buffer() { - size_t buffer_size = sizeof(int) * config_.indices_count; + size_t buffer_size = sizeof(int) * config_.max_indices_count; - Device::AllocParams ib_params{buffer_size, false, false, true, + Device::AllocParams ib_params{buffer_size, false, false, + app_context_->requires_export_sharing(), AllocUsage::Index}; index_buffer_ = app_context_->device().allocate_memory(ib_params); diff --git a/taichi/ui/backends/vulkan/renderable.h b/taichi/ui/backends/vulkan/renderable.h index 84c675784a4df..16e5fae6045ce 100644 --- a/taichi/ui/backends/vulkan/renderable.h +++ b/taichi/ui/backends/vulkan/renderable.h @@ -25,6 +25,8 @@ TI_UI_NAMESPACE_BEGIN namespace vulkan { struct RenderableConfig { + int max_vertices_count; + int max_indices_count; int vertices_count; int indices_count; size_t ubo_size; diff --git a/taichi/ui/backends/vulkan/renderables/circles.cpp b/taichi/ui/backends/vulkan/renderables/circles.cpp index 2198056df263f..9b2a36e695620 100644 --- a/taichi/ui/backends/vulkan/renderables/circles.cpp +++ b/taichi/ui/backends/vulkan/renderables/circles.cpp @@ -18,6 +18,8 @@ void Circles::update_data(const CirclesInfo &info) { void Circles::init_circles(AppContext *app_context, int vertices_count) { RenderableConfig config = { + vertices_count, + 1, vertices_count, 1, sizeof(UniformBufferObject), diff --git a/taichi/ui/backends/vulkan/renderables/lines.cpp b/taichi/ui/backends/vulkan/renderables/lines.cpp index e18b3c6e02ca6..773bb0e12fb4d 100644 --- a/taichi/ui/backends/vulkan/renderables/lines.cpp +++ b/taichi/ui/backends/vulkan/renderables/lines.cpp @@ -23,6 +23,8 @@ void Lines::init_lines(AppContext *app_context, int vertices_count, int indices_count) { RenderableConfig config = { + vertices_count, + indices_count, vertices_count, indices_count, sizeof(UniformBufferObject), diff --git a/taichi/ui/backends/vulkan/renderables/mesh.cpp b/taichi/ui/backends/vulkan/renderables/mesh.cpp index cd29a6670e6a7..b632ec0e35cf8 100644 --- a/taichi/ui/backends/vulkan/renderables/mesh.cpp +++ b/taichi/ui/backends/vulkan/renderables/mesh.cpp @@ -26,6 +26,8 @@ void Mesh::update_ubo(const MeshInfo &info, const Scene &scene) { } void Mesh::update_data(const MeshInfo &info, const Scene &scene) { + Renderable::update_data(info.renderable_info); + size_t correct_ssbo_size = scene.point_lights_.size() * sizeof(PointLight); if (config_.ssbo_size != correct_ssbo_size) { resize_storage_buffers(correct_ssbo_size); @@ -37,8 +39,6 @@ void Mesh::update_data(const MeshInfo &info, const Scene &scene) { app_context_->device().unmap(storage_buffer_); } - Renderable::update_data(info.renderable_info); - update_ubo(info, scene); } @@ -46,6 +46,8 @@ void Mesh::init_mesh(AppContext *app_context, int vertices_count, int indices_count) { RenderableConfig config = { + vertices_count, + indices_count, vertices_count, indices_count, sizeof(UniformBufferObject), diff --git a/taichi/ui/backends/vulkan/renderables/particles.cpp b/taichi/ui/backends/vulkan/renderables/particles.cpp index 18bb90e472c6e..699bb9410cf0a 100644 --- a/taichi/ui/backends/vulkan/renderables/particles.cpp +++ b/taichi/ui/backends/vulkan/renderables/particles.cpp @@ -32,6 +32,7 @@ void Particles::update_ubo(glm::vec3 color, } void Particles::update_data(const ParticlesInfo &info, const Scene &scene) { + Renderable::update_data(info.renderable_info); size_t correct_ssbo_size = scene.point_lights_.size() * sizeof(PointLight); if (config_.ssbo_size != correct_ssbo_size) { resize_storage_buffers(correct_ssbo_size); @@ -43,8 +44,6 @@ void Particles::update_data(const ParticlesInfo &info, const Scene &scene) { app_context_->device().unmap(storage_buffer_); } - Renderable::update_data(info.renderable_info); - update_ubo(info.color, info.renderable_info.has_per_vertex_color, info.radius, scene); } @@ -52,7 +51,9 @@ void Particles::update_data(const ParticlesInfo &info, const Scene &scene) { void Particles::init_particles(AppContext *app_context, int vertices_count) { RenderableConfig config = { vertices_count, + 1, vertices_count, + 1, sizeof(UniformBufferObject), 1, app_context->config.package_path + "/shaders/Particles_vk_vert.spv", diff --git a/taichi/ui/backends/vulkan/renderables/set_image.cpp b/taichi/ui/backends/vulkan/renderables/set_image.cpp index 3d42b0c30677b..316297849396e 100644 --- a/taichi/ui/backends/vulkan/renderables/set_image.cpp +++ b/taichi/ui/backends/vulkan/renderables/set_image.cpp @@ -1,7 +1,10 @@ #include "set_image.h" +#include "taichi/program/program.h" #include "taichi/ui/utils/utils.h" +using taichi::lang::Program; + TI_UI_NAMESPACE_BEGIN namespace vulkan { @@ -25,6 +28,9 @@ void SetImage::update_ubo(float x_factor, float y_factor) { } void SetImage::update_data(const SetImageInfo &info) { + Program *prog = app_context_->prog(); + prog->synchronize(); + const FieldInfo &img = info.img; int new_width = get_correct_dimension(img.shape[0]); @@ -43,8 +49,7 @@ void SetImage::update_data(const SetImageInfo &info) { app_context_->device().image_transition(texture_, ImageLayout::shader_read, ImageLayout::transfer_dst); - Program &program = get_current_program(); - DevicePtr img_dev_ptr = get_device_ptr(&program, img.snode); + DevicePtr img_dev_ptr = get_device_ptr(prog, img.snode); uint64_t img_size = pixels * 4; Device::MemcpyCapability memcpy_cap = Device::check_memcpy_capability( @@ -82,6 +87,8 @@ void SetImage::init_set_image(AppContext *app_context, int img_width, int img_height) { RenderableConfig config = { + 6, + 6, 6, 6, sizeof(UniformBufferObject), @@ -124,8 +131,9 @@ void SetImage::create_texture() { cpu_staging_buffer_ = app_context_->device().allocate_memory(cpu_staging_buffer_params); - Device::AllocParams gpu_staging_buffer_params{image_size, false, false, true, - AllocUsage::Uniform}; + Device::AllocParams gpu_staging_buffer_params{ + image_size, false, false, app_context_->requires_export_sharing(), + AllocUsage::Uniform}; gpu_staging_buffer_ = app_context_->device().allocate_memory(gpu_staging_buffer_params); } diff --git a/taichi/ui/backends/vulkan/renderables/triangles.cpp b/taichi/ui/backends/vulkan/renderables/triangles.cpp index c3d1b7b901400..af802c0d0cadb 100644 --- a/taichi/ui/backends/vulkan/renderables/triangles.cpp +++ b/taichi/ui/backends/vulkan/renderables/triangles.cpp @@ -19,6 +19,8 @@ void Triangles::init_triangles(AppContext *app_context, int vertices_count, int indices_count) { RenderableConfig config = { + vertices_count, + indices_count, vertices_count, indices_count, sizeof(UniformBufferObject), diff --git a/taichi/ui/backends/vulkan/renderer.cpp b/taichi/ui/backends/vulkan/renderer.cpp index f7d10b0daadce..c95d92547cf30 100644 --- a/taichi/ui/backends/vulkan/renderer.cpp +++ b/taichi/ui/backends/vulkan/renderer.cpp @@ -1,6 +1,8 @@ #include "renderer.h" #include "taichi/ui/utils/utils.h" +using taichi::lang::Program; + TI_UI_NAMESPACE_BEGIN namespace vulkan { @@ -8,8 +10,10 @@ namespace vulkan { using namespace taichi::lang; using namespace taichi::lang::vulkan; -void Renderer::init(GLFWwindow *window, const AppConfig &config) { - app_context_.init(window, config); +void Renderer::init(Program *prog, + TaichiWindow *window, + const AppConfig &config) { + app_context_.init(prog, window, config); swap_chain_.init(&app_context_); } @@ -80,12 +84,23 @@ void Renderer::scene(Scene *scene) { } float aspect_ratio = swap_chain_.width() / (float)swap_chain_.height(); scene->update_ubo(aspect_ratio); - for (int i = 0; i < scene->mesh_infos_.size(); ++i) { - mesh(scene->mesh_infos_[i], scene); - } - for (int i = 0; i < scene->particles_infos_.size(); ++i) { - particles(scene->particles_infos_[i], scene); + + int object_count = scene->mesh_infos_.size() + scene->particles_infos_.size(); + int mesh_id = 0; + int particles_id = 0; + for (int i = 0; i < object_count; ++i) { + if (mesh_id < scene->mesh_infos_.size() && + scene->mesh_infos_[mesh_id].object_id == i) { + mesh(scene->mesh_infos_[mesh_id], scene); + ++mesh_id; + } + if (particles_id < scene->particles_infos_.size() && + scene->particles_infos_[particles_id].object_id == i) { + particles(scene->particles_infos_[particles_id], scene); + ++particles_id; + } } + scene->next_object_id_ = 0; scene->mesh_infos_.clear(); scene->particles_infos_.clear(); scene->point_lights_.clear(); diff --git a/taichi/ui/backends/vulkan/renderer.h b/taichi/ui/backends/vulkan/renderer.h index 9edb952881c3b..362489dd8bf18 100644 --- a/taichi/ui/backends/vulkan/renderer.h +++ b/taichi/ui/backends/vulkan/renderer.h @@ -30,13 +30,19 @@ #include "renderables/circles.h" #include "renderables/lines.h" +namespace taichi { +namespace lang { +class Program; +} // namespace lang +} // namespace taichi + TI_UI_NAMESPACE_BEGIN namespace vulkan { -class Renderer { +class TI_DLL_EXPORT Renderer { public: - void init(GLFWwindow *window, const AppConfig &config); + void init(lang::Program *prog, TaichiWindow *window, const AppConfig &config); void cleanup(); void prepare_for_next_frame(); diff --git a/taichi/ui/backends/vulkan/swap_chain.cpp b/taichi/ui/backends/vulkan/swap_chain.cpp index fd508059a346f..2084348a0a70d 100644 --- a/taichi/ui/backends/vulkan/swap_chain.cpp +++ b/taichi/ui/backends/vulkan/swap_chain.cpp @@ -1,6 +1,7 @@ #include "taichi/ui/utils/utils.h" #include "taichi/ui/backends/vulkan/app_context.h" #include "taichi/ui/backends/vulkan/swap_chain.h" +#include "taichi/util/image_io.h" TI_UI_NAMESPACE_BEGIN @@ -13,8 +14,9 @@ void SwapChain::init(class AppContext *app_context) { app_context_ = app_context; SurfaceConfig config; config.vsync = app_context_->config.vsync; - config.window_handle = app_context_->glfw_window(); - + config.window_handle = app_context_->taichi_window(); + config.width = app_context_->config.width; + config.height = app_context_->config.height; surface_ = app_context_->device().create_surface(config); auto [w, h] = surface_->get_size(); curr_width_ = w; @@ -62,6 +64,30 @@ taichi::lang::Surface &SwapChain::surface() { return *(surface_.get()); } +void SwapChain::write_image(const std::string &filename) { + auto [w, h] = surface_->get_size(); + DeviceAllocation img_buffer = surface_->get_image_data(); + unsigned char *ptr = (unsigned char *)app_context_->device().map(img_buffer); + auto format = surface_->image_format(); + if (format == BufferFormat::bgra8 || format == BufferFormat::bgra8srgb) { + TI_TRACE("Converting BGRA8 to RGBA for file output"); + std::vector converted(w * h); + uint32_t *u32ptr = (uint32_t *)ptr; + for (int j = 0; j < h; j++) { + for (int i = 0; i < w; i++) { + auto pixel = u32ptr[j * w + i]; + converted[j * w + i] = ((pixel << 16) & 0xFF0000) | + (pixel & 0x0000FF00) | ((pixel >> 16) & 0xFF) | + (pixel & 0xFF000000); + } + } + imwrite(filename, (size_t)converted.data(), w, h, 4); + } else { + imwrite(filename, (size_t)ptr, w, h, 4); + } + app_context_->device().unmap(img_buffer); +} + } // namespace vulkan TI_UI_NAMESPACE_END diff --git a/taichi/ui/backends/vulkan/swap_chain.h b/taichi/ui/backends/vulkan/swap_chain.h index 1da586a1310ed..c12261063a845 100644 --- a/taichi/ui/backends/vulkan/swap_chain.h +++ b/taichi/ui/backends/vulkan/swap_chain.h @@ -5,7 +5,7 @@ TI_UI_NAMESPACE_BEGIN namespace vulkan { -class SwapChain { +class TI_DLL_EXPORT SwapChain { public: void init(class AppContext *app_context); uint32_t width(); @@ -16,6 +16,8 @@ class SwapChain { void resize(uint32_t width, uint32_t height); + void write_image(const std::string &filename); + void cleanup(); private: diff --git a/taichi/ui/backends/vulkan/vertex.h b/taichi/ui/backends/vulkan/vertex.h index b4b228b5cca5e..7e53a9c9d5a28 100644 --- a/taichi/ui/backends/vulkan/vertex.h +++ b/taichi/ui/backends/vulkan/vertex.h @@ -9,6 +9,12 @@ struct Vertex { float y; float z; }; + struct vec4 { + float x; + float y; + float z; + float w; + }; struct vec2 { float x; float y; @@ -16,7 +22,7 @@ struct Vertex { vec3 pos; vec3 normal; vec2 texCoord; - vec3 color; + vec4 color; }; } // namespace ui diff --git a/taichi/ui/backends/vulkan/window.cpp b/taichi/ui/backends/vulkan/window.cpp index 5b1ad9237af55..397a41a3ddb6e 100644 --- a/taichi/ui/backends/vulkan/window.cpp +++ b/taichi/ui/backends/vulkan/window.cpp @@ -1,26 +1,34 @@ #include "taichi/ui/backends/vulkan/window.h" +#include "taichi/program/callable.h" + +using taichi::lang::Program; TI_UI_NAMESPACE_BEGIN namespace vulkan { -Window::Window(const AppConfig &config) : WindowBase(config) { - init(config); +Window::Window(Program *prog, const AppConfig &config) : WindowBase(config) { + init(prog, config); } -void Window::init(const AppConfig &config) { - glfwSetFramebufferSizeCallback(glfw_window_, framebuffer_resize_callback); +void Window::init(Program *prog, const AppConfig &config) { + if (config_.show_window) { + glfwSetFramebufferSizeCallback(glfw_window_, framebuffer_resize_callback); + } renderer_ = std::make_unique(); - renderer_->init(glfw_window_, config); + renderer_->init(prog, glfw_window_, config); canvas_ = std::make_unique(renderer_.get()); - gui_ = std::make_unique(&renderer_->app_context(), glfw_window_); + gui_ = std::make_unique(&renderer_->app_context(), + &renderer_->swap_chain(), glfw_window_); prepare_for_next_frame(); } void Window::show() { - draw_frame(); + if (!drawn_frame_) { + draw_frame(); + } present_frame(); WindowBase::show(); prepare_for_next_frame(); @@ -29,6 +37,7 @@ void Window::show() { void Window::prepare_for_next_frame() { renderer_->prepare_for_next_frame(); gui_->prepare_for_next_frame(); + drawn_frame_ = false; } CanvasBase *Window::get_canvas() { @@ -67,6 +76,7 @@ void Window::resize() { void Window::draw_frame() { renderer_->draw_frame(gui_.get()); + drawn_frame_ = true; } void Window::present_frame() { @@ -76,7 +86,19 @@ void Window::present_frame() { Window::~Window() { gui_->cleanup(); renderer_->cleanup(); - glfwTerminate(); + if (config_.show_window) { + glfwTerminate(); + } +} + +void Window::write_image(const std::string &filename) { + if (!drawn_frame_) { + draw_frame(); + } + renderer_->swap_chain().write_image(filename); + if (!config_.show_window) { + prepare_for_next_frame(); + } } } // namespace vulkan diff --git a/taichi/ui/backends/vulkan/window.h b/taichi/ui/backends/vulkan/window.h index db603d73078f7..651e35da76d98 100644 --- a/taichi/ui/backends/vulkan/window.h +++ b/taichi/ui/backends/vulkan/window.h @@ -22,27 +22,36 @@ #include "taichi/ui/common/window_base.h" #include "taichi/ui/backends/vulkan/gui.h" +namespace taichi { +namespace lang { +class Program; +} // namespace lang +} // namespace taichi + TI_UI_NAMESPACE_BEGIN namespace vulkan { class Window final : public WindowBase { public: - Window(const AppConfig &config); + Window(lang::Program *prog, const AppConfig &config); virtual void show() override; virtual CanvasBase *get_canvas() override; virtual GuiBase *GUI() override; + void write_image(const std::string &filename) override; + ~Window(); private: std::unique_ptr canvas_; std::unique_ptr gui_; std::unique_ptr renderer_; + bool drawn_frame_{false}; private: - void init(const AppConfig &config); + void init(lang::Program *prog, const AppConfig &config); void prepare_for_next_frame(); diff --git a/taichi/ui/common/app_config.h b/taichi/ui/common/app_config.h index ac8e5716da200..59b2cf896e07e 100644 --- a/taichi/ui/common/app_config.h +++ b/taichi/ui/common/app_config.h @@ -2,18 +2,21 @@ #include #include "taichi/ui/utils/utils.h" -#include "taichi/program/arch.h" +#include "taichi/backends/arch.h" -TI_UI_NAMESPACE_BEGIN +namespace taichi { +namespace ui { struct AppConfig { std::string name; int width{0}; int height{0}; bool vsync{false}; + bool show_window{true}; std::string package_path; - taichi::lang::Arch ti_arch; + Arch ti_arch; bool is_packed_mode{false}; }; -TI_UI_NAMESPACE_END +} // namespace ui +} // namespace taichi diff --git a/taichi/ui/common/field_info.cpp b/taichi/ui/common/field_info.cpp index db84184b3b882..900a60e9b22e7 100644 --- a/taichi/ui/common/field_info.cpp +++ b/taichi/ui/common/field_info.cpp @@ -26,7 +26,7 @@ DevicePtr get_device_ptr(taichi::lang::Program *program, SNode *snode) { int tree_id = root->get_snode_tree_id(); DevicePtr root_ptr = program->get_snode_tree_device_ptr(tree_id); - size_t offset = 0; + int64 offset = 0; int child_id = root->child_id(dense_parent); diff --git a/taichi/ui/common/field_info.h b/taichi/ui/common/field_info.h index 0f791d9c7ca8c..cedb56ec8071d 100644 --- a/taichi/ui/common/field_info.h +++ b/taichi/ui/common/field_info.h @@ -27,8 +27,16 @@ struct FieldInfo { DEFINE_PROPERTY(FieldSource, field_source); DEFINE_PROPERTY(taichi::lang::DataType, dtype); + // 'snode' is used by default if a Program is currently present. This + // is the default behavior and is used automatically when executing + // Taichi Kernels from Python or with an active Program. + // 'dev_alloc' is only used when no Program is currently present, for + // example when loading Taichi AOT modules in an external application + // and need to provide some information from those kernels to the GUI + // internal structures. using SNodePtr = taichi::lang::SNode *; DEFINE_PROPERTY(SNodePtr, snode); + DEFINE_PROPERTY(taichi::lang::DeviceAllocation, dev_alloc); FieldInfo() { valid = false; diff --git a/taichi/ui/common/scene_base.h b/taichi/ui/common/scene_base.h index ab6cd5999c217..e7a0c46f79e63 100644 --- a/taichi/ui/common/scene_base.h +++ b/taichi/ui/common/scene_base.h @@ -17,12 +17,14 @@ struct MeshInfo { RenderableInfo renderable_info; glm::vec3 color; bool two_sided{false}; + int object_id; }; struct ParticlesInfo { RenderableInfo renderable_info; glm::vec3 color; float radius; + int object_id; }; class SceneBase { @@ -33,9 +35,11 @@ class SceneBase { void mesh(const MeshInfo &info) { mesh_infos_.push_back(info); + mesh_infos_.back().object_id = next_object_id_++; } void particles(const ParticlesInfo &info) { particles_infos_.push_back(info); + particles_infos_.back().object_id = next_object_id_++; } void point_light(glm::vec3 pos, glm::vec3 color) { point_lights_.push_back({glm::vec4(pos, 1.0), glm::vec4(color, 1.0)}); @@ -51,6 +55,7 @@ class SceneBase { std::vector point_lights_; std::vector mesh_infos_; std::vector particles_infos_; + int next_object_id_ = 0; }; TI_UI_NAMESPACE_END diff --git a/taichi/ui/common/window_base.cpp b/taichi/ui/common/window_base.cpp index 449f524f3012b..f38a14776ce85 100644 --- a/taichi/ui/common/window_base.cpp +++ b/taichi/ui/common/window_base.cpp @@ -2,12 +2,18 @@ TI_UI_NAMESPACE_BEGIN +#define CHECK_WINDOW_SHOWING \ + TI_ERROR_IF(!config_.show_window, \ + "show_window must be True to use this method") + WindowBase ::WindowBase(AppConfig config) : config_(config) { - glfw_window_ = create_glfw_window_(config_.name, config_.width, - config_.height, config_.vsync); - glfwSetWindowUserPointer(glfw_window_, this); - set_callbacks(); - last_record_time_ = glfwGetTime(); + if (config_.show_window) { + glfw_window_ = create_glfw_window_(config_.name, config_.width, + config_.height, config_.vsync); + glfwSetWindowUserPointer(glfw_window_, this); + set_callbacks(); + last_record_time_ = glfwGetTime(); + } } void WindowBase::set_callbacks() { @@ -16,17 +22,27 @@ void WindowBase::set_callbacks() { glfwSetMouseButtonCallback(glfw_window_, mouse_button_callback); input_handler_.add_key_callback([&](int key, int action) { - if (action == GLFW_PRESS) { - events_.push_back({EventType::Press, button_id_to_name(key)}); - } else if (action == GLFW_RELEASE) { - events_.push_back({EventType::Release, button_id_to_name(key)}); + // Catch exception from button_id_to_name(). + try { + if (action == GLFW_PRESS) { + events_.push_back({EventType::Press, button_id_to_name(key)}); + } else if (action == GLFW_RELEASE) { + events_.push_back({EventType::Release, button_id_to_name(key)}); + } + } catch (const std::runtime_error &e) { + TI_TRACE("Input: {}.", e.what()); } }); input_handler_.add_mouse_button_callback([&](int key, int action) { - if (action == GLFW_PRESS) { - events_.push_back({EventType::Press, button_id_to_name(key)}); - } else if (action == GLFW_RELEASE) { - events_.push_back({EventType::Release, button_id_to_name(key)}); + // Catch exception from button_id_to_name(). + try { + if (action == GLFW_PRESS) { + events_.push_back({EventType::Press, button_id_to_name(key)}); + } else if (action == GLFW_RELEASE) { + events_.push_back({EventType::Release, button_id_to_name(key)}); + } + } catch (const std::runtime_error &e) { + TI_TRACE("Input: {}.", e.what()); } }); } @@ -36,6 +52,7 @@ CanvasBase *WindowBase::get_canvas() { } void WindowBase::show() { + CHECK_WINDOW_SHOWING; ++frames_since_last_record_; double current_time = glfwGetTime(); @@ -55,19 +72,32 @@ void WindowBase::show() { } bool WindowBase::is_pressed(std::string button) { - int button_id = buttom_name_to_id(button); + int button_id; + // Catch exception from buttom_name_to_id(). + try { + button_id = buttom_name_to_id(button); + } catch (const std::runtime_error &e) { + TI_TRACE("Pressed: {}.", e.what()); + return false; + } return input_handler_.is_pressed(button_id) > 0; } bool WindowBase::is_running() { - return !glfwWindowShouldClose(glfw_window_); + if (config_.show_window) { + return !glfwWindowShouldClose(glfw_window_); + } + return true; } void WindowBase::set_is_running(bool value) { - glfwSetWindowShouldClose(glfw_window_, !value); + if (config_.show_window) { + glfwSetWindowShouldClose(glfw_window_, !value); + } } std::pair WindowBase::get_cursor_pos() { + CHECK_WINDOW_SHOWING; float x = input_handler_.last_x(); float y = input_handler_.last_y(); @@ -77,6 +107,7 @@ std::pair WindowBase::get_cursor_pos() { } std::vector WindowBase::get_events(EventType tag) { + CHECK_WINDOW_SHOWING; glfwPollEvents(); std::vector result; std::list::iterator i = events_.begin(); @@ -92,6 +123,7 @@ std::vector WindowBase::get_events(EventType tag) { } bool WindowBase::get_event(EventType tag) { + CHECK_WINDOW_SHOWING; glfwPollEvents(); if (events_.size() == 0) { return false; @@ -115,14 +147,18 @@ bool WindowBase::get_event(EventType tag) { // these 2 are used to export the `current_event` field to python Event WindowBase::get_current_event() { + CHECK_WINDOW_SHOWING; return current_event_; } void WindowBase::set_current_event(const Event &event) { + CHECK_WINDOW_SHOWING; current_event_ = event; } WindowBase::~WindowBase() { - glfwDestroyWindow(glfw_window_); + if (config_.show_window) { + glfwDestroyWindow(glfw_window_); + } } GuiBase *WindowBase::GUI() { diff --git a/taichi/ui/common/window_base.h b/taichi/ui/common/window_base.h index d960bffbb59c1..ae6fe248fb1dc 100644 --- a/taichi/ui/common/window_base.h +++ b/taichi/ui/common/window_base.h @@ -39,6 +39,8 @@ class WindowBase { virtual void show(); + virtual void write_image(const std::string &filename) = 0; + virtual GuiBase *GUI(); virtual ~WindowBase(); diff --git a/taichi/ui/utils/utils.h b/taichi/ui/utils/utils.h index f898a11ca1844..43311fc6e9b34 100644 --- a/taichi/ui/utils/utils.h +++ b/taichi/ui/utils/utils.h @@ -31,8 +31,11 @@ #ifdef _WIN64 #define VK_USE_PLATFORM_WIN32_KHR 1 #endif -#include + +#include "taichi/backends/vulkan/vulkan_common.h" +#if !defined(ANDROID) #include +#endif #include @@ -51,6 +54,7 @@ TI_UI_NAMESPACE_BEGIN +#if !defined(ANDROID) inline void initGLFW() { if (!glfwInit()) { printf("cannot initialize GLFW\n"); @@ -66,6 +70,7 @@ inline GLFWwindow *create_glfw_window_(const std::string &name, GLFWwindow *window; glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API); + glfwWindowHint(GLFW_VISIBLE, GLFW_TRUE); window = glfwCreateWindow(screenWidth, screenHeight, name.c_str(), nullptr, nullptr); @@ -169,10 +174,11 @@ inline std::string button_id_to_name(int id) { if (keys.find(id) != keys.end()) { return keys.at(id); } else { - throw std::runtime_error(std::string("unrecognized id: \n") + + throw std::runtime_error(std::string("unrecognized id: ") + std::to_string(id)); } } +#endif inline int next_power_of_2(int n) { int count = 0; diff --git a/taichi/util/action_recorder.cpp b/taichi/util/action_recorder.cpp index 6eb9002490cbc..8827d2fc719c9 100644 --- a/taichi/util/action_recorder.cpp +++ b/taichi/util/action_recorder.cpp @@ -3,16 +3,15 @@ TI_NAMESPACE_BEGIN -std::string ActionArg::serialize() const { - std::string ret = key + ": "; +void ActionArg::serialize(std::ostream &ss) const { + ss << key << ": "; if (type == argument_type::str) { - ret += lang::c_quoted(val_str); + ss << lang::c_quoted(val_str); } else if (type == argument_type::int64) { - ret += std::to_string(val_int64); + ss << std::to_string(val_int64); } else { - ret += std::to_string(val_float64); + ss << std::to_string(val_float64); } - return ret; } ActionRecorder &ActionRecorder::get_instance() { @@ -25,31 +24,33 @@ ActionRecorder::ActionRecorder() { void ActionRecorder::start_recording(const std::string &fn) { TI_INFO("ActionRecorder: start recording to [{}]", fn); - TI_ASSERT(!running); - running = true; - ofs.open(fn); + TI_ASSERT(!running_); + running_ = true; + ofs_.open(fn); } void ActionRecorder::stop_recording() { TI_INFO("ActionRecorder: stop recording"); - TI_ASSERT(running); - running = false; - ofs.close(); + TI_ASSERT(running_); + running_ = false; + ofs_.close(); } bool ActionRecorder::is_recording() { - return running; + return running_; } void ActionRecorder::record(const std::string &content, const std::vector &arguments) { - if (!running) + if (!running_) return; - ofs << "- action: \"" << content << "\"" << std::endl; + ofs_ << "- action: \"" << content << "\"" << std::endl; for (auto &arg : arguments) { - ofs << " " << arg.serialize() << std::endl; + ofs_ << " "; + arg.serialize(ofs_); + ofs_ << std::endl; } - ofs.flush(); + ofs_.flush(); } TI_NAMESPACE_END diff --git a/taichi/util/action_recorder.h b/taichi/util/action_recorder.h index 8de8f04ed7349..b7eabe0754ef2 100644 --- a/taichi/util/action_recorder.h +++ b/taichi/util/action_recorder.h @@ -27,7 +27,7 @@ struct ActionArg { : key(key), val_float64(val), type(argument_type::float64) { } - std::string serialize() const; + void serialize(std::ostream &ss) const; std::string key; @@ -56,9 +56,9 @@ class ActionRecorder { private: ActionRecorder(); - std::ofstream ofs; + std::ofstream ofs_; - bool running{false}; + bool running_{false}; }; TI_NAMESPACE_END diff --git a/taichi/util/file_sequence_writer.cpp b/taichi/util/file_sequence_writer.cpp index ea8fdb06ff8e0..ce21d55c7dd8e 100644 --- a/taichi/util/file_sequence_writer.cpp +++ b/taichi/util/file_sequence_writer.cpp @@ -1,5 +1,7 @@ +#ifdef TI_WITH_LLVM #include "llvm/IR/Module.h" #include "llvm/Support/raw_ostream.h" +#endif #include "taichi/util/file_sequence_writer.h" @@ -7,15 +9,19 @@ TLANG_NAMESPACE_BEGIN FileSequenceWriter::FileSequenceWriter(std::string filename_template, std::string file_type) - : counter(0), filename_template(filename_template), file_type(file_type) { + : counter_(0), + filename_template_(filename_template), + file_type_(file_type) { } +#ifdef TI_WITH_LLVM std::string FileSequenceWriter::write(llvm::Module *module) { std::string str; llvm::raw_string_ostream ros(str); module->print(ros, nullptr); return write(str); } +#endif std::string FileSequenceWriter::write(const std::string &str) { auto [ofs, fn] = create_new_file(); @@ -30,9 +36,9 @@ std::string FileSequenceWriter::write(IRNode *irnode) { } std::pair FileSequenceWriter::create_new_file() { - auto fn = fmt::format(filename_template, counter); - TI_INFO("Saving {} to {}", file_type, fn); - counter++; + auto fn = fmt::format(filename_template_, counter_); + TI_INFO("Saving {} to {}", file_type_, fn); + counter_++; return {std::ofstream(fn), fn}; } diff --git a/taichi/util/file_sequence_writer.h b/taichi/util/file_sequence_writer.h index b45e14486235a..b75db02c70625 100644 --- a/taichi/util/file_sequence_writer.h +++ b/taichi/util/file_sequence_writer.h @@ -2,7 +2,9 @@ #include "taichi/lang_util.h" #include "taichi/ir/transforms.h" +#ifdef TI_WITH_LLVM #include "taichi/llvm/llvm_fwd.h" +#endif TLANG_NAMESPACE_BEGIN @@ -10,17 +12,19 @@ class FileSequenceWriter { public: FileSequenceWriter(std::string filename_template, std::string file_type); +#ifdef TI_WITH_LLVM // returns filename std::string write(llvm::Module *module); +#endif std::string write(IRNode *irnode); std::string write(const std::string &str); private: - int counter; - std::string filename_template; - std::string file_type; + int counter_; + std::string filename_template_; + std::string file_type_; std::pair create_new_file(); }; diff --git a/taichi/util/str.h b/taichi/util/str.h index a0021058f1f1e..9b9388185c634 100644 --- a/taichi/util/str.h +++ b/taichi/util/str.h @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include diff --git a/taichi/util/zip.cpp b/taichi/util/zip.cpp index 67f8696155499..f95312d6cc187 100644 --- a/taichi/util/zip.cpp +++ b/taichi/util/zip.cpp @@ -5,7 +5,8 @@ #ifndef _FILE_OFFSET_BITS #define _FILE_OFFSET_BITS 64 #endif -#ifndef _LARGEFILE64_SOURCE +#if !defined(_LARGEFILE64_SOURCE) && defined(TI_PLATFORM_LINUX) +// Only Linux has large file extension #define _LARGEFILE64_SOURCE 1 #endif #endif diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/cpp/analysis/alias_analysis_test.cpp b/tests/cpp/analysis/alias_analysis_test.cpp index 8647c20d3d7d8..ac6dfd4492b90 100644 --- a/tests/cpp/analysis/alias_analysis_test.cpp +++ b/tests/cpp/analysis/alias_analysis_test.cpp @@ -94,6 +94,58 @@ TEST(AliasAnalysis, GlobalPtr_DiffSNodes) { // TODO(#2193): Add tests for other edge cases and other kinds of statements. +TEST(AliasAnalysis, ExternalPtr_Same) { + IRBuilder builder; + auto *arg1 = builder.create_arg_load(1, PrimitiveType::i32, true); + auto *arg2 = builder.create_arg_load(2, PrimitiveType::i32, false); + const auto indices = std::vector{arg2, arg2}; + auto *eptr1 = builder.create_external_ptr(arg1, indices); + auto *eptr2 = builder.create_external_ptr(arg1, indices); + + const auto aa = alias_analysis(eptr1, eptr2); + EXPECT_EQ(aa, AliasResult::same); +} + +TEST(AliasAnalysis, ExternalPtr_Different) { + IRBuilder builder; + auto *arg1 = builder.create_arg_load(1, PrimitiveType::i32, true); + auto *arg2 = builder.create_arg_load(2, PrimitiveType::i32, false); + const auto indices1 = std::vector{arg2, builder.get_int32(1)}; + const auto indices2 = std::vector{arg2, builder.get_int32(2)}; + auto *eptr1 = builder.create_external_ptr(arg1, indices1); + auto *eptr2 = builder.create_external_ptr(arg1, indices2); + + const auto aa = alias_analysis(eptr1, eptr2); + EXPECT_EQ(aa, AliasResult::different); +} + +TEST(AliasAnalysis, ExternalPtr_Uncertain) { + IRBuilder builder; + auto *arg1 = builder.create_arg_load(1, PrimitiveType::i32, true); + auto *arg2 = builder.create_arg_load(2, PrimitiveType::i32, false); + auto *arg3 = builder.create_arg_load(3, PrimitiveType::i32, false); + const auto indices1 = std::vector{arg2, arg2}; + const auto indices2 = std::vector{arg2, arg3}; + auto *eptr1 = builder.create_external_ptr(arg1, indices1); + auto *eptr2 = builder.create_external_ptr(arg1, indices2); + + const auto aa = alias_analysis(eptr1, eptr2); + EXPECT_EQ(aa, AliasResult::uncertain); +} + +TEST(AliasAnalysis, ExternalPtr_DiffPtr) { + IRBuilder builder; + auto *arg1 = builder.create_arg_load(1, PrimitiveType::i32, true); + auto *arg2 = builder.create_arg_load(2, PrimitiveType::i32, true); + auto *arg3 = builder.create_arg_load(3, PrimitiveType::i32, false); + const auto indices = std::vector{arg3, arg3}; + auto *eptr1 = builder.create_external_ptr(arg1, indices); + auto *eptr2 = builder.create_external_ptr(arg2, indices); + + const auto aa = alias_analysis(eptr1, eptr2); + EXPECT_EQ(aa, AliasResult::different); +} + } // namespace analysis } // namespace irpass } // namespace lang diff --git a/tests/cpp/analysis/same_statements_test.cpp b/tests/cpp/analysis/same_statements_test.cpp index 05391e7d783c3..3c20bd2cbfcf3 100644 --- a/tests/cpp/analysis/same_statements_test.cpp +++ b/tests/cpp/analysis/same_statements_test.cpp @@ -158,7 +158,7 @@ TEST(SameStatements, TestSameLoopIndex) { auto range_for = block ->push_back(zero, four, std::make_unique(), 1, 1, - 1, 1, false) + 1, false) ->as(); auto loop_index_a = range_for->body->push_back(range_for, 0); auto loop_index_b = range_for->body->push_back(range_for, 0); diff --git a/tests/cpp/analysis/value_diff_test.cpp b/tests/cpp/analysis/value_diff_test.cpp index 2d2ceb29c0f54..5962ec0012556 100644 --- a/tests/cpp/analysis/value_diff_test.cpp +++ b/tests/cpp/analysis/value_diff_test.cpp @@ -70,7 +70,7 @@ TEST(DiffRangeTest, Add) { auto diff = value_diff_loop_index(loop_idx2, for_stmt.get(), /*index=*/0); - EXPECT_TRUE(diff.related_()); + EXPECT_TRUE(diff.related()); EXPECT_EQ(diff.coeff, 1); EXPECT_EQ(diff.low, 4); EXPECT_EQ(diff.high, 5); @@ -88,7 +88,7 @@ TEST(DiffRangeTest, Sub) { auto diff = value_diff_loop_index(loop_idx2, for_stmt.get(), /*index=*/0); - EXPECT_TRUE(diff.related_()); + EXPECT_TRUE(diff.related()); EXPECT_EQ(diff.coeff, 1); EXPECT_EQ(diff.low, -4); EXPECT_EQ(diff.high, -3); @@ -106,7 +106,7 @@ TEST(DiffRangeTest, Mul) { auto diff = value_diff_loop_index(loop_idx2, for_stmt.get(), /*index=*/0); - EXPECT_TRUE(diff.related_()); + EXPECT_TRUE(diff.related()); EXPECT_EQ(diff.coeff, 4); EXPECT_EQ(diff.low, 0); EXPECT_EQ(diff.high, 1); @@ -124,7 +124,7 @@ TEST(DiffRangeTest, Shl) { auto diff = value_diff_loop_index(loop_idx2, for_stmt.get(), /*index=*/0); - EXPECT_TRUE(diff.related_()); + EXPECT_TRUE(diff.related()); EXPECT_EQ(diff.coeff, 4); EXPECT_EQ(diff.low, 0); EXPECT_EQ(diff.high, 1); diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp new file mode 100644 index 0000000000000..45885d7e28bf8 --- /dev/null +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -0,0 +1,153 @@ +#include "gtest/gtest.h" +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/statements.h" +#include "taichi/program/program.h" +#ifdef TI_WITH_VULKAN +#include "taichi/backends/vulkan/aot_module_loader_impl.h" +#include "taichi/backends/device.h" +#include "taichi/backends/vulkan/vulkan_device.h" +#include "taichi/backends/vulkan/vulkan_device_creator.h" +#include "taichi/backends/vulkan/vulkan_loader.h" +#include "taichi/backends/vulkan/vulkan_utils.h" +#endif + +using namespace taichi; +using namespace lang; + +[[maybe_unused]] static void aot_save() { + auto program = Program(Arch::vulkan); + + program.config.advanced_optimization = false; + + int n = 10; + + auto *root = new SNode(0, SNodeType::root); + auto *pointer = &root->dense(Axis(0), n, false); + auto *place = &pointer->insert_children(SNodeType::place); + place->dt = PrimitiveType::i32; + program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/true); + + auto aot_builder = program.make_aot_module_builder(Arch::vulkan); + + std::unique_ptr kernel_init, kernel_ret; + + { + /* + @ti.kernel + def init(): + for index in range(n): + place[index] = index + */ + IRBuilder builder; + auto *zero = builder.get_int32(0); + auto *n_stmt = builder.get_int32(n); + auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *ptr = builder.create_global_ptr(place, {index}); + builder.create_global_store(ptr, index); + } + + kernel_init = + std::make_unique(program, builder.extract_ir(), "init"); + } + + { + /* + @ti.kernel + def ret(): + sum = 0 + for index in place: + sum = sum + place[index]; + return sum + */ + IRBuilder builder; + auto *sum = builder.create_local_var(PrimitiveType::i32); + auto *loop = builder.create_struct_for(pointer, 1, 0, 4); + { + auto _ = builder.get_loop_guard(loop); + auto *index = builder.get_loop_index(loop); + auto *sum_old = builder.create_local_load(sum); + auto *place_index = + builder.create_global_load(builder.create_global_ptr(place, {index})); + builder.create_local_store(sum, builder.create_add(sum_old, place_index)); + } + builder.create_return(builder.create_local_load(sum)); + + kernel_ret = std::make_unique(program, builder.extract_ir(), "ret"); + kernel_ret->insert_ret(PrimitiveType::i32); + } + + aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1); + aot_builder->add("init", kernel_init.get()); + aot_builder->add("ret", kernel_ret.get()); + aot_builder->dump(".", ""); +} + +#ifdef TI_WITH_VULKAN +TEST(AotSaveLoad, Vulkan) { + // Otherwise will segfault on macOS VM, + // where Vulkan is installed but no devices are present + if (!vulkan::is_vulkan_api_available()) { + return; + } + + aot_save(); + + // API based on proposal https://github.com/taichi-dev/taichi/issues/3642 + // Initialize Vulkan program + taichi::uint64 *result_buffer{nullptr}; + taichi::lang::RuntimeContext host_ctx; + auto memory_pool = + std::make_unique(Arch::vulkan, nullptr); + result_buffer = (taichi::uint64 *)memory_pool->allocate( + sizeof(taichi::uint64) * taichi_result_buffer_entries, 8); + + // Create Taichi Device for computation + lang::vulkan::VulkanDeviceCreator::Params evd_params; + evd_params.api_version = + taichi::lang::vulkan::VulkanEnvSettings::kApiVersion(); + auto embedded_device = + std::make_unique(evd_params); + + // Create Vulkan runtime + vulkan::VkRuntime::Params params; + params.host_result_buffer = result_buffer; + params.device = embedded_device->device(); + auto vulkan_runtime = + std::make_unique(std::move(params)); + + // Run AOT module loader + vulkan::AotModuleParams mod_params; + mod_params.module_path = "."; + mod_params.runtime = vulkan_runtime.get(); + + std::unique_ptr vk_module = + aot::Module::load(".", Arch::vulkan, mod_params); + EXPECT_TRUE(vk_module); + + // Retrieve kernels/fields/etc from AOT module to initialize runtime + auto root_size = vk_module->get_root_size(); + EXPECT_EQ(root_size, 64); + vulkan_runtime->add_root_buffer(root_size); + + auto init_kernel = vk_module->get_kernel("init"); + EXPECT_TRUE(init_kernel); + + auto ret_kernel = vk_module->get_kernel("ret"); + EXPECT_TRUE(ret_kernel); + + auto ret2_kernel = vk_module->get_kernel("ret2"); + EXPECT_FALSE(ret2_kernel); + + // Run kernels + init_kernel->launch(&host_ctx); + ret_kernel->launch(&host_ctx); + vulkan_runtime->synchronize(); + + // auto x_field = vk_module.get_field("x"); + // EXPECT_TRUE(x_field); + // x_field.copy_to(/*dst=*/x.get()); +} +#endif diff --git a/tests/cpp/backends/dx11_device_test.cpp b/tests/cpp/backends/dx11_device_test.cpp new file mode 100644 index 0000000000000..9fd12735cd72e --- /dev/null +++ b/tests/cpp/backends/dx11_device_test.cpp @@ -0,0 +1,154 @@ +#include "gtest/gtest.h" + +#ifdef TI_WITH_DX11 + +#include "taichi/ir/ir_builder.h" +#include "taichi/backends/dx/dx_device.h" +#include "taichi/backends/dx/dx_info_queue.h" +#include "taichi/backends/dx/dx_program.h" +#include "taichi/system/memory_pool.h" +#include "tests/cpp/program/test_program.h" + +namespace taichi { +namespace lang { +namespace directx11 { + +TEST(Dx11DeviceCreationTest, CreateDeviceAndAllocateMemory) { + std::unique_ptr device = + std::make_unique(); + + // Should not crash + EXPECT_TRUE(device != nullptr); + + // Should have one object of each of the following types: + // ID3D11Device + // ID3D11Context + // ID3DDeviceContextState + // ID3D11BlendState + // ID3D11DepthStencilState + // ID3D11RasterizerState + // ID3D11Sampler + // ID3D11Query + int count0, count1, count2; + if (kD3d11DebugEnabled) { + count0 = device->live_dx11_object_count(); + EXPECT_EQ(count0, 8); + } + + taichi::lang::Device::AllocParams params; + params.size = 1048576; + const taichi::lang::DeviceAllocation device_alloc = + device->allocate_memory(params); + if (kD3d11DebugEnabled) { + count1 = device->live_dx11_object_count(); + // Should have allocated an UAV and a Buffer, so 2 more objects. + EXPECT_EQ(count1 - count0, 2); + } + + // Map to CPU, write some values, then check those values + void *mapped = device->map(device_alloc); + int *mapped_int = reinterpret_cast(mapped); + for (int i = 0; i < 100; i++) { + mapped_int[i] = i; + } + device->unmap(device_alloc); + + mapped = device->map(device_alloc); + mapped_int = reinterpret_cast(mapped); + for (int i = 0; i < 100; i++) { + EXPECT_EQ(mapped_int[i], i); + } + device->unmap(device_alloc); + + // The 2 objects should have been released. + device->dealloc_memory(device_alloc); + if (kD3d11DebugEnabled) { + count2 = device->live_dx11_object_count(); + EXPECT_EQ(count2 - count1, -2); + } +} + +TEST(Dx11InfoQueueTest, ParseReferenceCount) { + const std::vector messages = { + "Create ID3D11Context: Name=\"unnamed\", Addr=0x0000018F6678E080, " + "ExtRef=1, IntRef=0", + "Create ID3DDeviceContextState: Name=\"unnamed\", " + "Addr=0x0000018F6686CE10, " + "ExtRef=1, IntRef=0", + "Create ID3D11BlendState: Name=\"unnamed\", Addr=0x0000018F667F6DB0, " + "ExtRef=1, IntRef=0", + "Create ID3D11DepthStencilState: Name=\"unnamed\", " + "Addr=0x0000018F667F6BC0, ExtRef=1, IntRef=0", + "Create ID3D11RasterizerState: Name=\"unnamed\", " + "Addr=0x0000018F64891420, " + "ExtRef=1, IntRef=0", + "Create ID3D11Sampler: Name=\"unnamed\", Addr=0x0000018F667F6FA0, " + "ExtRef=1, IntRef=0", + "Create ID3D11Query: Name=\"unnamed\", Addr=0x0000018F64E81DA0, " + "ExtRef=1, IntRef=0", + "Create ID3D11Fence: Name=\"unnamed\", Addr=0x0000018F64FF7380, " + "ExtRef=1, IntRef=0", + "Destroy ID3D11Fence: Name=\"unnamed\", Addr=0x0000018F64FF7380", + "Live ID3D11Device at 0x0000018F66782250, Refcount: 5", + "Live ID3D11Context at 0x0000018F6678E080, Refcount: 1, IntRef: 1", + "Live ID3DDeviceContextState at 0x0000018F6686CE10, Refcount: 0, IntRef: " + "1", + "Live ID3D11BlendState at 0x0000018F667F6DB0, Refcount: 0, " + "IntRef: 1", + "Live ID3D11DepthStencilState at 0x0000018F667F6BC0, Refcount: 0, " + "IntRef: 1", + "Live ID3D11RasterizerState at 0x0000018F64891420, Refcount: 0, " + "IntRef: 1", + "Live ID3D11Sampler at 0x0000018F667F6FA0, Refcount: 0, IntRef: 1", + "Live ID3D11Query at 0x0000018F64E81DA0, Refcount: 0, IntRef: 1"}; + std::vector entries = + directx11::Dx11InfoQueue::parse_reference_count(messages); + EXPECT_EQ(entries.size(), 8); +} + +TEST(Dx11StreamTest, CommandListTest) { + std::unique_ptr device = + std::make_unique(); + std::unique_ptr stream = + std::make_unique(device.get()); + stream->new_command_list(); +} + +TEST(Dx11ProgramTest, MaterializeRuntimeTest) { + std::unique_ptr device = + std::make_unique(); + std::unique_ptr pool = + std::make_unique(Arch::dx11, device.get()); + std::unique_ptr program = + std::make_unique(default_compile_config); + /* + This test needs allocate_memory because of the call stack here: + Dx11ProgramImpl::materialize_runtime + - VkRuntime::VkRuntime + - VkRuntime::init_buffers + - Dx11Device::allocate_memory_unique + - Dx11Device::get_compute_stream + - Dx11Stream::new_command_list + - Dx11Stream::buffer_fill + - Dx11Stream::submit_synced + */ + uint64_t *result_buffer; + program->materialize_runtime(pool.get(), nullptr, &result_buffer); + + TestProgram test_prog; + test_prog.setup(); + + IRBuilder builder; + auto *lhs = builder.get_int32(42); + + auto block = builder.extract_ir(); + test_prog.prog()->config.arch = Arch::dx11; + auto ker = std::make_unique(*test_prog.prog(), std::move(block)); + program->compile(ker.get(), nullptr); +} + +} // namespace directx11 +} // namespace lang +} // namespace taichi + +#endif diff --git a/tests/cpp/codegen/refine_coordinates_test.cpp b/tests/cpp/codegen/refine_coordinates_test.cpp index 24a7bf1d3c2b8..f405d8727d41c 100644 --- a/tests/cpp/codegen/refine_coordinates_test.cpp +++ b/tests/cpp/codegen/refine_coordinates_test.cpp @@ -1,3 +1,4 @@ +#ifdef TI_WITH_LLVM #include "gtest/gtest.h" #include @@ -6,12 +7,12 @@ #include "llvm/IR/Type.h" #include "llvm/IR/BasicBlock.h" -#include "taichi/program/arch.h" -#include "taichi/program/program.h" -#include "taichi/struct/struct_llvm.h" +#include "taichi/backends/arch.h" #include "taichi/ir/snode.h" -#include "taichi/program/compile_config.h" #include "taichi/llvm/llvm_codegen_utils.h" +#include "taichi/program/compile_config.h" +#include "taichi/program/program.h" +#include "taichi/struct/struct_llvm.h" namespace taichi { @@ -117,7 +118,8 @@ class RefineCoordinatesTest : public ::testing::Test { leaf_snode.dt = PrimitiveType::f32; auto sc = std::make_unique( - arch_, &config_, tlctx_, tlctx_->clone_runtime_module()); + arch_, &config_, tlctx_, tlctx_->clone_runtime_module(), + /*snode_tree_id=*/0); sc->run(*root_snode_); } @@ -164,3 +166,4 @@ TEST_F(RefineCoordinatesTest, Basic) { } // namespace } // namespace lang } // namespace taichi +#endif // #ifdef TI_WITH_LLVM diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp new file mode 100644 index 0000000000000..deb775996fa03 --- /dev/null +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -0,0 +1,183 @@ +#include +#include "gtest/gtest.h" + +#include "taichi/ir/frontend_ir.h" +#include "taichi/program/program.h" + +namespace taichi { +namespace lang { + +TEST(FrontendTypeInference, Const) { + auto const_i64 = Expr::make(1LL << 63); + const_i64->type_check(nullptr); + EXPECT_EQ(const_i64->ret_type, PrimitiveType::i64); +} + +TEST(FrontendTypeInference, ArgLoad) { + auto arg_load_u64 = Expr::make(2, PrimitiveType::u64); + arg_load_u64->type_check(nullptr); + EXPECT_EQ(arg_load_u64->ret_type, PrimitiveType::u64); +} + +TEST(FrontendTypeInference, Rand) { + auto rand_f16 = Expr::make(PrimitiveType::f16); + rand_f16->type_check(nullptr); + EXPECT_EQ(rand_f16->ret_type, PrimitiveType::f16); +} + +TEST(FrontendTypeInference, Id) { + auto prog = std::make_unique(Arch::x64); + auto func = []() {}; + auto kernel = std::make_unique(*prog, func, "fake_kernel"); + Callable::CurrentCallableGuard _(kernel->program, kernel.get()); + auto const_i32 = Expr::make(-(1 << 20)); + const_i32->type_check(nullptr); + auto id_i32 = prog->current_ast_builder()->make_var(const_i32); + EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32); +} + +TEST(FrontendTypeInference, BinaryOp) { + auto prog = std::make_unique(Arch::x64); + prog->config.default_fp = PrimitiveType::f64; + auto const_i32 = Expr::make(-(1 << 20)); + const_i32->type_check(nullptr); + auto const_f32 = Expr::make(5.0); + const_f32->type_check(nullptr); + auto truediv_f64 = expr_truediv(const_i32, const_f32); + truediv_f64->type_check(&prog->config); + EXPECT_EQ(truediv_f64->ret_type, PrimitiveType::f64); +} + +TEST(FrontendTypeInference, UnaryOp) { + auto const_i16 = Expr::make(-(1 << 10)); + const_i16->type_check(nullptr); + EXPECT_EQ(const_i16->ret_type, PrimitiveType::i16); + auto cast_i8 = cast(const_i16, PrimitiveType::i8); + cast_i8->type_check(nullptr); + EXPECT_EQ(cast_i8->ret_type, PrimitiveType::i8); + auto bit_not_i16 = ~const_i16; + bit_not_i16->type_check(nullptr); + EXPECT_EQ(bit_not_i16->ret_type, PrimitiveType::i16); +} + +TEST(FrontendTypeInference, TernaryOp) { + auto const_i16 = Expr::make(-(1 << 10)); + const_i16->type_check(nullptr); + EXPECT_EQ(const_i16->ret_type, PrimitiveType::i16); + auto cast_i8 = cast(const_i16, PrimitiveType::i8); + cast_i8->type_check(nullptr); + EXPECT_EQ(cast_i8->ret_type, PrimitiveType::i8); + auto const_f32 = Expr::make(5.0); + const_f32->type_check(nullptr); + EXPECT_EQ(const_f32->ret_type, PrimitiveType::f32); + auto ternary_f32 = expr_select(const_i16, cast_i8, const_f32); + ternary_f32->type_check(nullptr); + EXPECT_EQ(ternary_f32->ret_type, PrimitiveType::f32); +} + +TEST(FrontendTypeInference, GlobalPtr_GlobalVariable) { + auto snode = std::make_unique(0, SNodeType::root); + snode->dt = PrimitiveType::u8; + auto global_var = Expr::make(snode.get()); + auto index = Expr::make(2); + index->type_check(nullptr); + auto global_ptr = + Expr::make(global_var, ExprGroup(index)); + global_ptr->type_check(nullptr); + EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); +} + +TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { + auto index = Expr::make(2); + index->type_check(nullptr); + auto external_tensor = + Expr::make(PrimitiveType::u16, 1, 0, 0); + auto global_ptr = + Expr::make(external_tensor, ExprGroup(index)); + EXPECT_THROW(global_ptr->type_check(nullptr), TaichiTypeError); +} + +TEST(FrontendTypeInference, TensorElement) { + auto prog = std::make_unique(Arch::x64); + auto func = []() {}; + auto kernel = std::make_unique(*prog, func, "fake_kernel"); + Callable::CurrentCallableGuard _(kernel->program, kernel.get()); + const std::vector shape{3}; + auto var = Expr(std::make_shared()); + prog->current_ast_builder()->insert(std::make_unique( + std::static_pointer_cast(var.expr)->id, shape, + PrimitiveType::u32)); + var->ret_type = prog->current_ast_builder()->get_last_stmt()->ret_type; + auto index = Expr::make(2); + index->type_check(nullptr); + auto tensor_element = + Expr::make(var, ExprGroup(index), shape, 1); + tensor_element->type_check(nullptr); + EXPECT_EQ(tensor_element->ret_type, PrimitiveType::u32); +} + +TEST(FrontendTypeInference, AtomicOp) { + auto const_i32 = Expr::make(-(1 << 20)); + const_i32->type_check(nullptr); + auto const_f32 = Expr::make(5.0); + const_f32->type_check(nullptr); + auto atomic_add_i32 = + Expr::make(AtomicOpType::add, const_i32, const_f32); + atomic_add_i32->type_check(nullptr); + EXPECT_EQ(atomic_add_i32->ret_type, PrimitiveType::i32); +} + +TEST(FrontendTypeInference, SNodeOp) { + auto snode = std::make_unique(0, SNodeType::root); + snode->dt = PrimitiveType::u8; + auto index = Expr::make(2); + index->type_check(nullptr); + auto snode_op = Expr::make( + snode.get(), SNodeOpType::get_addr, ExprGroup(index)); + snode_op->type_check(nullptr); + EXPECT_EQ(snode_op->ret_type, PrimitiveType::u64); +} + +TEST(FrontendTypeInference, ExternalTensorShapeAlongAxis) { + auto external_tensor = + Expr::make(PrimitiveType::u64, 1, 0, 0); + auto shape = + Expr::make(external_tensor, 0); + shape->type_check(nullptr); + EXPECT_EQ(shape->ret_type, PrimitiveType::i32); +} + +TEST(FrontendTypeInference, RangeAssumption) { + auto const_f32_a = Expr::make(5.0); + const_f32_a->type_check(nullptr); + auto const_f32_b = Expr::make(5.0); + const_f32_b->type_check(nullptr); + auto valid = + Expr::make(const_f32_a, const_f32_b, 0, 1); + valid->type_check(nullptr); + EXPECT_EQ(valid->ret_type, PrimitiveType::f32); + auto const_f64 = Expr::make(5.0); + const_f64->type_check(nullptr); + auto invalid = + Expr::make(const_f32_a, const_f64, 0, 1); + EXPECT_THROW(invalid->type_check(nullptr), TaichiTypeError); +} + +TEST(FrontendTypeInference, LoopUnique) { + auto const_i64 = Expr::make(5); + const_i64->type_check(nullptr); + auto loop_unique = + Expr::make(const_i64, std::vector{}); + loop_unique->type_check(nullptr); + EXPECT_EQ(loop_unique->ret_type, PrimitiveType::i64); +} + +TEST(FrontendTypeInference, InternalFuncCall) { + auto internal_func_call = + Expr::make("do_nothing", std::vector{}); + internal_func_call->type_check(nullptr); + EXPECT_EQ(internal_func_call->ret_type, PrimitiveType::i32); +} + +} // namespace lang +} // namespace taichi diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index 1b67cd9d29976..cef117651871a 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -108,9 +108,10 @@ TEST(IRBuilder, ExternalPtr) { builder.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2] auto block = builder.extract_ir(); auto ker = std::make_unique(*test_prog.prog(), std::move(block)); - ker->insert_arg(get_data_type(), /*is_external_array=*/true); + ker->insert_arg(get_data_type(), /*is_array=*/true); auto launch_ctx = ker->make_launch_context(); - launch_ctx.set_arg_external_array(/*arg_id=*/0, (uint64)array.get(), size); + launch_ctx.set_arg_external_array(/*arg_id=*/0, (uint64)array.get(), size, + /*is_device_allocation=*/false); (*ker)(launch_ctx); EXPECT_EQ(array[0], 2); EXPECT_EQ(array[1], 1); diff --git a/tests/cpp/ir/type_test.cpp b/tests/cpp/ir/type_test.cpp index 997da16bf5a7a..620117700e1f8 100644 --- a/tests/cpp/ir/type_test.cpp +++ b/tests/cpp/ir/type_test.cpp @@ -8,6 +8,9 @@ namespace taichi { namespace lang { TEST(Type, BitTypes) { + auto f16 = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f16); + EXPECT_EQ(f16->to_string(), "f16"); auto i32 = TypeFactory::get_instance() .get_primitive_type(PrimitiveTypeID::i32) ->as(); diff --git a/tests/cpp/program/test_program.cpp b/tests/cpp/program/test_program.cpp index 464aa8147cd0e..3d1a0858165bd 100644 --- a/tests/cpp/program/test_program.cpp +++ b/tests/cpp/program/test_program.cpp @@ -6,7 +6,8 @@ namespace lang { void TestProgram::setup() { prog_ = std::make_unique(Arch::x64); prog_->materialize_runtime(); - prog_->add_snode_tree(std::make_unique(/*depth=*/0, SNodeType::root)); + prog_->add_snode_tree(std::make_unique(/*depth=*/0, SNodeType::root), + /*compile_only=*/false); } } // namespace lang diff --git a/tests/cpp/transforms/binary_op_simplify_test.cpp b/tests/cpp/transforms/binary_op_simplify_test.cpp index 3f29cadf10ae9..6ba50e376b56b 100644 --- a/tests/cpp/transforms/binary_op_simplify_test.cpp +++ b/tests/cpp/transforms/binary_op_simplify_test.cpp @@ -50,7 +50,7 @@ TEST_F(BinaryOpSimplifyTest, MultiplyPOT) { EXPECT_EQ(bin_op->op_type, BinaryOpType::bit_shl); EXPECT_EQ(bin_op->rhs, const_stmt); EXPECT_TRUE(ir_block->statements[3]->is()); - EXPECT_EQ(ir_block->statements[3]->as()->value, bin_op); + EXPECT_EQ(ir_block->statements[3]->as()->values[0], bin_op); } TEST_F(BinaryOpSimplifyTest, ModPOT) { @@ -98,7 +98,7 @@ TEST_F(BinaryOpSimplifyTest, ModPOT) { EXPECT_EQ(bin_op->op_type, BinaryOpType::bit_and); EXPECT_EQ(bin_op->rhs, const_stmt); EXPECT_TRUE(ir_block->statements[3]->is()); - EXPECT_EQ(ir_block->statements[3]->as()->value, bin_op); + EXPECT_EQ(ir_block->statements[3]->as()->values[0], bin_op); } } // namespace lang diff --git a/tests/cpp/transforms/extract_constant_test.cpp b/tests/cpp/transforms/extract_constant_test.cpp index b579f9178a429..4e4693d3bc501 100644 --- a/tests/cpp/transforms/extract_constant_test.cpp +++ b/tests/cpp/transforms/extract_constant_test.cpp @@ -26,7 +26,7 @@ TEST_F(ExtractConstantTest, ExtractConstant) { builder.set_insertion_point_to_loop_begin(for_stmt); auto *x = builder.create_local_var(get_data_type()); auto *x_v = builder.create_local_load(x); - auto *sum = builder.create_add(x_v, builder.get_int32(1)); + builder.create_add(x_v, builder.get_int32(1)); auto ir = builder.extract_ir(); ASSERT_TRUE(ir->is()); diff --git a/tests/cpp/transforms/inlining_test.cpp b/tests/cpp/transforms/inlining_test.cpp index f88c8d8ba3f95..6e77c4bf0f561 100644 --- a/tests/cpp/transforms/inlining_test.cpp +++ b/tests/cpp/transforms/inlining_test.cpp @@ -34,7 +34,7 @@ TEST_F(InliningTest, ArgLoadOfArgLoad) { auto *func = prog_->create_function( FunctionKey("test_func", /*func_id=*/0, /*instance_id=*/0)); - func->insert_arg(get_data_type(), /*is_external_array=*/false); + func->insert_arg(get_data_type(), /*is_array=*/false); func->insert_ret(get_data_type()); func->set_function_body(std::move(func_body)); diff --git a/tests/python/bls_test_template.py b/tests/python/bls_test_template.py index 8d243cdd44331..879b27981da6c 100644 --- a/tests/python/bls_test_template.py +++ b/tests/python/bls_test_template.py @@ -15,7 +15,7 @@ def bls_test_template(dim, dense=False): x, y, y2 = ti.field(ti.i32), ti.field(ti.i32), ti.field(ti.i32) - index = ti.indices(*range(dim)) + index = ti.axes(*range(dim)) mismatch = ti.field(ti.i32, shape=()) if not isinstance(bs, (tuple, list)): @@ -177,8 +177,7 @@ def insert(): base = ti.Vector([ int(ti.floor(x[i][0] * N) - grid_offset[0]), int(ti.floor(x[i][1] * N) - grid_offset[1]) - ], - dt=ti.i32) + ]) base_p = ti.rescale_index(m1, pid, base) ti.append(pid.parent(), base_p, i) diff --git a/tests/python/ell.json b/tests/python/ell.json new file mode 100644 index 0000000000000..e1e56018920f0 --- /dev/null +++ b/tests/python/ell.json @@ -0,0 +1,120 @@ +{ +"num_patches" : 8, + "elements" : [ +{"order" : 0, +"num" : 20, +"max_num_per_patch" : 32, +"owned_offsets" : [0,0,3,5,7,10,12,14,20], +"total_offsets" : [0,20,40,60,80,100,120,140,160], +"l2g_mapping" : [12,4,0,3,13,1,5,11,15,9,2,8,7,10,19,14,6,17,16,18,2,5,10,13,1,4,9,12,0,3,15,7,8,11,17,19,6,14,16,18,1,13,12,4,0,9,5,2,10,3,8,15,11,17,19,7,6,14,16,18,4,7,12,0,3,13,1,5,15,6,9,8,11,17,19,14,2,10,16,18,16,17,18,19,11,12,14,15,3,0,8,4,9,13,1,6,7,5,10,2,8,9,12,0,11,1,13,10,4,3,15,17,19,16,18,14,5,2,7,6,0,12,4,3,9,8,1,11,13,15,17,19,5,7,6,14,10,2,16,18,3,6,11,14,15,19,12,4,0,7,16,18,8,17,9,13,1,5,10,2], +"g2r_mapping" : [12,3,0,14,5,1,15,6,10,11,2,16,13,4,17,18,7,8,9,19], +"l2r_mapping" : [13,5,12,14,4,3,1,16,18,11,0,10,6,2,19,17,15,8,7,9,0,1,2,4,3,5,11,13,12,14,18,6,10,16,8,19,15,17,7,9,3,4,13,5,12,11,1,0,2,14,10,18,16,8,19,6,15,17,7,9,5,6,13,12,14,4,3,1,18,15,11,10,16,8,19,17,0,2,7,9,7,8,9,19,16,13,17,18,14,12,10,5,11,4,3,15,6,1,2,0,10,11,13,12,16,3,4,2,5,14,18,8,19,7,9,17,1,0,6,15,12,13,5,14,11,10,3,16,4,18,8,19,1,6,15,17,2,0,7,9,14,15,16,17,18,19,13,5,12,6,7,9,10,8,11,4,3,1,2,0] +} +,{"order" : 2, +"num" : 66, +"max_num_per_patch" : 96, +"owned_offsets" : [0,8,17,23,31,42,48,56,66], +"total_offsets" : [0,66,132,198,264,330,396,462,528], +"l2g_mapping" : [0,3,8,10,11,15,17,18,1,2,9,16,14,43,44,45,27,25,26,20,12,19,7,13,34,32,33,4,48,49,53,54,47,24,46,37,35,36,5,6,39,38,64,57,63,42,55,56,21,22,23,60,58,59,65,50,40,41,62,30,61,52,29,51,31,28,16,32,33,34,35,36,37,38,39,15,17,8,9,11,19,20,10,12,13,14,18,0,2,3,25,27,46,47,5,6,7,1,4,26,44,45,49,53,57,58,60,64,43,48,21,23,24,41,42,54,55,22,40,56,63,65,28,29,30,50,51,59,61,62,31,52,9,12,13,14,19,20,8,10,15,16,18,32,33,35,36,38,11,17,34,37,39,1,2,3,4,6,7,26,27,44,45,49,53,57,58,60,64,0,25,46,47,5,43,48,21,23,24,41,42,54,55,22,40,56,63,65,28,29,30,50,51,59,61,62,31,52,2,22,23,24,25,27,46,47,0,3,8,10,11,15,17,18,1,4,6,7,9,12,14,20,26,44,45,49,53,57,58,60,64,5,13,43,48,21,41,42,54,55,16,19,32,34,35,37,39,33,36,38,40,56,63,65,28,29,30,50,51,59,61,62,31,52,28,29,30,31,51,52,58,59,60,61,62,50,57,63,64,65,43,44,45,48,49,53,54,55,56,1,2,3,4,6,7,9,10,11,12,14,20,26,27,40,41,42,21,22,24,25,47,0,23,46,5,13,18,8,15,17,19,36,37,16,32,34,35,39,33,38,4,5,6,7,48,49,12,13,19,20,36,37,1,2,3,9,10,11,14,26,27,44,45,53,57,58,60,64,0,18,43,28,29,30,50,51,54,55,56,59,61,63,8,15,16,32,33,35,38,17,34,39,25,46,47,21,23,24,41,42,22,40,65,62,31,52,1,26,43,44,45,53,57,64,0,3,4,5,7,13,14,18,48,2,6,9,10,11,12,20,27,49,58,60,8,15,17,25,46,47,21,23,24,41,42,54,55,19,36,37,16,32,33,35,38,28,29,30,50,51,56,59,61,63,34,39,22,40,65,62,31,52,21,40,41,42,50,54,55,56,63,65,0,1,2,23,24,25,26,43,45,46,22,28,29,30,44,48,49,51,53,57,58,59,61,52,27,47,64,31,60,62,3,4,6,7,9,10,11,12,14,20,8,15,17,18,5,13,19,36,37,16,32,34,35,39,33,38], +"g2r_mapping" : [0,48,23,1,42,43,44,45,2,17,3,4,18,19,20,5,8,6,7,21,22,56,24,25,26,27,49,28,31,32,33,34,9,10,11,12,13,14,15,16,57,58,59,50,51,52,29,30,46,47,60,35,36,53,61,62,63,54,37,38,39,40,41,64,55,65], +"l2r_mapping" : [0,1,2,3,4,5,6,7,48,23,17,8,20,50,51,52,28,27,49,22,18,21,45,19,11,9,10,42,46,47,53,61,30,26,29,14,12,13,43,44,16,15,55,54,64,59,62,63,56,24,25,39,37,38,65,60,57,58,41,33,40,36,32,35,34,31,8,9,10,11,12,13,14,15,16,5,6,2,17,4,21,22,3,18,19,20,7,0,23,1,27,28,29,30,43,44,45,48,42,49,51,52,47,53,54,37,39,55,50,46,56,25,26,58,59,61,62,24,57,63,64,65,31,32,33,60,35,38,40,41,34,36,17,18,19,20,21,22,2,3,5,8,7,9,10,12,13,15,4,6,11,14,16,48,23,1,42,44,45,49,28,51,52,47,53,54,37,39,55,0,27,29,30,43,50,46,56,25,26,58,59,61,62,24,57,63,64,65,31,32,33,60,35,38,40,41,34,36,23,24,25,26,27,28,29,30,0,1,2,3,4,5,6,7,48,42,44,45,17,18,20,22,49,51,52,47,53,54,37,39,55,43,19,50,46,56,58,59,61,62,8,21,9,11,12,14,16,10,13,15,57,63,64,65,31,32,33,60,35,38,40,41,34,36,31,32,33,34,35,36,37,38,39,40,41,60,54,64,55,65,50,51,52,46,47,53,61,62,63,48,23,1,42,44,45,17,3,4,18,20,22,49,28,57,58,59,56,24,26,27,30,0,25,29,43,19,7,2,5,6,21,13,14,8,9,11,12,16,10,15,42,43,44,45,46,47,18,19,21,22,13,14,48,23,1,17,3,4,20,49,28,51,52,53,54,37,39,55,0,7,50,31,32,33,60,35,61,62,63,38,40,64,2,5,8,9,10,12,15,6,11,16,27,29,30,56,25,26,58,59,24,57,65,41,34,36,48,49,50,51,52,53,54,55,0,1,42,43,45,19,20,7,46,23,44,17,3,4,18,22,28,47,37,39,2,5,6,27,29,30,56,25,26,58,59,61,62,21,13,14,8,9,10,12,15,31,32,33,60,35,63,38,40,64,11,16,24,57,65,41,34,36,56,57,58,59,60,61,62,63,64,65,0,48,23,25,26,27,49,50,52,29,24,31,32,33,51,46,47,35,53,54,37,38,40,36,28,30,55,34,39,41,1,42,44,45,17,3,4,18,20,22,2,5,6,7,43,19,21,13,14,8,9,11,12,16,10,15] +} +,{"order" : 1, +"num" : 61, +"max_num_per_patch" : 64, +"owned_offsets" : [0,4,13,19,26,36,42,49,61], +"total_offsets" : [0,61,122,183,244,305,366,427,488], +"l2g_mapping" : [3,13,14,19,0,1,2,4,5,11,12,15,18,20,16,46,47,48,28,29,25,6,21,17,7,36,37,38,9,10,50,54,23,49,22,39,40,41,8,42,58,60,32,44,43,53,24,26,27,55,56,57,52,45,59,30,34,51,31,35,33,18,20,36,37,38,39,40,41,42,19,11,13,15,21,12,14,16,17,0,3,4,29,49,6,7,8,1,2,9,28,46,56,58,5,10,48,22,25,26,43,47,23,27,44,54,60,24,50,32,34,35,53,57,55,59,30,31,52,45,51,33,11,12,15,16,17,21,14,20,38,41,13,18,37,40,0,1,2,6,9,28,46,56,58,3,4,19,29,49,5,7,10,48,8,39,36,42,22,25,26,43,47,50,23,27,44,54,60,32,34,35,53,57,55,59,30,31,52,24,45,51,33,0,4,22,23,24,29,49,3,13,14,19,1,2,6,9,11,12,28,46,56,58,5,7,10,16,48,25,26,43,47,15,18,21,37,40,17,20,38,41,36,27,44,54,60,45,8,39,50,32,34,35,53,57,55,59,30,31,52,51,42,33,30,31,33,34,35,51,55,56,57,59,32,52,58,60,46,47,48,50,53,54,0,1,2,6,9,11,12,28,43,44,45,23,25,27,29,4,5,22,26,3,7,10,16,8,13,14,19,49,17,21,39,15,18,37,40,20,38,41,24,36,42,6,7,8,9,10,50,17,21,39,0,1,2,11,12,28,46,56,58,3,5,16,48,32,34,35,47,53,54,57,14,15,20,38,41,13,18,37,40,42,4,19,29,49,22,25,26,43,23,27,44,60,55,59,30,31,52,33,51,45,36,24,1,2,5,28,46,48,58,3,7,10,16,0,6,9,11,12,56,4,13,14,19,29,49,22,25,26,43,47,8,17,21,39,50,15,20,38,41,32,34,35,53,54,57,18,37,40,23,27,44,60,55,59,30,31,52,36,24,45,51,42,33,25,26,27,32,43,44,45,47,52,53,54,60,2,4,5,22,24,34,35,46,48,50,57,51,23,28,29,30,31,55,58,0,1,6,9,11,12,56,3,13,14,19,49,7,10,16,33,59,8,17,21,39,15,18,37,40,20,38,41,36,42], +"g2r_mapping" : [19,42,43,0,20,44,36,37,38,39,40,13,14,1,2,15,16,17,4,3,5,18,21,22,23,49,50,51,45,24,26,27,52,28,29,30,6,7,8,9,10,11,12,53,54,55,46,56,47,25,41,31,57,58,59,32,33,34,48,35,60], +"l2r_mapping" : [0,1,2,3,19,42,43,20,44,13,14,15,4,5,16,46,56,47,45,24,49,36,18,17,37,6,7,8,39,40,41,59,22,25,21,9,10,11,38,12,48,60,52,54,53,58,23,50,51,32,33,34,57,55,35,26,29,31,27,30,28,4,5,6,7,8,9,10,11,12,3,13,1,15,18,14,2,16,17,19,0,20,24,25,36,37,38,42,43,39,45,46,33,48,44,40,47,21,49,50,53,56,22,51,54,59,60,23,41,52,29,30,58,34,32,35,26,27,57,55,31,28,13,14,15,16,17,18,2,5,8,11,1,4,7,10,19,42,43,36,39,45,46,33,48,0,20,3,24,25,44,37,40,47,38,9,6,12,21,49,50,53,56,41,22,51,54,59,60,52,29,30,58,34,32,35,26,27,57,23,55,31,28,19,20,21,22,23,24,25,0,1,2,3,42,43,36,39,13,14,45,46,33,48,44,37,40,16,47,49,50,53,56,15,4,18,7,10,17,5,8,11,6,51,54,59,60,55,38,9,41,52,29,30,58,34,32,35,26,27,57,31,12,28,26,27,28,29,30,31,32,33,34,35,52,57,48,60,46,56,47,41,58,59,19,42,43,36,39,13,14,45,53,54,55,22,49,51,24,20,44,21,50,0,37,40,16,38,1,2,3,25,17,18,9,15,4,7,10,5,8,11,23,6,12,36,37,38,39,40,41,17,18,9,19,42,43,13,14,45,46,33,48,0,44,16,47,52,29,30,56,58,59,34,2,15,5,8,11,1,4,7,10,12,20,3,24,25,21,49,50,53,22,51,54,60,32,35,26,27,57,28,31,55,6,23,42,43,44,45,46,47,48,0,37,40,16,19,36,39,13,14,33,20,1,2,3,24,25,21,49,50,53,56,38,17,18,9,41,15,5,8,11,52,29,30,58,59,34,4,7,10,22,51,54,60,32,35,26,27,57,6,23,55,31,12,28,49,50,51,52,53,54,55,56,57,58,59,60,43,20,44,21,23,29,30,46,47,41,34,31,22,45,24,26,27,32,48,19,42,36,39,13,14,33,0,1,2,3,25,37,40,16,28,35,38,17,18,9,15,4,7,10,5,8,11,6,12] +} +,{"order" : 3, +"num" : 24, +"max_num_per_patch" : 32, +"owned_offsets" : [0,4,7,9,12,16,18,21,24], +"total_offsets" : [0,24,48,72,96,120,144,168,192], +"l2g_mapping" : [0,2,4,5,14,8,6,3,10,16,18,15,11,1,12,22,19,7,20,23,13,21,17,9,10,11,12,4,6,2,3,5,1,0,16,14,8,18,15,22,19,7,20,23,13,21,17,9,3,6,1,2,5,11,16,4,0,12,14,10,8,18,15,22,19,7,20,23,13,21,17,9,7,8,15,0,13,18,5,14,19,22,2,3,16,23,20,4,6,1,17,21,10,11,9,12,9,17,20,21,23,22,19,18,13,8,14,7,0,15,16,5,1,2,3,4,6,10,11,12,1,16,3,14,6,5,0,18,2,11,8,22,19,4,12,15,20,23,13,10,7,21,17,9,14,18,22,0,8,16,19,20,23,5,15,1,13,21,17,2,3,7,9,4,6,10,11,12,13,19,23,7,17,18,22,15,9,8,14,20,21,0,16,5,1,2,3,4,6,10,11,12], +"g2r_mapping" : [0,16,1,7,2,3,8,9,10,12,4,5,6,21,18,11,17,13,19,22,14,15,20,23], +"l2r_mapping" : [0,1,2,3,18,10,8,7,4,17,19,11,5,16,6,20,22,9,14,23,21,15,13,12,4,5,6,2,8,1,7,3,16,0,17,18,10,19,11,20,22,9,14,23,21,15,13,12,7,8,16,1,3,5,17,2,0,6,18,4,10,19,11,20,22,9,14,23,21,15,13,12,9,10,11,0,21,19,3,18,22,20,1,7,17,23,14,2,8,16,13,15,4,5,12,6,12,13,14,15,23,20,22,19,21,10,18,9,0,11,17,3,16,1,7,2,8,4,5,6,16,17,7,18,8,3,0,19,1,5,10,20,22,2,6,11,14,23,21,4,9,15,13,12,18,19,20,0,10,17,22,14,23,3,11,16,21,15,13,1,7,9,12,2,8,4,5,6,21,22,23,9,13,19,20,11,12,10,18,14,15,0,17,3,16,1,7,2,8,4,5,6] +} + ], + "relations" : [ +{"from_order" : 0, +"to_order" : 0, +"offset" : [0,0,4,8,12,12,20,27,27,35,39,39,43,47,51,51,55,61,61,68,79,79,87,91,101,107,115,122], +"value" : [1,3,4,2,3,5,4,0,6,3,4,0,2,3,1,4,5,6,7,8,2,3,0,6,5,7,8,2,3,4,5,6,7,8,1,4,8,9,0,3,2,4,1,3,5,4,0,3,0,4,6,1,2,3,4,2,3,0,5,6,7,1,2,3,4,5,6,7,2,0,3,4,5,8,6,9,7,10,11,6,7,8,9,4,1,3,2,9,0,4,3,5,10,11,6,0,8,12,3,4,13,0,4,1,11,5,2,9,0,1,6,7,3,2,5,10,11,2,3,13,6,4] +} +,{"from_order" : 0, +"to_order" : 1, +"offset" : [0,0,4,8,12,12,20,27,27,35,39,39,43,47,51,51,55,61,61,68,79,79,87,91,101,107,115,122], +"value" : [2,3,4,8,0,9,1,2,5,6,7,8,1,6,2,3,4,7,8,9,0,10,2,11,5,12,13,0,7,1,8,9,10,5,6,2,3,4,6,0,2,3,9,6,7,8,9,1,2,4,5,2,3,4,5,0,1,2,6,7,8,0,7,2,8,9,10,5,11,0,1,12,13,14,15,3,4,16,6,12,13,14,15,0,1,4,7,16,1,2,6,3,17,18,19,7,20,21,9,10,22,4,5,6,23,8,9,24,0,2,25,26,5,10,11,27,28,3,8,29,30,11] +} +,{"from_order" : 0, +"to_order" : 2, +"offset" : [0,0,5,10,15,15,30,42,42,55,60,60,65,70,75,75,80,89,89,101,123,123,138,143,162,171,186,198], +"value" : [1,2,3,7,8,9,0,10,2,3,4,5,6,7,8,6,0,7,1,2,3,8,9,10,4,11,12,13,14,15,6,0,16,9,17,4,5,11,18,13,19,20,8,0,9,10,11,12,13,14,15,4,5,6,7,1,2,3,6,7,0,2,3,9,10,6,7,8,9,10,0,1,3,4,5,0,1,2,4,5,1,2,3,6,7,8,9,10,11,8,0,9,10,11,12,13,14,15,2,3,16,0,17,9,10,18,12,19,20,21,22,14,23,1,24,3,4,25,5,6,26,27,7,10,11,12,0,13,14,15,16,2,3,17,18,19,5,6,0,20,13,1,2,21,22,23,17,24,18,25,26,4,27,28,5,6,7,29,30,31,32,8,1,2,3,4,27,33,6,7,9,0,20,14,15,16,34,1,3,35,28,5,7,8,36,9,22,23,37,4,33,29,31,38,39,8,36,9] +} +,{"from_order" : 3, +"to_order" : 3, +"offset" : [0,3,6,8,11,11,13,15,17,17,20,23,23,25,28,30,30,32,34,36,38,38,40,42,42,45,49,52,52,54,57,60], +"value" : [3,4,5,2,6,3,1,8,0,1,7,3,2,4,2,0,1,2,1,4,3,0,5,2,4,3,2,5,0,1,1,3,0,4,5,3,0,2,2,1,3,0,3,5,1,4,0,2,6,1,7,8,1,3,0,5,2,4,1,6] +} +,{"from_order" : 3, +"to_order" : 2, +"value" : [1,0,8,9,4,2,10,3,6,2,5,11,1,3,12,7,8,13,14,15,16,9,17,18,19,10,20,21,22,20,23,12,24,11,25,26,14,27,28,29,18,15,30,31,32,33,17,34,35,21,36,37,22,27,38,39,40,25,36,41,42,30,43,44,45,31,46,47,33,48,49,50,51,43,52,53,54,55,47,44,45,48,56,57,58,59,53,60,61,62,55,63,64,65,62,59,3,0,1,2,6,14,4,5,8,1,4,7,10,11,9,0,15,12,17,14,13,11,12,16,30,17,18,19,23,16,19,20,30,32,28,29,23,21,31,22,34,32,43,36,31,42,34,35,25,22,24,33,33,35,37,49,27,46,24,26,41,37,38,54,48,49,50,53,46,44,51,45,40,38,39,61,55,59,53,54,48,44,52,47,63,58,61,62,65,57,59,60,64,56,57,58,26,1,2,3,5,0,1,4,26,24,41,25,16,6,0,7,23,7,3,10,19,4,13,14,29,24,43,31,17,6,8,9,23,37,21,22,20,11,13,15,21,42,29,30,18,9,11,12,28,22,38,27,27,30,32,49,40,46,38,39,36,32,33,54,48,49,50,53,46,44,51,45,35,33,34,61,55,59,53,54,48,44,52,47,63,58,61,62,65,57,59,60,64,56,57,58,3,37,1,2,5,0,4,24,7,3,4,6,9,8,16,0,39,37,52,38,24,26,28,40,9,11,22,15,16,35,25,26,39,40,41,53,32,28,29,54,12,10,20,11,19,21,34,22,25,17,36,27,55,59,53,54,31,29,30,61,14,10,13,42,23,20,21,43,19,17,33,18,65,57,59,60,63,58,61,62,45,42,44,49,47,43,46,50,64,56,57,58,48,44,46,51,3,0,1,2,5,1,11,4,8,12,6,7,10,2,7,9,15,11,24,13,14,21,12,13,41,22,23,24,37,18,21,22,41,42,39,40,38,26,45,37,25,16,17,18,44,42,43,48,27,47,25,26,46,44,45,49,17,28,19,20,27,32,35,52,30,28,50,29,33,53,31,32,30,34,51,35,55,53,54,59,36,31,34,56,61,59,60,64,58,56,62,57,63,60,62,65,3,0,1,2,21,0,4,5,3,6,7,18,12,30,21,22,9,15,6,8,14,16,18,29,14,28,12,13,19,22,23,36,17,42,15,16,11,8,47,10,20,13,52,19,27,23,24,41,59,36,37,38,49,42,43,44,51,45,47,48,54,57,52,53,26,24,25,39,62,34,38,41,59,55,61,58,50,44,45,46,57,55,60,56,63,33,39,40,65,32,34,35,64,31,32,33,0,2,3,4,1,4,5,39,7,5,6,57,9,8,0,17,24,17,31,1,3,10,16,25,38,39,40,54,27,6,26,55,62,52,54,57,9,20,14,15,33,36,31,32,12,10,11,18,38,34,61,37,63,51,55,56,65,50,52,53,21,28,19,20,12,22,13,14,36,34,60,35,64,49,50,51,30,28,29,44,23,19,22,41,58,44,45,46,43,41,47,42,59,45,47,48,3,0,1,2,3,5,6,7,9,4,7,8,14,0,20,13,33,22,4,27,16,18,28,5,36,28,29,8,35,14,15,19,37,21,22,23,34,12,15,16,11,17,24,18,38,29,30,31,39,23,31,32,40,10,11,12,24,41,25,26,40,45,48,53,43,41,54,42,46,50,44,45,43,47,55,48,52,50,51,59,49,44,47,56,61,59,60,64,58,56,62,57,63,60,62,65] +} +,{"from_order" : 3, +"to_order" : 1, +"value" : [4,5,0,6,7,8,4,9,1,10,2,11,1,12,3,2,11,13,0,4,5,10,2,14,6,5,8,15,16,17,18,4,19,6,7,20,9,21,22,10,11,23,5,21,24,10,14,23,25,12,26,11,13,27,15,5,17,28,29,30,20,6,18,15,16,31,19,32,33,7,34,20,22,35,36,11,23,37,21,5,24,38,28,29,26,36,39,11,27,37,40,18,41,42,15,31,43,20,44,16,45,31,34,32,20,46,47,48,49,50,40,42,15,51,52,43,41,42,45,31,44,20,43,47,48,53,54,55,49,42,56,51,57,58,52,42,59,45,55,58,60,42,56,59,2,0,3,12,1,4,13,5,6,12,17,7,3,6,8,12,4,7,11,0,9,15,12,1,10,23,13,14,12,17,18,10,11,14,15,12,26,23,24,14,16,17,19,18,26,14,15,16,23,26,24,25,28,34,18,26,19,27,20,33,30,26,35,28,34,47,27,26,33,30,40,35,29,18,21,27,20,37,37,27,29,30,40,44,21,41,22,20,36,37,32,29,45,48,30,44,43,37,39,40,51,44,36,41,37,46,38,42,53,31,32,48,30,52,57,43,45,48,51,44,39,37,43,38,42,58,54,55,53,48,49,52,59,56,57,48,50,51,55,56,60,48,49,50,15,17,29,1,3,4,0,17,5,1,2,4,17,15,29,32,18,30,14,0,10,1,6,2,23,14,15,1,6,3,5,33,13,2,4,9,20,15,31,18,30,41,10,11,25,6,2,7,14,15,23,16,24,28,12,13,35,2,8,9,16,15,28,20,40,31,34,11,12,2,7,8,19,14,26,16,24,37,37,16,19,20,40,45,26,42,27,24,36,37,22,19,46,47,20,45,44,37,39,40,50,45,36,42,37,57,38,43,52,21,22,47,20,51,56,44,46,47,50,45,39,37,44,38,43,58,53,54,52,47,48,51,59,55,56,47,49,50,54,55,60,47,48,49,2,3,26,4,27,40,17,0,5,12,1,26,5,3,6,1,2,26,0,11,7,12,1,21,28,26,41,27,40,44,26,12,17,18,29,42,7,0,11,16,9,24,12,11,21,18,29,25,41,26,28,29,51,42,20,17,43,48,18,42,0,15,8,16,9,30,11,13,22,16,24,35,18,11,25,14,23,47,57,41,43,48,51,42,53,19,20,48,18,52,8,31,10,9,30,36,15,13,32,16,30,35,13,11,22,45,14,23,58,56,57,48,50,51,54,55,53,48,49,52,39,31,33,30,36,37,32,46,34,30,35,38,55,56,60,48,49,50,33,34,59,30,37,38,0,1,2,10,3,4,5,1,11,10,4,18,6,7,12,10,14,8,9,0,6,10,3,8,11,29,13,10,18,19,12,27,13,10,14,19,29,32,28,15,18,19,32,22,27,14,15,19,28,32,29,38,33,30,27,20,34,22,35,32,22,21,36,14,15,16,37,31,32,58,38,33,20,21,39,22,35,36,34,31,47,35,37,32,14,21,16,24,41,17,39,20,21,26,45,42,23,21,40,43,24,41,20,25,44,26,45,51,21,23,40,26,42,48,44,52,46,45,51,55,25,23,49,26,51,48,59,52,53,51,55,56,49,50,54,51,48,57,53,54,60,51,56,57,0,10,1,2,3,4,15,10,21,3,4,5,10,0,1,13,20,6,11,10,19,15,25,21,12,0,7,13,30,6,18,9,10,13,29,20,9,10,18,11,39,19,44,11,14,15,25,27,9,12,34,13,29,30,7,8,37,30,6,33,14,9,41,11,39,44,17,14,50,22,15,27,49,44,46,25,26,27,34,35,40,29,30,31,36,37,38,30,32,33,41,47,42,39,43,44,51,16,17,22,15,28,55,49,50,22,26,27,46,44,49,45,48,58,59,35,36,30,31,32,43,47,44,60,45,48,52,53,51,22,23,28,57,54,55,22,24,26,53,54,56,22,23,24,1,0,2,4,27,5,24,1,3,4,27,41,6,3,49,37,4,41,11,0,7,1,17,2,3,11,21,1,17,24,4,0,5,13,9,32,48,24,26,27,40,41,50,16,6,37,4,42,54,48,49,37,40,41,7,11,0,15,19,10,21,46,22,17,23,24,12,0,8,28,13,9,26,24,48,25,47,57,51,52,50,37,38,42,58,53,54,37,39,40,11,14,18,15,19,33,0,12,8,15,10,29,23,46,24,56,25,47,52,53,60,37,38,39,18,43,20,19,33,34,14,12,30,15,33,29,55,43,44,33,34,35,30,31,45,33,29,36,44,45,59,33,35,36,4,0,5,1,2,6,5,0,4,7,9,10,8,5,11,3,9,10,15,24,0,16,1,2,23,28,8,3,18,9,0,12,25,19,7,10,30,25,11,3,19,10,26,24,42,13,15,0,27,28,46,3,17,18,25,31,26,12,13,0,12,32,14,19,7,20,29,37,30,3,19,22,47,27,29,3,17,22,31,32,38,12,13,14,19,32,20,34,44,21,38,31,32,36,40,45,33,32,43,48,34,44,31,35,39,36,40,52,32,33,43,36,45,49,39,53,41,40,52,56,35,33,50,36,52,49,59,53,54,52,56,57,50,51,55,52,49,58,54,55,60,52,57,58] +} +,{"from_order" : 3, +"to_order" : 0, +"value" : [0,1,2,3,0,1,4,5,4,1,6,5,2,1,0,5,0,3,2,7,0,8,1,3,0,4,9,5,2,0,9,5,6,10,4,5,0,7,2,11,8,3,0,7,1,8,12,3,9,4,13,5,9,0,2,11,4,10,13,5,0,14,8,7,8,15,3,7,12,3,8,16,17,14,0,7,14,15,8,7,3,15,8,16,17,18,14,7,19,15,14,7,14,18,19,7,1,0,3,4,6,3,2,4,3,0,2,4,3,5,1,4,7,3,6,4,7,5,3,4,8,7,6,4,8,5,7,4,6,7,8,12,7,5,8,9,7,13,8,12,7,9,8,13,7,10,5,9,10,9,7,13,5,10,11,9,7,15,10,13,10,17,9,13,11,9,10,16,14,15,7,13,15,17,10,13,9,17,10,16,14,18,15,13,19,17,15,13,15,18,19,13,4,2,5,0,2,1,5,0,5,2,4,10,2,3,1,0,4,3,2,0,5,1,8,0,2,12,4,10,1,3,6,0,2,3,4,9,1,7,8,0,2,9,4,12,6,7,1,0,2,11,3,9,11,9,2,12,3,11,15,9,2,14,11,12,11,17,9,12,15,9,11,16,13,14,2,12,14,17,11,12,9,17,11,16,13,18,14,12,19,17,14,12,14,18,19,12,1,4,8,9,2,8,0,4,0,8,1,4,2,0,3,4,4,15,8,9,8,4,2,12,3,0,2,6,2,4,3,12,8,15,4,12,2,14,8,12,2,0,5,6,3,2,10,6,2,12,3,11,14,15,8,12,13,14,2,12,5,0,7,6,2,5,10,6,10,2,3,11,19,15,14,12,13,18,14,12,7,16,5,6,10,5,17,6,14,18,19,12,5,16,17,6,3,0,2,4,2,6,3,4,1,3,5,4,1,0,3,4,3,6,7,4,5,3,7,4,7,6,8,4,7,8,5,4,8,6,7,15,5,7,11,8,5,8,9,4,16,8,7,15,5,11,9,8,11,7,16,8,5,4,9,10,9,11,5,14,12,5,9,10,5,11,13,14,9,5,12,14,13,11,17,14,5,13,12,14,17,19,13,14,12,13,18,14,13,19,18,14,1,2,3,0,2,4,3,0,3,2,1,5,2,9,3,4,2,6,1,5,3,8,2,5,2,8,3,9,10,9,2,4,2,8,6,5,1,6,7,5,2,10,8,9,2,12,10,4,10,15,9,4,6,8,16,5,6,17,7,5,8,10,18,9,11,12,2,4,12,15,10,4,9,15,10,19,16,17,6,5,18,9,10,19,11,13,12,4,14,15,12,4,12,13,14,4,1,3,0,7,9,3,1,7,1,11,9,7,1,2,0,3,1,9,2,3,1,7,0,5,9,15,3,7,10,11,1,7,11,15,9,7,0,2,1,6,2,9,13,3,4,1,0,5,3,15,9,14,10,18,11,7,19,15,11,7,1,2,8,6,0,1,4,6,13,3,9,14,11,18,19,7,8,2,12,6,1,8,4,6,12,17,8,6,4,8,16,6,8,17,16,6,0,3,4,1,4,3,0,2,5,3,4,2,9,0,4,1,11,3,5,2,4,0,6,2,6,5,4,2,7,4,9,0,5,10,11,2,6,4,7,0,6,0,8,2,13,5,6,2,13,10,5,2,6,7,8,0,6,2,8,12,8,7,6,16,14,6,8,12,6,7,15,16,8,6,14,16,15,7,17,16,6,15,14,16,17,19,15,16,14,15,18,16,15,19,18,16] +} +,{"from_order" : 2, +"to_order" : 3, +"offset" : [0,1,3,5,7,8,9,10,11,11,13,15,16,17,19,20,21,22,23,23,25,27,28,30,32,33,33,35,36,37,39,41,42,43,44,44,45,47,49,50,51,52,53,55,56,57,58,58,60,61,62,64,65,66,66,68,70,71,73,75,77,79,80,80,82,83,84,86,88,90,91,93,95,96], +"value" : [0,0,3,1,2,1,3,1,2,2,3,3,0,0,2,0,0,1,2,1,1,2,2,3,1,0,1,0,0,4,1,5,1,3,1,0,0,0,2,1,2,1,2,2,0,0,1,0,3,0,1,1,2,2,3,2,3,3,0,1,0,0,0,2,1,1,3,0,4,1,0,0,5,0,1,1,2,7,2,2,3,0,0,0,0,1,4,2,5,1,1,1,2,6,2,2] +} +,{"from_order" : 2, +"to_order" : 2, +"offset" : [0,7,18,28,38,46,52,57,64,64,73,82,87,92,101,107,112,117,122,122,132,141,148,158,168,175,175,185,190,195,204,214,223,229,234,234,239,248,257,262,268,273,281,290,295,300,305,305,314,319,325,335,340,348,348,358,369,376,388,399,410,422,429,429,438,444,449,459,469,480,488,497,507,514], +"value" : [1,7,9,17,34,8,13,9,3,4,16,8,27,22,12,14,0,7,4,6,3,5,7,10,11,21,25,36,9,1,4,16,10,20,12,2,5,7,9,1,3,16,10,19,2,6,6,2,3,7,11,26,2,4,11,24,5,0,1,2,3,5,23,12,10,3,11,12,14,1,4,9,2,3,8,11,12,0,14,4,2,7,3,9,0,1,7,2,0,10,1,8,6,8,11,12,0,14,1,5,7,6,17,18,14,4,7,14,15,5,4,8,8,1,2,4,5,1,3,4,6,7,16,5,7,1,3,6,9,4,11,13,25,26,5,0,7,3,2,4,14,41,26,3,10,1,4,14,21,23,24,26,29,0,7,1,2,10,5,19,6,0,9,11,13,1,2,14,0,16,25,26,1,4,19,9,11,12,5,16,24,26,8,4,6,3,7,2,37,52,3,6,1,37,38,2,6,1,7,37,4,24,39,40,5,7,8,0,6,37,3,24,39,40,24,28,32,0,9,11,12,4,7,7,8,0,4,2,3,4,5,1,3,6,3,2,9,1,4,3,5,2,11,12,7,13,0,4,3,10,1,11,12,7,13,0,9,2,10,1,5,0,5,0,1,11,23,24,4,1,3,11,15,8,17,18,20,21,12,7,9,8,10,1,2,11,12,13,6,9,7,10,6,12,14,10,0,2,6,7,9,2,3,7,8,12,14,3,18,21,2,5,1,4,3,7,2,0,4,3,6,9,1,0,5,2,6,9,12,14,0,18,21,1,7,30,21,0,1,5,21,22,23,24,25,0,2,4,9,10,12,14,3,17,1,4,8,2,24,5,7,0,17,4,34,36,31,38,39,8,0,4,39,40,3,16,0,9,10,12,14,4,25,5,6,26,2,16,0,17,1,3,25,5,6,26,2,39,40,1,24,7,3,4,25,6,26,39,54,57,27,7,50,51,52,55,57,3,4,25,5,26,6,27,1,24,5,57,62,14,15,16,3,5,13,2,20,1,3,7,9,0,20,2,3,6,0,13,1,2,6,0,14,15,16,5,1,7,9,33,9,22,23,29,31,8,27,6,7,0,14,15,16,3,17,18,6,28,7,8,2,3,17,18,5,4,27,7,1,3,9,4,27,6,28,5,8,36,9,22,23,4,29,31,28,5,7,4,33,1,3,7,8,36] +} +,{"from_order" : 2, +"to_order" : 1, +"value" : [0,7,8,4,5,0,1,2,11,4,10,2,4,9,1,3,2,13,1,12,3,0,2,14,5,6,8,4,6,7,9,10,11,12,11,13,5,10,14,8,16,17,5,15,17,6,15,16,18,4,19,19,7,20,18,6,20,9,21,22,21,10,23,22,11,23,21,5,24,24,14,23,25,12,26,26,11,27,25,13,27,5,28,29,17,29,30,15,28,30,18,15,31,20,16,31,19,32,33,34,32,20,33,7,34,22,35,36,36,11,37,35,23,37,24,38,29,21,38,28,26,36,39,39,27,37,40,18,41,40,42,15,41,42,31,44,20,43,44,16,45,43,45,31,20,47,48,32,46,48,34,46,47,49,50,40,50,15,51,49,42,51,52,43,41,52,42,45,43,48,53,44,47,53,54,55,49,55,42,56,54,56,51,57,58,52,58,42,59,57,59,45,55,58,60,60,56,59,0,12,1,3,12,4,2,1,4,2,0,3,6,12,7,5,17,7,13,5,6,8,4,7,3,6,8,9,15,1,11,0,9,11,15,12,10,14,12,18,10,11,13,12,17,10,23,13,18,14,15,23,14,17,24,16,17,26,14,16,19,15,16,19,20,33,18,27,20,18,26,19,21,20,37,29,18,21,22,20,36,21,41,22,24,25,34,23,25,28,23,26,24,26,27,33,26,28,34,29,27,37,26,30,35,27,30,40,30,28,47,29,30,44,32,48,30,31,30,52,53,31,32,32,29,45,33,40,35,35,34,47,37,38,42,36,46,38,36,41,37,39,38,58,39,37,43,37,40,44,39,40,51,41,46,42,43,42,58,43,51,44,45,48,44,57,43,45,60,49,50,56,48,50,55,48,49,57,48,51,59,50,51,53,48,52,54,49,52,54,55,53,55,56,60,59,56,57,0,1,2,17,1,4,29,3,4,15,1,3,5,2,4,0,17,5,10,6,2,14,1,6,25,6,7,11,2,7,23,6,3,12,2,8,34,7,8,13,2,9,33,4,9,35,8,9,14,0,10,10,11,25,34,11,12,5,33,13,12,13,35,15,16,28,14,16,24,14,15,23,15,18,30,17,32,18,17,15,29,19,16,37,19,14,26,15,20,31,16,20,40,20,18,41,19,20,45,22,47,20,21,20,51,52,21,22,22,19,46,23,24,28,26,24,37,27,24,36,26,42,27,29,32,30,28,40,31,31,30,41,37,38,43,36,57,38,36,42,37,39,38,58,39,37,44,37,40,45,39,40,50,42,57,43,44,43,58,44,50,45,46,47,45,56,44,46,60,48,49,55,47,49,54,47,48,56,47,50,59,49,50,52,47,51,53,48,51,53,54,52,54,55,60,59,55,56,0,12,1,3,4,40,2,4,27,2,3,26,5,1,26,17,0,5,6,1,2,5,3,6,7,1,21,0,11,7,8,9,30,0,16,9,0,15,8,10,9,36,8,31,10,7,9,24,11,12,21,11,14,23,13,45,14,13,11,22,15,16,30,13,16,35,11,16,24,15,13,32,17,12,26,11,18,25,12,18,29,18,14,47,17,18,42,20,48,18,19,18,52,53,19,20,20,17,43,22,45,23,22,24,35,21,29,25,25,23,47,26,27,40,28,27,44,28,26,41,26,29,42,28,29,51,31,30,36,32,30,35,33,30,37,39,31,33,34,30,38,32,46,34,33,34,59,39,36,37,46,35,38,59,37,38,41,40,44,41,51,42,43,48,42,57,41,43,60,49,50,56,48,50,55,48,49,57,48,51,58,50,51,53,48,52,54,49,52,54,55,53,55,56,60,58,56,57,2,3,4,1,10,4,0,10,3,0,1,2,5,4,18,5,1,11,7,14,8,6,10,8,6,7,12,9,3,8,9,0,6,11,10,18,12,10,14,13,10,19,12,27,13,11,29,13,36,15,16,21,14,16,22,14,15,16,41,17,14,24,17,27,14,19,32,15,19,28,15,18,29,18,19,21,22,36,20,22,35,20,21,39,21,24,41,23,43,24,23,21,40,25,26,51,20,26,45,20,25,44,23,26,48,21,26,42,25,23,49,27,22,32,27,20,34,29,33,30,28,38,30,28,32,29,32,38,33,31,58,33,37,31,32,34,35,32,34,31,47,39,35,36,37,58,38,47,35,37,40,43,41,40,42,48,39,45,42,44,45,51,46,45,55,44,52,46,49,51,48,50,48,57,49,50,54,52,51,55,53,51,56,59,52,53,54,51,57,53,54,60,59,55,56,60,56,57,10,3,4,1,2,4,0,2,3,0,10,1,21,4,5,15,3,5,0,13,6,1,20,6,7,30,6,12,0,7,8,6,33,7,8,37,10,11,19,9,11,39,9,10,18,12,13,30,9,13,29,9,12,34,10,13,20,14,11,44,14,9,41,10,15,21,11,15,25,14,15,27,17,22,15,16,15,28,51,16,17,17,14,50,18,39,19,18,29,20,19,25,21,56,23,24,54,22,24,53,22,23,55,22,26,57,24,26,44,25,27,46,25,26,49,26,27,51,22,28,52,23,28,50,22,27,34,29,30,40,29,31,35,30,31,36,30,32,59,31,32,37,30,33,38,32,33,34,35,40,59,35,36,36,37,38,41,39,44,42,39,43,41,47,42,44,45,48,43,60,45,43,47,44,46,45,58,46,44,49,47,60,48,49,48,58,55,49,50,52,53,51,53,54,56,57,54,55,0,1,2,3,1,24,2,27,5,0,4,5,1,4,27,3,4,41,6,37,4,6,3,49,7,17,2,11,0,7,0,13,9,8,28,9,12,0,8,8,10,29,0,15,10,7,19,10,5,9,32,11,1,17,12,28,13,14,15,33,11,15,19,11,14,18,12,15,29,14,12,30,3,11,21,4,13,32,16,4,42,50,16,6,18,19,33,20,19,34,18,43,20,21,17,24,22,17,23,21,46,22,24,25,47,23,56,25,23,46,24,26,25,57,26,24,48,24,27,41,26,27,40,30,33,29,31,29,36,30,31,45,43,33,34,44,33,35,55,34,35,45,33,36,59,35,36,60,38,39,53,37,39,52,37,38,54,37,40,58,39,40,48,40,41,50,37,42,51,38,42,49,37,41,55,43,44,44,45,59,46,56,47,48,47,57,54,48,49,51,52,50,52,53,60,58,53,54,0,1,2,5,2,6,4,1,6,4,0,5,8,3,9,0,7,10,4,7,9,5,9,10,11,3,10,8,5,11,38,13,14,32,12,14,31,12,13,15,16,1,15,24,0,26,13,0,25,12,0,14,7,20,12,19,7,42,13,15,24,16,2,46,17,18,28,3,18,27,3,17,32,19,20,20,44,21,19,34,21,23,18,9,25,19,10,30,3,19,37,19,22,29,3,22,47,17,22,23,28,8,25,31,26,26,24,42,30,25,11,27,28,46,29,37,30,47,27,29,31,32,38,32,34,44,33,48,34,33,32,43,35,36,52,31,36,40,31,35,39,33,36,49,32,36,45,35,33,50,39,40,52,41,40,56,39,53,41,38,40,45,43,48,44,43,45,49,50,52,49,51,49,58,50,51,55,53,52,56,54,52,57,59,53,54,55,52,58,54,55,60,59,56,57,60,57,58] +} +,{"from_order" : 2, +"to_order" : 0, +"value" : [1,2,3,0,1,2,1,4,5,0,1,5,0,1,4,1,6,5,4,1,6,2,1,5,0,2,3,0,1,3,0,4,5,4,6,5,2,0,5,3,2,7,0,2,7,0,3,7,0,8,1,8,1,3,0,8,3,0,4,9,0,9,5,4,9,5,9,0,2,2,9,5,6,10,4,10,4,5,6,10,5,0,2,11,7,2,11,0,7,11,8,0,7,8,3,7,1,8,12,12,3,8,1,12,3,9,4,13,4,13,5,9,13,5,9,2,11,9,0,11,4,10,13,10,13,5,0,14,8,14,0,7,14,8,7,3,15,8,15,3,7,8,15,7,3,8,16,12,8,16,12,3,16,17,14,0,17,0,7,17,14,7,14,15,8,15,14,7,15,8,16,3,15,16,17,18,14,14,18,7,17,18,7,19,15,14,14,19,7,19,15,7,14,18,19,18,19,7,3,1,4,0,3,4,1,0,4,1,0,3,3,2,4,6,2,4,6,3,2,0,2,4,3,0,2,5,1,4,3,5,1,5,3,4,7,3,4,7,5,3,3,6,4,7,3,6,7,5,4,7,6,4,8,6,4,8,7,4,8,5,4,5,8,9,7,5,9,7,5,8,10,5,9,7,10,5,5,11,9,5,10,11,6,8,12,6,7,12,6,7,8,7,8,9,7,8,12,7,10,9,7,8,13,7,9,13,7,13,12,10,7,13,15,7,13,14,7,13,14,15,7,7,15,10,9,8,13,13,8,12,9,10,16,11,9,16,11,9,10,9,17,16,9,17,10,10,9,13,17,9,13,11,10,16,17,10,16,10,17,13,15,10,13,15,17,10,18,19,13,15,19,13,15,18,13,17,15,13,19,17,13,14,15,13,14,18,13,14,18,15,15,18,19,19,17,15,2,1,0,2,5,0,4,5,0,4,2,0,1,5,0,2,1,5,3,1,0,2,3,0,3,6,0,1,6,0,4,3,0,7,1,0,6,7,0,1,8,0,5,8,0,7,8,0,2,3,1,1,3,6,6,7,1,5,1,8,1,7,8,2,4,9,2,3,9,2,3,4,2,4,10,5,2,10,5,2,4,2,11,9,2,11,3,2,4,12,2,9,12,2,12,10,11,2,12,14,2,12,13,2,12,13,14,2,2,14,11,3,4,9,11,3,9,3,15,9,3,11,15,5,4,10,9,4,12,12,4,10,9,11,16,15,9,16,15,9,11,9,17,16,9,17,11,11,9,12,17,9,12,15,11,16,17,11,16,11,17,12,14,11,12,14,17,11,18,19,12,14,19,12,14,18,12,17,14,12,19,17,12,13,14,12,13,18,12,13,18,14,14,18,19,19,17,14,2,0,4,1,8,9,1,4,9,1,4,8,8,0,4,2,8,0,0,1,4,0,8,1,0,3,4,2,0,3,0,5,6,2,0,6,2,0,5,0,7,6,5,0,7,3,0,6,2,3,4,2,3,11,10,2,11,10,2,3,2,5,6,2,10,6,3,2,6,2,5,10,2,8,4,2,3,12,2,4,12,2,12,11,8,2,12,14,2,12,13,2,12,13,14,2,2,14,8,10,3,11,3,10,6,4,3,12,12,3,11,4,8,9,4,15,9,4,15,8,8,4,12,15,4,12,5,7,6,5,10,6,16,5,6,7,16,5,5,17,6,10,5,17,5,16,17,7,16,6,10,17,6,16,17,6,15,8,9,8,15,12,14,8,12,14,15,8,18,19,12,14,19,12,14,18,12,15,14,12,19,15,12,13,14,12,13,18,12,13,18,14,14,18,19,19,15,14,0,2,4,3,2,4,3,0,4,3,0,2,2,6,4,2,6,3,1,5,4,1,3,4,1,3,5,1,0,4,1,0,3,6,3,4,3,5,4,3,7,4,5,3,7,3,6,7,8,9,4,5,9,4,5,8,4,4,9,10,5,4,10,7,5,4,7,8,4,6,8,4,7,6,4,5,9,8,5,11,8,5,11,9,5,9,10,12,5,10,12,5,9,5,13,14,5,11,14,5,11,13,5,12,14,9,5,14,5,13,12,5,7,8,5,7,11,6,7,15,8,6,15,8,6,7,8,7,15,16,7,15,16,8,7,7,11,8,11,7,16,11,9,8,16,8,15,11,16,8,12,9,10,9,12,14,9,11,14,11,13,14,11,17,14,13,11,17,13,12,14,12,18,14,12,13,18,13,17,14,19,13,14,17,19,13,13,18,14,13,19,18,17,19,14,19,18,14,2,3,0,1,3,0,1,2,0,1,2,3,4,3,0,2,4,0,2,1,5,3,1,5,6,1,5,2,6,1,1,7,5,1,6,7,2,3,9,2,8,9,2,8,3,2,6,5,2,8,5,2,8,6,3,2,5,2,10,9,2,10,8,2,3,4,2,9,4,10,2,4,12,2,4,11,2,4,11,12,2,2,12,10,8,3,9,3,8,5,9,3,4,13,14,4,12,14,4,12,13,4,15,12,4,14,15,4,10,9,4,15,9,4,10,15,4,11,12,4,11,13,4,12,10,4,8,6,5,8,16,5,6,16,5,17,6,5,16,17,5,6,7,5,17,7,5,6,8,16,16,17,6,6,17,7,10,8,9,8,18,9,8,10,18,9,10,19,18,9,19,18,9,10,9,15,19,9,15,10,18,10,19,15,10,19,12,15,10,11,13,12,12,13,14,14,15,12,1,0,3,1,9,3,3,0,7,1,0,7,1,3,7,9,1,7,11,1,7,1,11,9,2,0,3,1,2,0,1,0,5,4,0,5,4,1,0,0,4,6,0,1,6,0,2,6,7,0,5,1,2,3,4,1,5,1,8,6,1,2,6,1,2,8,1,4,6,1,8,4,1,9,2,1,7,5,10,1,7,10,11,1,2,8,6,2,12,6,8,2,12,9,2,3,2,13,3,2,9,13,3,9,14,13,3,14,13,3,9,3,15,14,3,15,9,9,3,7,15,3,7,8,4,6,4,16,6,4,8,16,8,12,6,17,8,6,12,17,6,8,16,6,17,16,6,18,19,7,11,19,7,11,18,7,15,11,7,19,15,7,9,15,7,10,11,7,10,18,7,11,9,7,12,17,8,8,17,16,13,9,14,15,9,14,11,15,9,10,18,11,11,18,19,19,15,11,0,4,1,3,4,1,0,3,1,0,3,4,3,5,2,4,0,2,3,0,2,4,3,2,5,4,2,5,3,4,7,8,0,6,8,0,6,7,0,9,0,1,9,0,4,4,7,0,6,4,0,0,8,2,6,0,2,7,9,0,9,4,1,10,11,2,5,11,2,5,10,2,6,8,2,2,8,12,6,2,12,11,3,2,4,6,2,5,6,2,13,6,2,13,5,2,13,10,2,11,3,5,6,4,7,7,4,9,6,5,4,5,10,11,13,5,6,13,10,5,6,7,8,6,8,12,14,6,12,14,6,8,6,15,16,6,7,16,6,7,15,6,14,16,8,6,16,6,15,14,7,15,16,7,17,16,15,7,17,8,7,16,14,8,12,8,14,16,15,14,16,14,18,16,14,15,18,15,17,16,19,15,16,17,19,15,15,18,16,15,19,18,17,19,16,19,18,16] +} +,{"from_order" : 1, +"to_order" : 3, +"offset" : [0,2,4,7,8,8,10,12,13,15,17,18,20,22,23,23,25,29,35,37,40,42,42,46,49,51,53,54,56,57,57,59,61,62,64,66,67,69,70,72,73,73,76,78,79,81,83,84,84,90,94,96,99,104,106,108,108,114,116,118,124,126,129,130,133,135,138,142,144], +"value" : [0,3,1,2,1,2,3,2,3,0,3,0,0,0,2,0,2,1,1,2,1,2,2,3,1,3,0,4,1,3,7,1,11,5,9,0,4,0,1,5,1,5,3,10,6,1,3,1,2,0,2,0,2,0,1,2,2,0,3,0,1,0,0,3,0,1,1,2,3,2,2,3,3,0,2,4,0,2,0,0,1,0,1,1,3,11,16,9,0,5,3,4,0,1,3,0,4,1,2,0,5,1,7,2,0,5,7,2,3,9,0,7,5,1,3,0,3,0,8,4,11,12,6,2,0,1,0,1,2,0,10,5,1,4,2,4,1,2,5,1,6,2,6,2] +} +,{"from_order" : 1, +"to_order" : 2, +"offset" : [0,3,6,10,12,12,15,18,20,23,26,28,31,34,36,36,39,43,49,52,56,59,59,64,68,71,74,76,79,81,81,84,87,89,92,95,97,100,102,105,107,107,111,114,116,119,122,124,124,130,134,137,141,147,150,153,153,159,162,165,171,174,178,180,184,187,191,195,198], +"value" : [0,1,7,2,4,6,2,3,5,7,5,6,0,10,3,9,0,2,2,3,1,3,8,1,2,7,5,6,4,6,8,4,5,7,7,8,0,16,5,0,7,1,3,6,0,9,4,11,13,2,3,10,1,2,4,14,4,5,19,0,9,11,12,5,8,0,4,6,2,3,6,1,3,7,1,2,4,5,7,6,7,2,3,10,1,3,5,0,3,0,2,9,0,1,4,4,5,7,8,10,6,8,6,7,9,9,10,2,3,6,9,1,3,7,1,2,0,2,5,0,1,4,4,5,0,9,10,12,14,3,0,17,1,4,8,0,2,1,24,5,7,3,4,25,5,6,26,2,3,16,6,27,7,0,14,15,16,3,5,0,13,2,0,20,1,22,23,4,29,31,8,2,3,6,1,3,7,9,1,2,17,18,5,6,4,33,9,4,27,6,7,28,5,7,8,8,36,9] +} +,{"from_order" : 1, +"to_order" : 1, +"offset" : [0,13,26,40,50,50,59,69,75,84,94,102,111,121,127,127,143,160,173,186,198,209,209,226,240,250,260,266,280,290,290,299,308,314,326,338,346,355,368,380,386,386,401,412,420,433,442,454,454,470,487,500,517,536,551,567,567,581,591,601,616,628,640,648,664,675,689,705,718], +"value" : [4,7,1,2,3,19,33,5,8,24,29,14,17,4,0,7,2,3,19,33,9,11,12,22,26,36,4,0,7,1,3,19,33,10,11,14,23,13,27,37,4,0,7,1,2,19,33,12,13,25,10,11,12,13,3,6,9,1,2,0,9,2,14,15,12,16,17,4,7,0,9,1,3,4,8,2,4,8,10,11,12,0,13,6,2,3,8,14,15,12,16,17,1,7,23,24,25,17,13,6,7,8,10,11,12,0,13,3,5,7,8,5,6,8,14,15,12,16,17,1,4,2,3,4,5,6,7,14,15,16,17,18,1,19,20,21,22,10,2,11,5,12,13,14,15,16,17,18,0,19,20,21,22,6,2,3,4,7,8,9,0,10,11,5,12,13,1,6,3,4,7,8,9,15,23,28,29,30,31,1,6,2,4,7,8,9,17,29,32,5,33,1,6,2,3,7,8,9,0,10,2,11,12,13,17,29,32,4,33,11,12,13,14,15,16,17,18,19,20,7,1,8,9,10,5,6,0,7,8,9,10,5,6,12,21,2,26,27,28,29,3,4,6,12,1,21,26,27,28,29,2,4,6,26,40,17,5,41,42,43,2,3,6,27,40,44,3,26,40,17,41,42,43,0,7,1,8,9,10,6,0,7,1,8,9,10,5,2,3,4,1,10,11,6,12,13,2,3,9,0,10,11,6,12,13,2,4,5,0,3,9,1,4,5,0,2,9,10,4,14,15,16,17,18,19,8,1,2,5,10,3,14,15,16,17,18,19,8,1,2,4,28,29,30,11,18,7,8,9,0,1,10,11,12,13,6,8,9,20,21,22,23,24,25,26,27,14,12,6,7,9,10,3,4,14,15,16,17,18,19,6,7,8,0,2,3,1,2,6,7,8,9,10,11,3,12,13,14,15,16,17,0,2,6,7,8,10,18,19,4,20,21,0,1,6,7,8,3,4,5,9,10,11,0,12,13,14,15,16,17,2,4,5,10,18,19,1,20,21,2,3,5,22,23,24,15,25,21,26,27,28,2,3,4,11,1,12,13,14,15,3,4,16,6,7,2,8,9,10,5,11,0,12,13,14,15,3,4,16,6,17,2,23,24,25,26,27,0,7,8,9,10,5,1,17,23,24,25,26,27,11,0,1,12,13,14,15,4,16,6,46,24,47,21,48,41,49,11,0,1,12,13,14,15,3,16,6,37,38,39,27,5,32,40,41,42,0,7,2,8,9,10,37,38,39,4,27,32,40,41,42,52,53,37,54,50,49,11,0,1,12,13,14,15,3,4,16,12,13,14,15,1,4,7,24,2,25,26,5,10,11,12,13,14,15,0,4,7,16,2,6,24,0,25,26,5,10,11,16,1,6,27,28,8,29,30,11,17,18,19,7,20,21,9,10,22,12,13,14,15,0,1,7,5,6,23,8,9,4,6,23,8,9,24,0,2,25,26,10,11,4,5,23,8,9,16,1,2,12,13,14,15,0,1,4,3,17,18,19,20,21,9,10,22,4,5,6,23,9,27,28,3,29,30,11,4,5,6,23,8,3,17,18,19,7,20,21,10,22,24,0,2,25,26,5,11,3,17,18,19,7,20,21,9,22,27,28,3,8,29,30,24,0,2,25,26,5,10] +} +,{"from_order" : 1, +"to_order" : 0, +"value" : [1,2,1,4,1,5,1,6,0,1,0,2,0,3,1,3,2,3,0,4,0,5,4,5,4,6,6,5,2,5,0,7,3,7,2,7,0,8,8,1,3,8,9,0,4,9,9,5,9,2,6,10,10,4,10,5,0,11,2,11,7,11,8,7,12,8,1,12,12,3,9,13,4,13,13,5,9,11,10,13,14,0,14,8,14,7,15,8,3,15,15,7,12,16,3,16,8,16,17,14,17,0,17,7,15,14,15,16,17,18,14,18,18,7,19,15,14,19,19,7,18,19,3,1,1,4,1,0,0,3,0,4,6,2,3,2,2,4,0,2,5,1,7,3,5,3,3,4,3,6,7,4,5,4,8,4,6,4,7,5,5,8,5,9,10,5,5,11,6,7,6,8,6,12,7,8,7,9,7,12,7,10,7,13,14,7,15,7,8,9,8,12,8,13,11,9,9,10,9,16,9,17,9,13,11,10,10,16,17,10,10,13,15,10,11,16,13,12,15,13,18,13,19,13,17,13,14,13,14,15,14,18,15,18,15,19,17,15,17,16,19,17,18,19,2,1,2,0,1,0,4,0,5,0,1,5,3,0,6,0,7,0,8,0,3,1,1,6,7,1,1,8,2,3,2,4,2,9,5,2,2,10,2,11,2,12,13,2,14,2,3,4,3,9,3,6,11,3,3,15,4,9,5,4,4,10,4,12,5,10,5,8,6,7,7,8,15,9,9,11,9,16,9,17,9,12,12,10,15,11,11,16,17,11,11,12,14,11,14,12,18,12,19,12,17,12,13,12,13,14,13,18,14,18,14,19,17,14,15,16,17,16,19,17,18,19,2,0,0,4,1,4,1,8,1,9,8,0,0,1,0,3,0,5,0,6,0,7,2,3,2,4,10,2,2,11,2,5,2,6,2,8,2,12,13,2,14,2,3,4,10,3,3,11,3,6,3,12,4,8,4,9,4,15,4,12,5,6,5,7,5,10,16,5,5,17,10,6,7,6,16,6,17,6,7,16,8,9,15,8,8,12,14,8,15,9,10,11,10,17,12,11,14,12,18,12,19,12,15,12,13,12,13,14,13,18,14,18,14,19,15,14,19,15,16,17,18,19,3,0,3,2,0,2,0,4,2,4,2,6,1,3,1,5,1,4,1,0,3,4,6,3,3,5,3,7,5,4,8,4,9,4,4,10,6,4,7,4,5,11,5,9,5,8,12,5,5,10,5,13,5,14,5,7,8,6,6,7,6,15,16,7,8,7,7,15,7,11,11,8,9,8,16,8,8,15,11,9,12,9,9,10,9,14,12,10,11,13,11,14,11,17,11,16,12,14,13,12,12,18,13,14,13,17,19,13,13,18,17,14,19,14,18,14,16,15,17,19,19,18,1,2,1,3,1,0,2,0,3,0,4,0,1,5,6,1,1,7,2,8,2,3,2,9,2,6,2,5,2,10,2,4,11,2,12,2,8,3,3,9,3,5,3,4,12,4,13,4,14,4,9,4,15,4,10,4,11,4,8,5,6,5,16,5,17,5,7,5,8,6,6,16,17,6,6,7,17,7,8,9,8,16,10,8,8,18,18,9,9,10,9,19,9,15,18,10,10,19,15,10,12,10,11,12,11,13,12,13,12,14,15,12,13,14,14,15,15,19,16,17,18,19,1,0,1,3,0,3,1,9,1,7,0,7,11,1,2,0,4,0,0,5,0,6,1,2,4,1,1,5,1,8,1,6,10,1,2,3,2,8,2,6,2,12,9,2,2,13,13,3,3,9,3,14,3,15,3,7,4,5,4,6,8,4,4,16,7,5,8,6,12,6,17,6,16,6,11,7,18,7,19,7,15,7,9,7,10,7,8,12,17,8,8,16,13,9,9,14,15,9,11,9,10,11,10,18,11,18,11,19,15,11,12,17,13,14,15,14,19,15,17,16,18,19,0,4,0,1,4,1,5,2,0,3,3,4,3,1,0,2,3,5,3,2,4,2,5,4,6,0,7,0,8,0,9,0,9,1,10,2,11,2,6,2,8,2,2,12,13,2,11,3,9,4,6,4,4,7,5,10,5,11,13,5,5,6,6,7,6,8,14,6,6,12,6,15,6,16,13,6,7,8,7,15,7,16,7,17,7,9,14,8,8,12,8,16,10,11,13,10,14,12,14,16,15,14,14,18,15,16,15,17,19,15,15,18,17,16,19,16,18,16,17,19,19,18] +} +,{"from_order" : 0, +"to_order" : 3, +"offset" : [0,0,2,4,6,6,14,20,20,26,28,28,30,32,34,34,36,40,40,46,58,58,66,68,78,82,90,96], +"value" : [0,2,3,0,1,2,3,0,7,4,1,11,5,9,3,7,1,11,5,9,3,10,15,6,1,2,0,2,0,3,2,3,0,1,0,1,0,2,4,9,3,11,16,9,0,5,3,11,15,16,9,20,4,0,5,1,7,2,13,3,9,0,10,7,5,1,3,0,8,10,14,4,5,1,11,12,6,2,0,4,1,2,3,9,0,7,5,1,6,2,8,4,11,12,6,2] +} + ], +"attrs" : { + "x" : [ +0,0,0,1,0,0,2,0,0,0,1,0,1,1,0,2,1,0,0,2,0,1,2,0,0,0,1,1,0,1,2,0,1,0,1,1,1,1,1,2,1,1,0,2,1,1,2,1,0,1,3,1,1,3,0,2,3,1,2,3 ]} +} \ No newline at end of file diff --git a/tests/python/examples/__init__.py b/tests/python/examples/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/tests/python/examples/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/python/examples/algorithm/test_laplace.py b/tests/python/examples/algorithm/test_laplace.py new file mode 100644 index 0000000000000..e957b9ea8c486 --- /dev/null +++ b/tests/python/examples/algorithm/test_laplace.py @@ -0,0 +1,10 @@ +def test_laplace(): + from taichi.examples.algorithm.laplace import laplace, x, y + + for i in range(10): + x[i, i + 1] = 1.0 + + laplace() + + for i in range(10): + assert y[i, i + 1] == (4.0 if i % 3 == 1 else 0.0) diff --git a/tests/python/examples/algorithm/test_print_offset.py b/tests/python/examples/algorithm/test_print_offset.py new file mode 100644 index 0000000000000..093eff5c96e15 --- /dev/null +++ b/tests/python/examples/algorithm/test_print_offset.py @@ -0,0 +1,54 @@ +import argparse + +from taichi.lang import impl + +import taichi as ti + +FRAMES = 100 + + +def test_print_offset(): + from taichi.examples.algorithm.print_offset import fill + fill() + + +def video_print_offset(result_dir): + from taichi.examples.algorithm.print_offset import a, fill, m, n + video_manager = ti.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + + fill() + + gui = ti.GUI('layout', + res=(256, 512), + background_color=0xFFFFFF, + show_gui=False) + + for f in range(FRAMES): + for i in range(1, m): + gui.line(begin=(0, i / m), + end=(1, i / m), + radius=2, + color=0x000000) + for i in range(1, n): + gui.line(begin=(i / n, 0), + end=(i / n, 1), + radius=2, + color=0x000000) + for i in range(n): + for j in range(m): + gui.text(f'{a[i, j]}', ((i + 0.3) / n, (j + 0.75) / m), + font_size=30, + color=0x0) + video_manager.write_frame(gui.get_image()) + gui.clear() + + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate print_offset video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_print_offset(parser.parse_args().output_directory) diff --git a/tests/python/examples/autodiff/__init__.py b/tests/python/examples/autodiff/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/tests/python/examples/autodiff/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/python/examples/autodiff/test_minimization.py b/tests/python/examples/autodiff/test_minimization.py new file mode 100644 index 0000000000000..272949eedaa79 --- /dev/null +++ b/tests/python/examples/autodiff/test_minimization.py @@ -0,0 +1,22 @@ +import random + +import pytest + +import taichi as ti + + +def test_minimization(): + from taichi.examples.autodiff.minimization import (L, gradient_descent, n, + reduce, x, y) + + for i in range(n): + x[i] = random.random() + y[i] = random.random() + + for k in range(100): + with ti.Tape(loss=L): + reduce() + gradient_descent() + + for i in range(n): + assert x[i] == pytest.approx(y[i], rel=1e-2) diff --git a/tests/python/examples/autodiff/test_regression.py b/tests/python/examples/autodiff/test_regression.py new file mode 100644 index 0000000000000..82e92334917ee --- /dev/null +++ b/tests/python/examples/autodiff/test_regression.py @@ -0,0 +1,49 @@ +import argparse + +from tests import test_utils + + +def test_regression(): + from taichi.examples.autodiff.regression import initialize, regress_raw + initialize() + regress_raw() + + +def pic_regression(result_dir): + import numpy as np + from matplotlib import pyplot as plt + from taichi.examples.autodiff.regression import (coeffs, initialize, + number_coeffs, + regress_raw, xs, ys) + + initialize() + regress_raw() + + curve_xs = np.arange(-2.5, 2.5, 0.01) + curve_ys = curve_xs * 0 + for i in range(number_coeffs): + curve_ys += coeffs[i] * np.power(curve_xs, i) + + plt.title( + 'Nonlinear Regression with Gradient Descent (3rd order polynomial)') + ax = plt.gca() + ax.scatter(xs, ys, label='data', color='r') + ax.plot(curve_xs, curve_ys, label='fitted') + ax.legend() + ax.grid(True) + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + print(result_dir + '/output.png') + + # Create new directory + test_utils.mkdir_p(result_dir) + plt.savefig(result_dir + '/output.png') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate regression pic') + parser.add_argument('output_directory', + help='output directory of generated pic') + pic_regression(parser.parse_args().output_directory) diff --git a/tests/python/examples/autodiff/test_simple_derivative.py b/tests/python/examples/autodiff/test_simple_derivative.py new file mode 100644 index 0000000000000..314635ef303c9 --- /dev/null +++ b/tests/python/examples/autodiff/test_simple_derivative.py @@ -0,0 +1,40 @@ +import argparse + +from tests import test_utils + + +def test_simple_derivative(): + from taichi.examples.autodiff.simple_derivative import initialize + + initialize() + + +def pic_simple_derivative(result_dir): + from matplotlib import pyplot as plt + from taichi.examples.autodiff.simple_derivative import (grad_xs, + initialize, xs, ys) + + initialize() + + plt.title('Auto Diff') + ax = plt.gca() + ax.plot(xs, ys, label='f(x)') + ax.plot(xs, grad_xs, label='f\'(x)') + ax.legend() + ax.grid(True) + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + + # Create new directory + test_utils.mkdir_p(result_dir) + plt.savefig(result_dir + '/output.png') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Generate simple_derivative pic') + parser.add_argument('output_directory', + help='output directory of generated pic') + pic_simple_derivative(parser.parse_args().output_directory) diff --git a/tests/python/examples/rendering/test_cornell_box.py b/tests/python/examples/rendering/test_cornell_box.py new file mode 100644 index 0000000000000..152d7f2a73bc6 --- /dev/null +++ b/tests/python/examples/rendering/test_cornell_box.py @@ -0,0 +1,43 @@ +import argparse + +import taichi as ti + +FRAMES = 200 + + +def test_cornell_box(): + from taichi.examples.rendering.cornell_box import render, tonemap + for i in range(FRAMES): + render() + interval = 10 + if i % interval == 0: + tonemap(i) + + +def video_cornell_box(result_dir): + from taichi.examples.rendering.cornell_box import (render, tonemap, + tonemapped_buffer) + video_manager = ti.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + gui = ti.GUI("Taichi Cornell Box", + res=800, + background_color=0x112F41, + show_gui=False) + for i in range(FRAMES): + render() + interval = 10 + if i % interval == 0: + tonemap(i) + + gui.set_image(tonemapped_buffer) + video_manager.write_frame(gui.get_image()) + gui.clear() + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate cornell_box video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_cornell_box(parser.parse_args().output_directory) diff --git a/tests/python/examples/rendering/test_taichi_logo.py b/tests/python/examples/rendering/test_taichi_logo.py new file mode 100644 index 0000000000000..2a8e8539599da --- /dev/null +++ b/tests/python/examples/rendering/test_taichi_logo.py @@ -0,0 +1,32 @@ +import argparse + +import taichi as ti + +FRAMES = 100 + + +def test_taichi_logo(): + from taichi.examples.rendering.taichi_logo import paint + paint() + + +def video_taichi_logo(result_dir): + from taichi.examples.rendering.taichi_logo import n, paint, x + video_manager = ti.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + paint() + gui = ti.GUI('Logo', (n, n), show_gui=False) + for i in range(FRAMES): + gui.set_image(x) + video_manager.write_frame(gui.get_image()) + gui.clear() + + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate taichi_logo video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_taichi_logo(parser.parse_args().output_directory) diff --git a/tests/python/examples/simulation/test_ad_gravity.py b/tests/python/examples/simulation/test_ad_gravity.py new file mode 100644 index 0000000000000..9eba1f53c20f6 --- /dev/null +++ b/tests/python/examples/simulation/test_ad_gravity.py @@ -0,0 +1,40 @@ +import argparse + +import taichi as ti + +FRAMES = 100 + + +def test_ad_gravity(): + from taichi.examples.simulation.ad_gravity import init, substep + + init() + for _ in range(FRAMES): + for _ in range(50): + substep() + + +def video_ad_gravity(result_dir): + import numpy as np + from taichi.examples.simulation.ad_gravity import init, substep, x + + video_manager = ti.tools.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + + gui = ti.GUI('Autodiff gravity', show_gui=False) + init() + for _ in range(FRAMES): + for _ in range(50): + substep() + gui.circles(x.to_numpy(), radius=3) + video_manager.write_frame(gui.get_image()) + gui.clear() + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate ad_gravity video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_ad_gravity(parser.parse_args().output_directory) diff --git a/tests/python/examples/simulation/test_game_of_life.py b/tests/python/examples/simulation/test_game_of_life.py new file mode 100644 index 0000000000000..1632bff235c0f --- /dev/null +++ b/tests/python/examples/simulation/test_game_of_life.py @@ -0,0 +1,43 @@ +import argparse + +import taichi as ti + +FRAMES = 100 + + +def test_game_of_life(): + from taichi.examples.simulation.game_of_life import init, run + + init() + for i in range(FRAMES): + run() + + +def video_game_of_life(result_dir): + import numpy as np + from taichi.examples.simulation.game_of_life import (alive, img_size, init, + run) + + video_manager = ti.tools.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + + gui = ti.GUI('Game of Life', (img_size, img_size), show_gui=False) + gui.fps_limit = 15 + + init() + for i in range(FRAMES): + run() + + gui.set_image( + ti.tools.imresize(alive, img_size).astype(np.uint8) * 255) + video_manager.write_frame(gui.get_image()) + gui.clear() + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate game_of_life video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_game_of_life(parser.parse_args().output_directory) diff --git a/tests/python/examples/simulation/test_mpm99.py b/tests/python/examples/simulation/test_mpm99.py new file mode 100644 index 0000000000000..297fe023608f9 --- /dev/null +++ b/tests/python/examples/simulation/test_mpm99.py @@ -0,0 +1,45 @@ +import argparse + +import taichi as ti + +FRAMES = 100 + + +def test_mpm99(): + from taichi.examples.simulation.mpm99 import dt, initialize, substep + + initialize() + for i in range(FRAMES): + for s in range(int(2e-3 // dt)): + substep() + + +def video_mpm99(result_dir): + from taichi.examples.simulation.mpm99 import (dt, initialize, material, + substep, x) + + video_manager = ti.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + initialize() + gui = ti.GUI("Taichi MLS-MPM-99", + res=512, + background_color=0x112F41, + show_gui=False) + for i in range(FRAMES): + for s in range(int(2e-3 // dt)): + substep() + gui.circles(x.to_numpy(), + radius=1.5, + palette=[0x068587, 0xED553B, 0xEEEEF0], + palette_indices=material) + video_manager.write_frame(gui.get_image()) + gui.clear() + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate mpm99 video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_mpm99(parser.parse_args().output_directory) diff --git a/tests/python/examples/simulation/test_nbody.py b/tests/python/examples/simulation/test_nbody.py new file mode 100644 index 0000000000000..eb90b4ad9517a --- /dev/null +++ b/tests/python/examples/simulation/test_nbody.py @@ -0,0 +1,45 @@ +import argparse + +import taichi as ti + +FRAMES = 100 + + +def test_nbody(): + from taichi.examples.simulation.nbody import (compute_force, initialize, + substepping, update) + + initialize() + for i in range(FRAMES): + for i in range(substepping): + compute_force() + update() + + +def video_nbody(result_dir): + from taichi.examples.simulation.nbody import (compute_force, initialize, + planet_radius, pos, + substepping, update) + + video_manager = ti.tools.VideoManager(output_dir=result_dir, + framerate=24, + automatic_build=False) + + initialize() + gui = ti.GUI('N-body problem', (800, 800), show_gui=False) + for i in range(FRAMES): + for i in range(substepping): + compute_force() + update() + + gui.circles(pos.to_numpy(), color=0xffffff, radius=planet_radius) + video_manager.write_frame(gui.get_image()) + gui.clear() + video_manager.make_video(mp4=True, gif=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate nbody video') + parser.add_argument('output_directory', + help='output directory of generated video') + video_nbody(parser.parse_args().output_directory) diff --git a/tests/python/expected/test_geometry_2d.png b/tests/python/expected/test_geometry_2d.png new file mode 100644 index 0000000000000..ecd189e91f0cd Binary files /dev/null and b/tests/python/expected/test_geometry_2d.png differ diff --git a/tests/python/expected/test_geometry_3d.png b/tests/python/expected/test_geometry_3d.png new file mode 100644 index 0000000000000..9a666b0aa1f6c Binary files /dev/null and b/tests/python/expected/test_geometry_3d.png differ diff --git a/tests/python/expected/test_imgui.png b/tests/python/expected/test_imgui.png new file mode 100644 index 0000000000000..8b51a2ab2fcb7 Binary files /dev/null and b/tests/python/expected/test_imgui.png differ diff --git a/tests/python/expected/test_set_image.png b/tests/python/expected/test_set_image.png new file mode 100644 index 0000000000000..26c83636ea420 Binary files /dev/null and b/tests/python/expected/test_set_image.png differ diff --git a/tests/python/fuse_test_template.py b/tests/python/fuse_test_template.py index 1a8b1a8738ed1..db808509b8d15 100644 --- a/tests/python/fuse_test_template.py +++ b/tests/python/fuse_test_template.py @@ -3,11 +3,11 @@ import taichi as ti -def template_fuse_dense_x2y2z(size=1024**3, - repeat=10, - first_n=100, - benchmark=0, - benchmark_repeat=50): +def template_fuse_dense_x2y2z( + size=1024**3, + repeat=10, + first_n=100, +): x = ti.field(ti.i32, shape=(size, )) y = ti.field(ti.i32, shape=(size, )) z = ti.field(ti.i32, shape=(size, )) @@ -30,39 +30,32 @@ def x_to_y_to_z(): for i in range(first_n): x[i] = i * 10 - if benchmark: - ti.benchmark(x_to_y_to_z, repeat=benchmark_repeat) - else: - # Simply test - for _ in range(repeat): - t = time.time() - x_to_y() - ti.sync() - print('x_to_y', time.time() - t) - - for _ in range(repeat): - t = time.time() - y_to_z() - ti.sync() - print('y_to_z', time.time() - t) - - for _ in range(repeat): - t = time.time() - x_to_y_to_z() - ti.sync() - print('fused x->y->z', time.time() - t) + # Simply test + for _ in range(repeat): + t = time.time() + x_to_y() + ti.sync() + print('x_to_y', time.time() - t) - for i in range(first_n): - assert x[i] == i * 10 - assert y[i] == x[i] + 1 - assert z[i] == x[i] + 5 + for _ in range(repeat): + t = time.time() + y_to_z() + ti.sync() + print('y_to_z', time.time() - t) + + for _ in range(repeat): + t = time.time() + x_to_y_to_z() + ti.sync() + print('fused x->y->z', time.time() - t) + + for i in range(first_n): + assert x[i] == i * 10 + assert y[i] == x[i] + 1 + assert z[i] == x[i] + 5 -def template_fuse_reduction(size=1024**3, - repeat=10, - first_n=100, - benchmark=0, - benchmark_repeat=50): +def template_fuse_reduction(size=1024**3, repeat=10, first_n=100): x = ti.field(ti.i32, shape=(size, )) first_n = min(first_n, size) @@ -76,33 +69,23 @@ def inc(): for i in x: x[i] = x[i] + 1 - if benchmark: - - def repeated_inc(): - for _ in range(repeat): - inc() - - ti.benchmark(repeated_inc, repeat=benchmark_repeat) - else: - # Simply test - reset() - ti.sync() - for _ in range(repeat): - t = time.time() - inc() - ti.sync() - print('single inc', time.time() - t) - - reset() - ti.sync() + # Simply test + reset() + ti.sync() + for _ in range(repeat): t = time.time() - for _ in range(repeat): - inc() + inc() ti.sync() - duration = time.time() - t - print( - f'fused {repeat} inc: total={duration} average={duration / repeat}' - ) + print('single inc', time.time() - t) - for i in range(first_n): - assert x[i] == i * 10 + repeat + reset() + ti.sync() + t = time.time() + for _ in range(repeat): + inc() + ti.sync() + duration = time.time() - t + print(f'fused {repeat} inc: total={duration} average={duration / repeat}') + + for i in range(first_n): + assert x[i] == i * 10 + repeat diff --git a/tests/python/py38_only.py b/tests/python/py38_only.py new file mode 100644 index 0000000000000..85ea9b827bd4d --- /dev/null +++ b/tests/python/py38_only.py @@ -0,0 +1,19 @@ +import taichi as ti +from tests import test_utils + +# The walrus operator is not supported until python 3.8, +# and pytest cannot handle files containing walrus operators when python version is below 3.8. +# So, we moved this test to the directory "python38". +# Tests in this directory will not be executed when python version is below 3.8. +# See https://github.com/taichi-dev/taichi/issues/3425 for more information. + + +@test_utils.test() +def test_namedexpr(): + @ti.kernel + def foo() -> ti.i32: + b = 2 + (a := 5) + b += a + return b + + assert foo() == 12 diff --git a/tests/python/test_abs.py b/tests/python/test_abs.py index 7ca262e58259e..a2ae2a6bab11e 100644 --- a/tests/python/test_abs.py +++ b/tests/python/test_abs.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_abs(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -15,7 +16,7 @@ def test_abs(): @ti.kernel def func(): for i in range(N): - x[i] = ti.abs(y[i]) + x[i] = abs(y[i]) for i in range(N): y[i] = i - 10 diff --git a/tests/python/test_ad_atomic.py b/tests/python/test_ad_atomic.py index 3539bf6b809c2..fa317a79ef045 100644 --- a/tests/python/test_ad_atomic.py +++ b/tests/python/test_ad_atomic.py @@ -1,8 +1,8 @@ import taichi as ti -from taichi import approx +from tests import test_utils -@ti.test() +@test_utils.test() def test_ad_reduce(): N = 16 @@ -23,6 +23,6 @@ def func(): func() func.grad() - assert total_loss == approx(loss[None]) + assert total_loss == test_utils.approx(loss[None]) for i in range(N): - assert x.grad[i] == approx(i * 2) + assert x.grad[i] == test_utils.approx(i * 2) diff --git a/tests/python/test_ad_basics.py b/tests/python/test_ad_basics.py index 1e98686c33119..a870e643a6d9f 100644 --- a/tests/python/test_ad_basics.py +++ b/tests/python/test_ad_basics.py @@ -1,9 +1,10 @@ import functools +import numpy as np import pytest import taichi as ti -from taichi import approx +from tests import test_utils has_autograd = False @@ -29,9 +30,11 @@ def wrapper(*args, **kwargs): def grad_test(tifunc, npfunc=None): npfunc = npfunc or tifunc - print(f'arch={ti.cfg.arch} default_fp={ti.cfg.default_fp}') - x = ti.field(ti.cfg.default_fp) - y = ti.field(ti.cfg.default_fp) + print( + f'arch={ti.lang.impl.current_cfg().arch} default_fp={ti.lang.impl.current_cfg().default_fp}' + ) + x = ti.field(ti.lang.impl.current_cfg().default_fp) + y = ti.field(ti.lang.impl.current_cfg().default_fp) ti.root.dense(ti.i, 1).place(x, x.grad, y, y.grad) @@ -47,12 +50,12 @@ def func(): func() func.grad() - assert y[0] == approx(npfunc(v), rel=1e-4) - assert x.grad[0] == approx(grad(npfunc)(v), rel=1e-4) + assert y[0] == test_utils.approx(npfunc(v), rel=1e-4) + assert x.grad[0] == test_utils.approx(grad(npfunc)(v), rel=1e-4) @if_has_autograd -@ti.test() +@test_utils.test() def test_size1(): x = ti.field(ti.i32) @@ -74,7 +77,7 @@ def test_size1(): lambda x: (x - 3) * (x - 1) + x * x, ]) @if_has_autograd -@ti.test() +@test_utils.test() def test_poly(tifunc): grad_test(tifunc) @@ -87,7 +90,7 @@ def test_poly(tifunc): (lambda x: ti.asin(x), lambda x: np.arcsin(x)), ]) @if_has_autograd -@ti.test(exclude=[ti.vulkan]) +@test_utils.test(exclude=[ti.vulkan]) def test_trigonometric(tifunc, npfunc): grad_test(tifunc, npfunc) @@ -98,7 +101,7 @@ def test_trigonometric(tifunc, npfunc): lambda x: (x + 1) * (x + 2) / ((x - 1) * (x + 3)), ]) @if_has_autograd -@ti.test() +@test_utils.test() def test_frac(tifunc): grad_test(tifunc) @@ -109,7 +112,7 @@ def test_frac(tifunc): (lambda x: ti.log(x), lambda x: np.log(x)), ]) @if_has_autograd -@ti.test() +@test_utils.test() def test_unary(tifunc, npfunc): grad_test(tifunc, npfunc) @@ -125,13 +128,13 @@ def test_unary(tifunc, npfunc): (lambda x: ti.max(1, x), lambda x: np.maximum(1, x)), ]) @if_has_autograd -@ti.test() +@test_utils.test() def test_minmax(tifunc, npfunc): grad_test(tifunc, npfunc) @if_has_autograd -@ti.test() +@test_utils.test() def test_mod(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -159,7 +162,7 @@ def func2(): (lambda y: ti.atan2(y, 0.4), lambda y: np.arctan2(y, 0.4)), ]) @if_has_autograd -@ti.test() +@test_utils.test() def test_atan2(tifunc, npfunc): grad_test(tifunc, npfunc) @@ -169,7 +172,7 @@ def test_atan2(tifunc, npfunc): (lambda y: ti.atan2(y, 0.4), lambda y: np.arctan2(y, 0.4)), ]) @if_has_autograd -@ti.test(require=ti.extension.data64, default_fp=ti.f64) +@test_utils.test(require=ti.extension.data64, default_fp=ti.f64) def test_atan2_f64(tifunc, npfunc): grad_test(tifunc, npfunc) @@ -179,7 +182,7 @@ def test_atan2_f64(tifunc, npfunc): (lambda y: y**0.4, lambda y: np.power(y, 0.4)), ]) @if_has_autograd -@ti.test() +@test_utils.test() def test_pow(tifunc, npfunc): grad_test(tifunc, npfunc) @@ -189,12 +192,12 @@ def test_pow(tifunc, npfunc): (lambda y: y**0.4, lambda y: np.power(y, 0.4)), ]) @if_has_autograd -@ti.test(require=ti.extension.data64, default_fp=ti.f64) +@test_utils.test(require=ti.extension.data64, default_fp=ti.f64) def test_pow_f64(tifunc, npfunc): grad_test(tifunc, npfunc) -@ti.test() +@test_utils.test() def test_obey_kernel_simplicity(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -216,10 +219,10 @@ def func(): func() func.grad() - assert x.grad[0] == approx((42 - 5) * 3) + assert x.grad[0] == test_utils.approx((42 - 5) * 3) -@ti.test() +@test_utils.test() def test_violate_kernel_simplicity1(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -238,7 +241,7 @@ def func(): func.grad() -@ti.test() +@test_utils.test() def test_violate_kernel_simplicity2(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -257,7 +260,7 @@ def func(): func.grad() -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_cast(): @ti.kernel def func(): @@ -266,7 +269,7 @@ def func(): func() -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_ad_precision_1(): loss = ti.field(ti.f32, shape=()) x = ti.field(ti.f64, shape=()) @@ -283,7 +286,7 @@ def func(): assert x.grad[None] == 1 -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_ad_precision_2(): loss = ti.field(ti.f64, shape=()) x = ti.field(ti.f32, shape=()) @@ -300,7 +303,7 @@ def func(): assert x.grad[None] == 1 -@ti.test() +@test_utils.test() def test_ad_rand(): loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) x = ti.field(dtype=ti.f32, shape=(), needs_grad=True) @@ -314,3 +317,37 @@ def work(): with ti.Tape(loss): work() assert 'RandStmt not supported' in e.value.args[0] + + +@test_utils.test(exclude=[ti.cc, ti.vulkan, ti.opengl]) +def test_ad_frac(): + @ti.func + def frac(x): + fractional = x - ti.floor(x) if x > 0. else x - ti.ceil(x) + return fractional + + @ti.kernel + def ti_frac(input_field: ti.template(), output_field: ti.template()): + for i in input_field: + output_field[i] = frac(input_field[i])**2 + + @ti.kernel + def calc_loss(input_field: ti.template(), loss: ti.template()): + for i in input_field: + loss[None] += input_field[i] + + n = 10 + field0 = ti.field(dtype=ti.f32, shape=(n, ), needs_grad=True) + randoms = np.random.randn(10).astype(np.float32) + field0.from_numpy(randoms) + field1 = ti.field(dtype=ti.f32, shape=(n, ), needs_grad=True) + loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + + with ti.Tape(loss): + ti_frac(field0, field1) + calc_loss(field1, loss) + + grads = field0.grad.to_numpy() + expected = np.modf(randoms)[0] * 2 + for i in range(n): + assert grads[i] == test_utils.approx(expected[i], rel=1e-4) diff --git a/tests/python/test_ad_demote_dense.py b/tests/python/test_ad_demote_dense.py index b9f55c6f621ad..4253453852142 100644 --- a/tests/python/test_ad_demote_dense.py +++ b/tests/python/test_ad_demote_dense.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(exclude=[ti.metal, ti.opengl]) +@test_utils.test(exclude=[ti.metal, ti.opengl]) def test_ad_demote_dense(): a = ti.field(ti.f32, shape=(7, 3, 19)) diff --git a/tests/python/test_ad_for.py b/tests/python/test_ad_for.py index a40d110a5ac29..b0b3cb6d3bb10 100644 --- a/tests/python/test_ad_for.py +++ b/tests/python/test_ad_for.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_sum(): N = 10 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -32,7 +33,7 @@ def compute_sum(): assert a.grad[i] == b[i] -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_sum_local_atomic(): N = 10 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -63,7 +64,7 @@ def compute_sum(): assert a.grad[i] == b[i] -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_power(): N = 10 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -94,7 +95,7 @@ def power(): assert a.grad[i] == b[i] * 3**(b[i] - 1) -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_fibonacci(): N = 15 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -132,7 +133,7 @@ def fib(): assert b.grad[i] == f[i] -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_fibonacci_index(): N = 5 M = 10 @@ -164,7 +165,7 @@ def fib(): assert b[i] == is_fib * N -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_global_ptr(): N = 5 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -194,7 +195,7 @@ def task(): assert a.grad[i] == 2 * i * N -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_integer_stack(): N = 5 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -233,7 +234,7 @@ def int_stack(): t = t * 10 + 1 -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_double_for_loops(): N = 5 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -271,7 +272,7 @@ def double_for(): assert b.grad[i] == 2 * i -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_double_for_loops_more_nests(): N = 6 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -317,7 +318,7 @@ def double_for(): assert b.grad[i] == total_grad_b -@ti.test(require=[ti.extension.adstack, ti.extension.data64]) +@test_utils.test(require=[ti.extension.adstack, ti.extension.data64]) def test_complex_body(): N = 5 a = ti.field(ti.f32, shape=N, needs_grad=True) @@ -357,7 +358,7 @@ def complex(): assert a.grad[i] == g[i] -@ti.test(require=[ti.extension.adstack, ti.extension.bls], dynamic_index=False) +@test_utils.test(require=[ti.extension.adstack, ti.extension.bls]) def test_triple_for_loops_bls(): N = 8 M = 3 @@ -399,3 +400,378 @@ def triple_for(): for i in range(N): assert b.grad[i * 2] == min(min(N - i - 1, i + 1), M) * N assert b.grad[i * 2 + 1] == min(min(N - i - 1, i + 1), M) * N + + +@test_utils.test(require=ti.extension.adstack) +def test_mixed_inner_loops(): + x = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + arr = ti.field(dtype=ti.f32, shape=(5)) + loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + + @ti.kernel + def mixed_inner_loops(): + for i in arr: + loss[None] += ti.sin(x[None]) + for j in range(2): + loss[None] += ti.sin(x[None]) + 1.0 + + loss.grad[None] = 1.0 + x[None] = 0.0 + mixed_inner_loops() + mixed_inner_loops.grad() + + assert loss[None] == 10.0 + assert x.grad[None] == 15.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_mixed_inner_loops_tape(): + x = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + arr = ti.field(dtype=ti.f32, shape=(5)) + loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) + + @ti.kernel + def mixed_inner_loops_tape(): + for i in arr: + loss[None] += ti.sin(x[None]) + for j in range(2): + loss[None] += ti.sin(x[None]) + 1.0 + + x[None] = 0.0 + with ti.Tape(loss=loss): + mixed_inner_loops_tape() + + assert loss[None] == 10.0 + assert x.grad[None] == 15.0 + + +@test_utils.test(require=ti.extension.adstack, ad_stack_size=32) +def test_inner_loops_local_variable_fixed_stack_size_tape(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_inner_loops_local_variable(): + for i in arr: + for j in range(3): + s = 0.0 + t = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + t += ti.sin(x[None]) + loss[None] += s + t + + x[None] = 0.0 + with ti.Tape(loss=loss): + test_inner_loops_local_variable() + + assert loss[None] == 18.0 + assert x.grad[None] == 36.0 + + +@test_utils.test(require=ti.extension.adstack, ad_stack_size=32) +def test_inner_loops_local_variable_fixed_stack_size_kernel_grad(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_inner_loops_local_variable(): + for i in arr: + for j in range(3): + s = 0.0 + t = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + t += ti.sin(x[None]) + loss[None] += s + t + + loss.grad[None] = 1.0 + x[None] = 0.0 + test_inner_loops_local_variable() + test_inner_loops_local_variable.grad() + + assert loss[None] == 18.0 + assert x.grad[None] == 36.0 + + +@test_utils.test(require=ti.extension.adstack, ad_stack_size=0) +def test_inner_loops_local_variable_adaptive_stack_size_tape(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_inner_loops_local_variable(): + for i in arr: + for j in range(3): + s = 0.0 + t = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + t += ti.sin(x[None]) + loss[None] += s + t + + x[None] = 0.0 + with ti.Tape(loss=loss): + test_inner_loops_local_variable() + + assert loss[None] == 18.0 + assert x.grad[None] == 36.0 + + +@test_utils.test(require=ti.extension.adstack, ad_stack_size=0) +def test_inner_loops_local_variable_adaptive_stack_size_kernel_grad(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_inner_loops_local_variable(): + for i in arr: + for j in range(3): + s = 0.0 + t = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + t += ti.sin(x[None]) + loss[None] += s + t + + loss.grad[None] = 1.0 + x[None] = 0.0 + test_inner_loops_local_variable() + test_inner_loops_local_variable.grad() + + assert loss[None] == 18.0 + assert x.grad[None] == 36.0 + + +@test_utils.test(require=ti.extension.adstack, ad_stack_size=0) +def test_more_inner_loops_local_variable_adaptive_stack_size_tape(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_more_inner_loops_local_variable(): + for i in arr: + for j in range(2): + s = 0.0 + for k in range(3): + u = 0.0 + s += ti.sin(x[None]) + 1.0 + for l in range(2): + u += ti.sin(x[None]) + loss[None] += u + loss[None] += s + + x[None] = 0.0 + with ti.Tape(loss=loss): + test_more_inner_loops_local_variable() + + assert loss[None] == 12.0 + assert x.grad[None] == 36.0 + + +@test_utils.test(require=ti.extension.adstack, ad_stack_size=32) +def test_more_inner_loops_local_variable_fixed_stack_size_tape(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_more_inner_loops_local_variable(): + for i in arr: + for j in range(2): + s = 0.0 + for k in range(3): + u = 0.0 + s += ti.sin(x[None]) + 1.0 + for l in range(2): + u += ti.sin(x[None]) + loss[None] += u + loss[None] += s + + x[None] = 0.0 + with ti.Tape(loss=loss): + test_more_inner_loops_local_variable() + + assert loss[None] == 12.0 + assert x.grad[None] == 36.0 + + +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=32, + arch=[ti.cpu, ti.gpu]) +def test_stacked_inner_loops_local_variable_fixed_stack_size_kernel_grad(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_stacked_inner_loops_local_variable(): + for i in arr: + loss[None] += ti.sin(x[None]) + for j in range(3): + s = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + loss[None] += s + for j in range(3): + s = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + loss[None] += s + + loss.grad[None] = 1.0 + x[None] = 0.0 + test_stacked_inner_loops_local_variable() + test_stacked_inner_loops_local_variable.grad() + + assert loss[None] == 36.0 + assert x.grad[None] == 38.0 + + +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=32, + arch=[ti.cpu, ti.gpu]) +def test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable_fixed_stack_size_kernel_grad( +): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable(): + for i in arr: + loss[None] += ti.sin(x[None]) + for j in range(3): + for k in range(3): + loss[None] += ti.sin(x[None]) + 1.0 + for j in range(3): + s = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + loss[None] += s + for j in range(3): + for k in range(3): + loss[None] += ti.sin(x[None]) + 1.0 + + loss.grad[None] = 1.0 + x[None] = 0.0 + test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable() + test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable.grad() + + assert loss[None] == 54.0 + assert x.grad[None] == 56.0 + + +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=0, + arch=[ti.cpu, ti.gpu]) +def test_stacked_inner_loops_local_variable_adaptive_stack_size_kernel_grad(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_stacked_inner_loops_local_variable(): + for i in arr: + loss[None] += ti.sin(x[None]) + for j in range(3): + s = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + loss[None] += s + for j in range(3): + s = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + loss[None] += s + + loss.grad[None] = 1.0 + x[None] = 0.0 + test_stacked_inner_loops_local_variable() + test_stacked_inner_loops_local_variable.grad() + + assert loss[None] == 36.0 + assert x.grad[None] == 38.0 + + +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=0, + arch=[ti.cpu, ti.gpu]) +def test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable_adaptive_stack_size_kernel_grad( +): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable(): + for i in arr: + loss[None] += ti.sin(x[None]) + for j in range(3): + for k in range(3): + loss[None] += ti.sin(x[None]) + 1.0 + for j in range(3): + s = 0.0 + for k in range(3): + s += ti.sin(x[None]) + 1.0 + loss[None] += s + for j in range(3): + for k in range(3): + loss[None] += ti.sin(x[None]) + 1.0 + + loss.grad[None] = 1.0 + x[None] = 0.0 + test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable() + test_stacked_mixed_ib_and_non_ib_inner_loops_local_variable.grad() + + assert loss[None] == 54.0 + assert x.grad[None] == 56.0 + + +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=0, + arch=[ti.cpu, ti.gpu]) +def test_large_for_loops_adaptive_stack_size(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_large_loop(): + for i in range(5): + for j in range(2000): + for k in range(1000): + loss[None] += ti.sin(x[None]) + 1.0 + + with ti.Tape(loss=loss): + test_large_loop() + + assert loss[None] == 1e7 + assert x.grad[None] == 1e7 + + +@test_utils.test(require=ti.extension.adstack, + ad_stack_size=1, + arch=[ti.cpu, ti.gpu]) +def test_large_for_loops_fixed_stack_size(): + x = ti.field(dtype=float, shape=(), needs_grad=True) + arr = ti.field(dtype=float, shape=(2), needs_grad=True) + loss = ti.field(dtype=float, shape=(), needs_grad=True) + + @ti.kernel + def test_large_loop(): + for i in range(5): + for j in range(2000): + for k in range(1000): + loss[None] += ti.sin(x[None]) + 1.0 + + with ti.Tape(loss=loss): + test_large_loop() + + assert loss[None] == 1e7 + assert x.grad[None] == 1e7 diff --git a/tests/python/test_ad_if.py b/tests/python/test_ad_if.py index 0b4e0cae16196..b76ef60289a19 100644 --- a/tests/python/test_ad_if.py +++ b/tests/python/test_ad_if.py @@ -1,7 +1,11 @@ +from taichi.lang import impl +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_if_simple(): x = ti.field(ti.f32, shape=()) y = ti.field(ti.f32, shape=()) @@ -22,7 +26,7 @@ def func(): assert x.grad[None] == 1 -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_if(): x = ti.field(ti.f32, shape=2) y = ti.field(ti.f32, shape=2) @@ -50,7 +54,7 @@ def func(i: ti.i32): assert x.grad[1] == 1 -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_if_nested(): n = 20 x = ti.field(ti.f32, shape=n) @@ -88,7 +92,7 @@ def func(): assert z.grad[i] == i % 4 -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_if_mutable(): x = ti.field(ti.f32, shape=2) y = ti.field(ti.f32, shape=2) @@ -117,7 +121,7 @@ def func(i: ti.i32): assert x.grad[1] == 1 -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_if_parallel(): x = ti.field(ti.f32, shape=2) y = ti.field(ti.f32, shape=2) @@ -145,8 +149,8 @@ def func(): assert x.grad[1] == 1 -@ti.test(require=[ti.extension.adstack, ti.extension.data64], - default_fp=ti.f64) +@test_utils.test(require=[ti.extension.adstack, ti.extension.data64], + default_fp=ti.f64) def test_ad_if_parallel_f64(): x = ti.field(ti.f64, shape=2) y = ti.field(ti.f64, shape=2) @@ -174,7 +178,7 @@ def func(): assert x.grad[1] == 1 -@ti.test(require=ti.extension.adstack) +@test_utils.test(require=ti.extension.adstack) def test_ad_if_parallel_complex(): x = ti.field(ti.f32, shape=2) y = ti.field(ti.f32, shape=2) @@ -202,8 +206,8 @@ def func(): assert x.grad[1] == -0.25 -@ti.test(require=[ti.extension.adstack, ti.extension.data64], - default_fp=ti.f64) +@test_utils.test(require=[ti.extension.adstack, ti.extension.data64], + default_fp=ti.f64) def test_ad_if_parallel_complex_f64(): x = ti.field(ti.f64, shape=2) y = ti.field(ti.f64, shape=2) @@ -231,10 +235,10 @@ def func(): assert x.grad[1] == -0.25 -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_stack(): @ti.kernel def func(): - ti.call_internal("test_stack") + impl.call_internal("test_stack") func() diff --git a/tests/python/test_ad_offload.py b/tests/python/test_ad_offload.py index 24cbea4d8303c..945dba9c83d20 100644 --- a/tests/python/test_ad_offload.py +++ b/tests/python/test_ad_offload.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_offload_order(): n = 128 x = ti.field(ti.f32, shape=n, needs_grad=True) diff --git a/tests/python/test_aot.py b/tests/python/test_aot.py index 802db2f2953e9..194077bd08cff 100644 --- a/tests/python/test_aot.py +++ b/tests/python/test_aot.py @@ -1,10 +1,16 @@ +import json import os +import sys import tempfile +import numpy as np +import pytest + import taichi as ti +from tests import test_utils -@ti.test(arch=ti.cc) +@test_utils.test(arch=ti.cc) def test_record(): with tempfile.TemporaryDirectory() as tmpdir: recorded_file = os.path.join(tmpdir, 'record.yml') @@ -26,3 +32,531 @@ def compute_loss(): # Make sure kernel info is in the file with open(recorded_file, 'r') as f: assert 'compute_loss' in ''.join(f.readlines()) + + +@test_utils.test(arch=ti.opengl, max_block_dim=32) +def test_opengl_max_block_dim(): + density = ti.field(float, shape=(8, 8)) + + @ti.kernel + def init(): + for i, j in density: + density[i, j] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + m = ti.aot.Module(ti.opengl) + m.add_field('density', density) + m.add_kernel(init) + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + res = json.load(json_file) + gl_file_path = res['aot_data']['kernels']['init']['tasks'][0][ + 'source_path'] + with open(gl_file_path) as gl_file: + s = 'layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;\n' + assert s in gl_file.readlines() + + +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_aot_field_range_hint(): + density = ti.field(float, shape=(8, 8)) + + @ti.kernel + def init(): + for i, j in density: + density[i, j] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + m = ti.aot.Module(ti.opengl) + m.add_field('density', density) + m.add_kernel(init) + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + res = json.load(json_file) + range_hint = res['aot_data']['kernels']['init']['tasks'][0][ + 'range_hint'] + assert range_hint == '64' + + +@test_utils.test(arch=ti.opengl) +def test_aot_ndarray_range_hint(): + density = ti.ndarray(dtype=ti.f32, shape=(8, 8)) + + @ti.kernel + def init(density: ti.any_arr()): + for i, j in density: + density[i, j] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + m = ti.aot.Module(ti.opengl) + m.add_kernel(init, (density, )) + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + res = json.load(json_file) + range_hint = res['aot_data']['kernels']['init']['tasks'][0][ + 'range_hint'] + assert range_hint == 'arg 0' + + +@test_utils.test(arch=ti.opengl) +def test_element_size_alignment(): + a = ti.field(ti.f32, shape=()) + b = ti.Matrix.field(2, 3, ti.f32, shape=(2, 4)) + c = ti.field(ti.i32, shape=()) + + with tempfile.TemporaryDirectory() as tmpdir: + s = ti.aot.Module(ti.lang.impl.current_cfg().arch) + s.add_field('a', a) + s.add_field('b', b) + s.add_field('c', c) + s.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + res = json.load(json_file) + offsets = (res['aot_data']['fields'][0]['mem_offset_in_parent'], + res['aot_data']['fields'][1]['mem_offset_in_parent'], + res['aot_data']['fields'][2]['mem_offset_in_parent']) + assert 0 in offsets and 4 in offsets and 24 in offsets + assert res['aot_data']['root_buffer_size'] == 216 + + +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_save(): + density = ti.field(float, shape=(4, 4)) + + @ti.kernel + def init(): + for i, j in density: + density[i, j] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + # note ti.aot.Module(ti.opengl) is no-op according to its docstring. + m = ti.aot.Module(ti.lang.impl.current_cfg().arch) + m.add_field('density', density) + m.add_kernel(init) + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + json.load(json_file) + + +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_save_template_kernel(): + density = ti.field(float, shape=(4, 4)) + + @ti.kernel + def foo(n: ti.template()): + for i in range(n): + density[0, 0] += 1 + + with tempfile.TemporaryDirectory() as tmpdir: + # note ti.aot.Module(ti.opengl) is no-op according to its docstring. + m = ti.aot.Module(ti.lang.impl.current_cfg().arch) + m.add_field('density', density) + with m.add_kernel_template(foo) as kt: + kt.instantiate(n=6) + kt.instantiate(n=8) + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + json.load(json_file) + + +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_non_dense_snode(): + n = 8 + x = ti.field(dtype=ti.f32) + y = ti.field(dtype=ti.f32) + blk = ti.root.dense(ti.i, n) + blk.place(x) + blk.dense(ti.i, n).place(y) + + with pytest.raises(RuntimeError, match='AOT: only supports dense field'): + m = ti.aot.Module(ti.lang.impl.current_cfg().arch) + m.add_field('x', x) + m.add_field('y', y) + + +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_mpm88_aot(): + n_particles = 8192 + n_grid = 128 + dx = 1 / n_grid + dt = 2e-4 + + p_rho = 1 + p_vol = (dx * 0.5)**2 + p_mass = p_vol * p_rho + gravity = 9.8 + bound = 3 + E = 400 + + x = ti.Vector.field(2, float, n_particles) + v = ti.Vector.field(2, float, n_particles) + C = ti.Matrix.field(2, 2, float, n_particles) + J = ti.field(float, n_particles) + + grid_v = ti.Vector.field(2, float, (n_grid, n_grid)) + grid_m = ti.field(float, (n_grid, n_grid)) + + @ti.kernel + def substep(): + for i, j in grid_m: + grid_v[i, j] = [0, 0] + grid_m[i, j] = 0 + for p in x: + Xp = x[p] / dx + base = int(Xp - 0.5) + fx = Xp - base + w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] + stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2 + affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p] + for i, j in ti.static(ti.ndrange(3, 3)): + offset = ti.Vector([i, j]) + dpos = (offset - fx) * dx + weight = w[i].x * w[j].y + grid_v[base + + offset] += weight * (p_mass * v[p] + affine @ dpos) + grid_m[base + offset] += weight * p_mass + for i, j in grid_m: + if grid_m[i, j] > 0: + grid_v[i, j] /= grid_m[i, j] + grid_v[i, j].y -= dt * gravity + if i < bound and grid_v[i, j].x < 0: + grid_v[i, j].x = 0 + if i > n_grid - bound and grid_v[i, j].x > 0: + grid_v[i, j].x = 0 + if j < bound and grid_v[i, j].y < 0: + grid_v[i, j].y = 0 + if j > n_grid - bound and grid_v[i, j].y > 0: + grid_v[i, j].y = 0 + for p in x: + Xp = x[p] / dx + base = int(Xp - 0.5) + fx = Xp - base + w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] + new_v = ti.Vector.zero(float, 2) + new_C = ti.Matrix.zero(float, 2, 2) + for i, j in ti.static(ti.ndrange(3, 3)): + offset = ti.Vector([i, j]) + dpos = (offset - fx) * dx + weight = w[i].x * w[j].y + g_v = grid_v[base + offset] + new_v += weight * g_v + new_C += 4 * weight * g_v.outer_product(dpos) / dx**2 + v[p] = new_v + x[p] += dt * v[p] + J[p] *= 1 + dt * new_C.trace() + C[p] = new_C + + @ti.kernel + def init(): + for i in range(n_particles): + x[i] = [ti.random() * 0.4 + 0.2, ti.random() * 0.4 + 0.2] + v[i] = [0, -1] + J[i] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + m = ti.aot.Module(ti.lang.impl.current_cfg().arch) + m.add_field("x", x) + m.add_field("v", v) + m.add_field("C", C) + m.add_field("J", J) + m.add_field("grid_v", grid_v) + m.add_field("grid_m", grid_m) + m.add_kernel(substep) + m.add_kernel(init) + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + json.load(json_file) + + +@test_utils.test(arch=ti.opengl) +def test_opengl_8_ssbo(): + # 6 ndarrays + gtmp + args + n = 4 + density1 = ti.ndarray(dtype=ti.f32, shape=(4, 4)) + density2 = ti.ndarray(dtype=ti.f32, shape=(4, 4)) + density3 = ti.ndarray(dtype=ti.f32, shape=(4, 4)) + density4 = ti.ndarray(dtype=ti.f32, shape=(4, 4)) + density5 = ti.ndarray(dtype=ti.f32, shape=(4, 4)) + density6 = ti.ndarray(dtype=ti.f32, shape=(4, 4)) + + @ti.kernel + def init(d: ti.i32, density1: ti.any_arr(), density2: ti.any_arr(), + density3: ti.any_arr(), density4: ti.any_arr(), + density5: ti.any_arr(), density6: ti.any_arr()): + for i, j in density1: + density1[i, j] = d + 1 + density2[i, j] = d + 2 + density3[i, j] = d + 3 + density4[i, j] = d + 4 + density5[i, j] = d + 5 + density6[i, j] = d + 6 + + init(0, density1, density2, density3, density4, density5, density6) + assert (density1.to_numpy() == (np.zeros(shape=(n, n)) + 1)).all() + assert (density2.to_numpy() == (np.zeros(shape=(n, n)) + 2)).all() + assert (density3.to_numpy() == (np.zeros(shape=(n, n)) + 3)).all() + assert (density4.to_numpy() == (np.zeros(shape=(n, n)) + 4)).all() + assert (density5.to_numpy() == (np.zeros(shape=(n, n)) + 5)).all() + assert (density6.to_numpy() == (np.zeros(shape=(n, n)) + 6)).all() + + +@test_utils.test(arch=ti.opengl) +def test_opengl_exceed_max_ssbo(): + # 8 ndarrays + args > 8 (maximum allowed) + n = 4 + density1 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density2 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density3 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density4 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density5 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density6 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density7 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + density8 = ti.ndarray(dtype=ti.f32, shape=(n, n)) + + @ti.kernel + def init(d: ti.i32, density1: ti.any_arr(), density2: ti.any_arr(), + density3: ti.any_arr(), density4: ti.any_arr(), + density5: ti.any_arr(), density6: ti.any_arr(), + density7: ti.any_arr(), density8: ti.any_arr()): + for i, j in density1: + density1[i, j] = d + 1 + density2[i, j] = d + 2 + density3[i, j] = d + 3 + density4[i, j] = d + 4 + density5[i, j] = d + 5 + density6[i, j] = d + 6 + density7[i, j] = d + 7 + density8[i, j] = d + 8 + + with pytest.raises(RuntimeError): + init(0, density1, density2, density3, density4, density5, density6, + density7, density8) + + +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_mpm99_aot(): + quality = 1 # Use a larger value for higher-res simulations + n_particles, n_grid = 9000 * quality**2, 128 * quality + dx, inv_dx = 1 / n_grid, float(n_grid) + dt = 1e-4 / quality + p_vol, p_rho = (dx * 0.5)**2, 1 + p_mass = p_vol * p_rho + E, nu = 0.1e4, 0.2 # Young's modulus and Poisson's ratio + mu_0, lambda_0 = E / (2 * (1 + nu)), E * nu / ( + (1 + nu) * (1 - 2 * nu)) # Lame parameters + x = ti.Vector.field(2, dtype=float, shape=n_particles) # position + v = ti.Vector.field(2, dtype=float, shape=n_particles) # velocity + C = ti.Matrix.field(2, 2, dtype=float, + shape=n_particles) # affine velocity field + F = ti.Matrix.field(2, 2, dtype=float, + shape=n_particles) # deformation gradient + material = ti.field(dtype=int, shape=n_particles) # material id + Jp = ti.field(dtype=float, shape=n_particles) # plastic deformation + grid_v = ti.Vector.field(2, dtype=float, + shape=(n_grid, + n_grid)) # grid node momentum/velocity + grid_m = ti.field(dtype=float, shape=(n_grid, n_grid)) # grid node mass + grid_v_int = ti.Vector.field(2, dtype=int, + shape=(n_grid, + n_grid)) # grid node momentum/velocity + grid_m_int = ti.field(dtype=int, shape=(n_grid, n_grid)) # grid node mass + + v_exp = 24 + m_exp = 40 + + @ti.kernel + def substep(): + for i, j in grid_m: + grid_v[i, j] = [0, 0] + grid_m[i, j] = 0 + grid_v_int[i, j] = [0, 0] + grid_m_int[i, j] = 0 + for p in x: # Particle state update and scatter to grid (P2G) + base = (x[p] * inv_dx - 0.5).cast(int) + fx = x[p] * inv_dx - base.cast(float) + # Quadratic kernels [http://mpm.graphics Eqn. 123, with x=fx, fx-1,fx-2] + w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] + F[p] = (ti.Matrix.identity(float, 2) + + dt * C[p]) @ F[p] # deformation gradient update + h = ti.exp( + 10 * (1.0 - Jp[p]) + ) # Hardening coefficient: snow gets harder when compressed + if material[p] == 1: # jelly, make it softer + h = 0.3 + mu, la = mu_0 * h, lambda_0 * h + if material[p] == 0: # liquid + mu = 0.0 + U, sig, V = ti.svd(F[p]) + J = 1.0 + for d in ti.static(range(2)): + new_sig = sig[d, d] + if material[p] == 2: # Snow + new_sig = min(max(sig[d, d], 1 - 2.5e-2), + 1 + 4.5e-3) # Plasticity + Jp[p] *= sig[d, d] / new_sig + sig[d, d] = new_sig + J *= new_sig + if material[ + p] == 0: # Reset deformation gradient to avoid numerical instability + F[p] = ti.Matrix.identity(float, 2) * ti.sqrt(J) + elif material[p] == 2: + F[p] = U @ sig @ V.transpose( + ) # Reconstruct elastic deformation gradient after plasticity + stress = 2 * mu * (F[p] - U @ V.transpose()) @ F[p].transpose( + ) + ti.Matrix.identity(float, 2) * la * J * (J - 1) + stress = (-dt * p_vol * 4 * inv_dx * inv_dx) * stress + affine = stress + p_mass * C[p] + for i, j in ti.static(ti.ndrange( + 3, 3)): # Loop over 3x3 grid node neighborhood + offset = ti.Vector([i, j]) + dpos = (offset.cast(float) - fx) * dx + weight = w[i][0] * w[j][1] + grid_v_int[base + offset] += int( + ti.floor(0.5 + weight * (p_mass * v[p] + affine @ dpos) * + (2.0**v_exp))) + grid_m_int[base + offset] += int( + ti.floor(0.5 + weight * p_mass * (2.0**m_exp))) + for i, j in grid_m: + if grid_m_int[i, j] > 0: # No need for epsilon here + # grid_v[i, j] = (1.0 / grid_m[i, j]) * grid_v[i, j] # Momentum to velocity + grid_v[i, j] = (2**(m_exp - v_exp) / grid_m_int[i, j] + ) * grid_v_int[i, j] # Momentum to velocity + grid_v[i, j][1] -= dt * 50 # gravity + if i < 3 and grid_v[i, j][0] < 0: + grid_v[i, j][0] = 0 # Boundary conditions + if i > n_grid - 3 and grid_v[i, j][0] > 0: grid_v[i, j][0] = 0 + if j < 3 and grid_v[i, j][1] < 0: grid_v[i, j][1] = 0 + if j > n_grid - 3 and grid_v[i, j][1] > 0: grid_v[i, j][1] = 0 + for p in x: # grid to particle (G2P) + base = (x[p] * inv_dx - 0.5).cast(int) + fx = x[p] * inv_dx - base.cast(float) + w = [ + 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 + ] + new_v = ti.Vector.zero(float, 2) + new_C = ti.Matrix.zero(float, 2, 2) + for i, j in ti.static(ti.ndrange( + 3, 3)): # loop over 3x3 grid node neighborhood + dpos = ti.Vector([i, j]).cast(float) - fx + g_v = grid_v[base + ti.Vector([i, j])] + weight = w[i][0] * w[j][1] + new_v += weight * g_v + new_C += 4 * inv_dx * weight * g_v.outer_product(dpos) + v[p], C[p] = new_v, new_C + x[p] += dt * v[p] # advection + + group_size = n_particles // 3 + + @ti.kernel + def initialize(): + for i in range(n_particles): + x[i] = [ + ti.random() * 0.2 + 0.3 + 0.10 * (i // group_size), + ti.random() * 0.2 + 0.05 + 0.32 * (i // group_size) + ] + material[i] = i // group_size # 0: fluid 1: jelly 2: snow + v[i] = ti.Matrix([0, 0]) + F[i] = ti.Matrix([[1, 0], [0, 1]]) + Jp[i] = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + m = ti.aot.Module(ti.lang.impl.current_cfg().arch) + m.add_field('x', x) + m.add_field('v', v) + m.add_field('C', C) + m.add_field('J', Jp) + m.add_field('grid_v', grid_v) + m.add_field('grid_m', grid_m) + m.add_field('grid_v_int', grid_v_int) + m.add_field('grid_m_int', grid_m_int) + m.add_field('material', material) + m.add_kernel(initialize) + m.add_kernel(substep) + + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + json.load(json_file) + + +@test_utils.test(arch=ti.opengl) +def test_mpm88_ndarray(): + dim = 2 + N = 64 + n_particles = N * N + n_grid = 128 + dx = 1 / n_grid + inv_dx = 1 / dx + dt = 2.0e-4 + p_vol = (dx * 0.5)**2 + p_rho = 1 + p_mass = p_vol * p_rho + E = 400 + + @ti.kernel + def substep(x: ti.any_arr(element_dim=1), v: ti.any_arr(element_dim=1), + C: ti.any_arr(element_dim=2), J: ti.any_arr(), + grid_v: ti.any_arr(element_dim=1), grid_m: ti.any_arr()): + for p in x: + base = (x[p] * inv_dx - 0.5).cast(int) + fx = x[p] * inv_dx - base.cast(float) + w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] + stress = -dt * p_vol * (J[p] - 1) * 4 * inv_dx * inv_dx * E + affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p] + for i in ti.static(range(3)): + for j in ti.static(range(3)): + offset = ti.Vector([i, j]) + dpos = (offset.cast(float) - fx) * dx + weight = w[i][0] * w[j][1] + ti.atomic_add(grid_v[base + offset], + weight * (p_mass * v[p] + affine @ dpos)) + ti.atomic_add(grid_m[base + offset], weight * p_mass) + + for i, j in grid_m: + if grid_m[i, j] > 0: + bound = 3 + inv_m = 1 / grid_m[i, j] + grid_v[i, j] = inv_m * grid_v[i, j] + grid_v[i, j][1] -= dt * 9.8 + if i < bound and grid_v[i, j][0] < 0: + grid_v[i, j][0] = 0 + if i > n_grid - bound and grid_v[i, j][0] > 0: + grid_v[i, j][0] = 0 + if j < bound and grid_v[i, j][1] < 0: + grid_v[i, j][1] = 0 + if j > n_grid - bound and grid_v[i, j][1] > 0: + grid_v[i, j][1] = 0 + + for p in x: + base = (x[p] * inv_dx - 0.5).cast(int) + fx = x[p] * inv_dx - base.cast(float) + w = [ + 0.5 * (1.5 - fx)**2, 0.75 - (fx - 1.0)**2, 0.5 * (fx - 0.5)**2 + ] + new_v = ti.Vector.zero(ti.f32, 2) + new_C = ti.Matrix.zero(ti.f32, 2, 2) + for i in ti.static(range(3)): + for j in ti.static(range(3)): + dpos = ti.Vector([i, j]).cast(float) - fx + g_v = grid_v[base + ti.Vector([i, j])] + weight = w[i][0] * w[j][1] + new_v += weight * g_v + new_C += 4 * weight * g_v.outer_product(dpos) * inv_dx + v[p] = new_v + x[p] += dt * v[p] + J[p] *= 1 + dt * new_C.trace() + C[p] = new_C + + x = ti.Vector.ndarray(dim, ti.f32, n_particles) + v = ti.Vector.ndarray(dim, ti.f32, n_particles) + C = ti.Matrix.ndarray(dim, dim, ti.f32, n_particles) + J = ti.ndarray(ti.f32, n_particles) + grid_v = ti.Vector.ndarray(dim, ti.f32, (n_grid, n_grid)) + grid_m = ti.ndarray(ti.f32, (n_grid, n_grid)) + + with tempfile.TemporaryDirectory() as tmpdir: + m = ti.aot.Module(ti.opengl) + m.add_kernel(substep, (x, v, C, J, grid_v, grid_m)) + + m.save(tmpdir, '') + with open(os.path.join(tmpdir, 'metadata.json')) as json_file: + json.load(json_file) diff --git a/tests/python/test_api.py b/tests/python/test_api.py new file mode 100644 index 0000000000000..bd732a8152146 --- /dev/null +++ b/tests/python/test_api.py @@ -0,0 +1,90 @@ +import sys + +import pytest + +import taichi as ti +from tests import test_utils + +user_api = {} +user_api[ti] = [ + 'CRITICAL', 'DEBUG', 'ERROR', 'Field', 'FieldsBuilder', 'GUI', 'INFO', + 'Layout', 'Matrix', 'MatrixField', 'MatrixNdarray', 'Mesh', 'Ndarray', + 'SNode', 'ScalarField', 'ScalarNdarray', 'Struct', 'StructField', 'TRACE', + 'TaichiCompilationError', 'TaichiNameError', 'TaichiRuntimeError', + 'TaichiRuntimeTypeError', 'TaichiSyntaxError', 'TaichiTypeError', 'Tape', + 'TetMesh', 'TriMesh', 'Vector', 'VectorNdarray', 'WARN', 'abs', 'acos', + 'activate', 'ad', 'any_arr', 'aot', 'append', 'arm64', 'asin', + 'assume_in_range', 'atan2', 'atomic_add', 'atomic_and', 'atomic_max', + 'atomic_min', 'atomic_or', 'atomic_sub', 'atomic_xor', 'axes', 'bit_cast', + 'bit_shr', 'block_dim', 'block_local', 'cache_read_only', 'cast', 'cc', + 'ceil', 'clear_all_gradients', 'cos', 'cpu', 'cuda', 'data_oriented', + 'deactivate', 'deactivate_all_snodes', 'dx11', 'eig', 'exp', + 'experimental', 'ext_arr', 'extension', 'f16', 'f32', 'f64', 'field', + 'float16', 'float32', 'float64', 'floor', 'func', 'get_addr', + 'global_thread_idx', 'gpu', 'grouped', 'hex_to_rgb', 'i', 'i16', 'i32', + 'i64', 'i8', 'ij', 'ijk', 'ijkl', 'ijl', 'ik', 'ikl', 'il', 'init', + 'int16', 'int32', 'int64', 'int8', 'is_active', 'is_logging_effective', + 'j', 'jk', 'jkl', 'jl', 'k', 'kernel', 'kl', 'l', 'lang', 'length', + 'linalg', 'log', 'max', 'mesh_local', 'mesh_patch_idx', 'metal', 'min', + 'ndarray', 'ndrange', 'no_activate', 'one', 'opengl', 'parallelize', + 'polar_decompose', 'pow', 'profiler', 'randn', 'random', 'raw_div', + 'raw_mod', 'rescale_index', 'reset', 'rgb_to_hex', 'root', 'round', + 'rsqrt', 'select', 'set_logging_level', 'sin', 'sparse_matrix_builder', + 'sqrt', 'static', 'static_assert', 'static_print', 'stop_grad', 'svd', + 'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools', 'types', 'u16', + 'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan', + 'wasm', 'x64', 'x86_64', 'zero' +] +user_api[ti.Field] = [ + 'copy_from', 'dtype', 'fill', 'from_numpy', 'from_torch', 'parent', + 'shape', 'snode', 'to_numpy', 'to_torch' +] +user_api[ti.FieldsBuilder] = [ + 'bit_array', 'bit_struct', 'bitmasked', 'deactivate_all', 'dense', + 'dynamic', 'finalize', 'lazy_grad', 'place', 'pointer' +] +user_api[ti.Matrix] = [ + 'all', 'any', 'cast', 'cols', 'cross', 'determinant', 'diag', 'dot', + 'field', 'fill', 'identity', 'inverse', 'max', 'min', 'ndarray', 'norm', + 'norm_inv', 'norm_sqr', 'normalized', 'one', 'outer_product', 'rotation2d', + 'rows', 'sum', 'to_list', 'to_numpy', 'trace', 'transpose', 'unit', 'w', + 'x', 'y', 'z', 'zero' +] +user_api[ti.MatrixField] = [ + 'copy_from', 'dtype', 'fill', 'from_numpy', 'from_torch', + 'get_scalar_field', 'parent', 'shape', 'snode', 'to_numpy', 'to_torch' +] +user_api[ti.MatrixNdarray] = [ + 'copy_from', 'element_shape', 'fill', 'from_numpy', 'to_numpy' +] +user_api[ti.Ndarray] = ['copy_from', 'element_shape', 'fill'] +user_api[ti.SNode] = [ + 'bit_array', 'bit_struct', 'bitmasked', 'deactivate_all', 'dense', + 'dynamic', 'lazy_grad', 'parent', 'place', 'pointer', 'shape' +] +user_api[ti.ScalarField] = [ + 'copy_from', 'dtype', 'fill', 'from_numpy', 'from_torch', 'parent', + 'shape', 'snode', 'to_numpy', 'to_torch' +] +user_api[ti.ScalarNdarray] = [ + 'copy_from', 'element_shape', 'fill', 'from_numpy', 'to_numpy' +] +user_api[ti.Struct] = ['field', 'fill', 'items', 'keys', 'to_dict'] +user_api[ti.StructField] = [ + 'copy_from', 'dtype', 'fill', 'from_numpy', 'from_torch', + 'get_member_field', 'keys', 'parent', 'shape', 'snode', 'to_numpy', + 'to_torch' +] +user_api[ti.VectorNdarray] = [ + 'copy_from', 'element_shape', 'fill', 'from_numpy', 'to_numpy' +] + + +@pytest.mark.parametrize('src', user_api.keys()) +@test_utils.test(arch=ti.cpu) +def test_api(src): + # When Python version is below 3.7, deprecated names are + # handled as normal names, which will fail this test. + assert sys.version_info < (3, 7) or [ + s for s in dir(src) if not s.startswith('_') + ] == user_api[src] diff --git a/tests/python/test_arg_alignment.py b/tests/python/test_arg_alignment.py new file mode 100644 index 0000000000000..4320f219ee6ae --- /dev/null +++ b/tests/python/test_arg_alignment.py @@ -0,0 +1,23 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test(exclude=[ti.opengl]) +def test_ret_write(): + @ti.kernel + def func(a: ti.i16) -> ti.f32: + return 3.0 + + assert func(255) == 3.0 + + +@test_utils.test(exclude=[ti.opengl]) +def test_arg_read(): + x = ti.field(ti.i32, shape=()) + + @ti.kernel + def func(a: ti.i8, b: ti.i32): + x[None] = b + + func(255, 2) + assert x[None] == 2 diff --git a/tests/python/test_arg_check.py b/tests/python/test_arg_check.py index 889c652213284..35bd2d5ebd56a 100644 --- a/tests/python/test_arg_check.py +++ b/tests/python/test_arg_check.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_argument_error(): x = ti.field(ti.i32) @@ -12,7 +13,7 @@ def test_argument_error(): @ti.kernel def set_i32_notype(v): pass - except ti.KernelDefError: + except ti.TaichiSyntaxError: pass try: @@ -20,7 +21,7 @@ def set_i32_notype(v): @ti.kernel def set_i32_args(*args): pass - except ti.KernelDefError: + except ti.TaichiSyntaxError: pass try: @@ -28,7 +29,7 @@ def set_i32_args(*args): @ti.kernel def set_i32_kwargs(**kwargs): pass - except ti.KernelDefError: + except ti.TaichiSyntaxError: pass @ti.kernel diff --git a/tests/python/test_arg_load.py b/tests/python/test_arg_load.py index ef3c262697fcf..427b1a64c2241 100644 --- a/tests/python/test_arg_load.py +++ b/tests/python/test_arg_load.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_arg_load(): x = ti.field(ti.i32) y = ti.field(ti.f32) @@ -29,7 +30,7 @@ def set_f32(v: ti.f32): assert y[None] == 1.5 -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_arg_load_f64(): x = ti.field(ti.i32) y = ti.field(ti.f32) @@ -51,7 +52,7 @@ def set_i64(v: ti.i64): assert y[None] == 2.5 -@ti.test() +@test_utils.test() def test_ext_arr(): N = 128 x = ti.field(ti.f32) diff --git a/tests/python/test_argument.py b/tests/python/test_argument.py new file mode 100644 index 0000000000000..7fc38a7077b7d --- /dev/null +++ b/tests/python/test_argument.py @@ -0,0 +1,53 @@ +import pytest + +import taichi as ti +from tests import test_utils + + +@test_utils.test(arch=[ti.opengl, ti.cc]) +def test_exceed_max_eight(): + @ti.kernel + def foo1(a: ti.i32, b: ti.i32, c: ti.i32, d: ti.i32, e: ti.i32, f: ti.i32, + g: ti.i32, h: ti.i32) -> ti.i32: + return a + b + c + d + e + f + g + h + + assert foo1(1, 2, 3, 4, 5, 6, 7, 8) == 36 + + @ti.kernel + def foo2(a: ti.i32, b: ti.i32, c: ti.i32, d: ti.i32, e: ti.i32, f: ti.i32, + g: ti.i32, h: ti.i32, i: ti.i32) -> ti.i32: + return a + b + c + d + e + f + g + h + i + + with pytest.raises( + ti.TaichiRuntimeError, + match= + f"The number of elements in kernel arguments is too big! Do not exceed 8 on {ti.lang._ti_core.arch_name(ti.lang.impl.current_cfg().arch)} backend." + ): + foo2(1, 2, 3, 4, 5, 6, 7, 8, 9) + + +@test_utils.test(exclude=[ti.opengl, ti.cc]) +def test_exceed_max_64(): + N = 64 + + @ti.kernel + def foo1(a: ti.types.vector(N, ti.i32)) -> ti.i32: + return a.sum() + + A = ti.Vector([1] * N) + assert foo1(A) == 64 + + N = 65 + + @ti.kernel + def foo2(a: ti.types.vector(N, ti.i32)) -> ti.i32: + return a.sum() + + A = ti.Vector([1] * N) + + with pytest.raises( + ti.TaichiRuntimeError, + match= + f"The number of elements in kernel arguments is too big! Do not exceed 64 on {ti.lang._ti_core.arch_name(ti.lang.impl.current_cfg().arch)} backend." + ): + foo2(A) diff --git a/tests/python/test_assert.py b/tests/python/test_assert.py index 3094f5c1db6a0..dd18b6d12ab94 100644 --- a/tests/python/test_assert.py +++ b/tests/python/test_assert.py @@ -1,12 +1,12 @@ import pytest +from taichi.lang.misc import get_host_arch_list import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_assert_minimal(): - ti.set_gdb_trigger(False) - @ti.kernel def func(): assert 0 @@ -21,7 +21,7 @@ def func2(): func2() -@ti.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_assert_basic(): @ti.kernel def func(): @@ -32,7 +32,7 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_assert_message(): @ti.kernel def func(): @@ -43,7 +43,7 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_assert_message_formatted(): x = ti.field(dtype=int, shape=16) x[10] = 42 @@ -70,7 +70,7 @@ def assert_float(): assert_formatted() -@ti.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_assert_ok(): @ti.kernel def func(): @@ -80,7 +80,7 @@ def func(): func() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_static_assert_is_static(): @ti.kernel def func(): @@ -90,8 +90,7 @@ def func(): func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(AssertionError) +@test_utils.test(arch=get_host_arch_list()) def test_static_assert_message(): x = 3 @@ -99,10 +98,11 @@ def test_static_assert_message(): def func(): ti.static_assert(x == 4, "Oh, no!") - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_static_assert_vector_n_ok(): x = ti.Vector.field(4, ti.f32, ()) @@ -113,7 +113,7 @@ def func(): func() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_static_assert_data_type_ok(): x = ti.field(ti.f32, ()) diff --git a/tests/python/test_assign.py b/tests/python/test_assign.py index d0297c1f63e64..182f4bdfbaf59 100644 --- a/tests/python/test_assign.py +++ b/tests/python/test_assign.py @@ -1,9 +1,10 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test(debug=True) +@test_utils.test(debug=True) def test_assign_basic(): @ti.kernel def func_basic(): @@ -13,7 +14,7 @@ def func_basic(): func_basic() -@ti.test(debug=True) +@test_utils.test(debug=True) def test_assign_unpack(): @ti.kernel def func_unpack(): @@ -24,7 +25,7 @@ def func_unpack(): func_unpack() -@ti.test(debug=True) +@test_utils.test(debug=True) def test_assign_chained(): @ti.kernel def func_chained(): @@ -35,7 +36,7 @@ def func_chained(): func_chained() -@ti.test(debug=True) +@test_utils.test(debug=True) def test_assign_chained_unpack(): @ti.kernel def func_chained_unpack(): @@ -48,7 +49,7 @@ def func_chained_unpack(): func_chained_unpack() -@ti.test(debug=True) +@test_utils.test(debug=True) def test_assign_assign(): @ti.kernel def func_assign(): @@ -57,3 +58,27 @@ def func_assign(): assert a == 1 func_assign() + + +@test_utils.test(debug=True) +def test_assign_ann(): + @ti.kernel + def func_ann(): + a: ti.i32 = 1 + b: ti.f32 = a + assert a == 1 + assert b == 1.0 + + func_ann() + + +@test_utils.test() +def test_assign_ann_over(): + @ti.kernel + def func_ann_over(): + my_int = ti.i32 + d: my_int = 2 + d: ti.f32 = 2.0 + + with pytest.raises(ti.TaichiCompilationError): + func_ann_over() diff --git a/tests/python/test_ast_refactor.py b/tests/python/test_ast_refactor.py new file mode 100644 index 0000000000000..fd123cd09473b --- /dev/null +++ b/tests/python/test_ast_refactor.py @@ -0,0 +1,1021 @@ +import sys + +import numpy as np +import pytest +from taichi.lang import impl +from taichi.lang.util import has_pytorch + +import taichi as ti +from tests import test_utils + +if sys.version_info >= (3, 8): + # Import the test case only if the Python version is >= 3.8 + from .py38_only import test_namedexpr # noqa + + +@test_utils.test() +def test_binop(): + @ti.kernel + def foo(x: ti.i32, y: ti.i32, a: ti.template()): + a[0] = x + y + a[1] = x - y + a[2] = x * y + a[3] = impl.ti_float(x) / y + a[4] = x // y + a[5] = x % y + a[6] = x**y + a[7] = x << y + a[8] = x >> y + a[9] = x | y + a[10] = x ^ y + a[11] = x & y + + x = 37 + y = 3 + a = ti.field(ti.f32, shape=(12, )) + b = ti.field(ti.f32, shape=(12, )) + + a[0] = x + y + a[1] = x - y + a[2] = x * y + a[3] = x / y + a[4] = x // y + a[5] = x % y + a[6] = x**y + a[7] = x << y + a[8] = x >> y + a[9] = x | y + a[10] = x ^ y + a[11] = x & y + + foo(x, y, b) + + for i in range(12): + assert a[i] == test_utils.approx(b[i]) + + +@test_utils.test() +def test_augassign(): + @ti.kernel + def foo(x: ti.i32, y: ti.i32, a: ti.template(), b: ti.template()): + for i in a: + a[i] = x + a[0] += y + a[1] -= y + a[2] *= y + a[3] //= y + a[4] %= y + a[5] **= y + a[6] <<= y + a[7] >>= y + a[8] |= y + a[9] ^= y + a[10] &= y + b[0] = x + b[0] /= y + + x = 37 + y = 3 + a = ti.field(ti.i32, shape=(11, )) + b = ti.field(ti.i32, shape=(11, )) + c = ti.field(ti.f32, shape=(1, )) + d = ti.field(ti.f32, shape=(1, )) + + a[0] = x + y + a[1] = x - y + a[2] = x * y + a[3] = x // y + a[4] = x % y + a[5] = x**y + a[6] = x << y + a[7] = x >> y + a[8] = x | y + a[9] = x ^ y + a[10] = x & y + c[0] = x / y + + foo(x, y, b, d) + + for i in range(11): + assert a[i] == b[i] + assert c[0] == test_utils.approx(d[0]) + + +@test_utils.test() +def test_unaryop(): + @ti.kernel + def foo(x: ti.i32, a: ti.template()): + a[0] = +x + a[1] = -x + a[2] = not x + a[3] = ~x + + x = 1234 + a = ti.field(ti.i32, shape=(4, )) + b = ti.field(ti.i32, shape=(4, )) + + a[0] = +x + a[1] = -x + a[2] = not x + a[3] = ~x + + foo(x, b) + + for i in range(4): + assert a[i] == b[i] + + +@test_utils.test() +def test_boolop(): + @ti.kernel + def foo(a: ti.template()): + a[0] = 0 and 0 + a[1] = 0 and 1 + a[2] = 1 and 0 + a[3] = 1 and 1 + a[4] = 0 or 0 + a[5] = 0 or 1 + a[6] = 1 or 0 + a[7] = 1 or 1 + a[8] = 1 and 1 and 1 and 1 + a[9] = 1 and 1 and 1 and 0 + a[10] = 0 or 0 or 0 or 0 + a[11] = 0 or 0 or 1 or 0 + + a = ti.field(ti.i32, shape=(12, )) + b = ti.field(ti.i32, shape=(12, )) + + a[0] = 0 and 0 + a[1] = 0 and 1 + a[2] = 1 and 0 + a[3] = 1 and 1 + a[4] = 0 or 0 + a[5] = 0 or 1 + a[6] = 1 or 0 + a[7] = 1 or 1 + a[8] = 1 and 1 and 1 and 1 + a[9] = 1 and 1 and 1 and 0 + a[10] = 0 or 0 or 0 or 0 + a[11] = 0 or 0 or 1 or 0 + + foo(b) + + for i in range(12): + assert a[i] == b[i] + + +@test_utils.test() +def test_compare_fail(): + with pytest.raises(ti.TaichiCompilationError, + match='"Is" is only supported inside `ti.static`.'): + + @ti.kernel + def foo(): + 1 is [1] + + foo() + + +@test_utils.test() +def test_single_compare(): + @ti.kernel + def foo(a: ti.template(), b: ti.template(), c: ti.template()): + for i in ti.static(range(3)): + c[i * 6] = a[i] == b[i] + c[i * 6 + 1] = a[i] != b[i] + c[i * 6 + 2] = a[i] < b[i] + c[i * 6 + 3] = a[i] <= b[i] + c[i * 6 + 4] = a[i] > b[i] + c[i * 6 + 5] = a[i] >= b[i] + + a = ti.Vector([1, 1, 2]) + b = ti.Vector([2, 1, 1]) + c = ti.field(ti.i32, shape=(18, )) + d = ti.field(ti.i32, shape=(18, )) + + for i in range(3): + c[i * 6] = a[i] == b[i] + c[i * 6 + 1] = a[i] != b[i] + c[i * 6 + 2] = a[i] < b[i] + c[i * 6 + 3] = a[i] <= b[i] + c[i * 6 + 4] = a[i] > b[i] + c[i * 6 + 5] = a[i] >= b[i] + + foo(a, b, d) + for i in range(18): + assert c[i] == d[i] + + +@test_utils.test() +def test_chain_compare(): + @ti.kernel + def foo(a: ti.i32, b: ti.i32, c: ti.template()): + c[0] = a == b == a + c[1] = a == b != a + c[2] = a != b == a + c[3] = a < b > a + c[4] = a > b < a + c[5] = a < b < a + c[6] = a > b > a + c[7] = a == a == a == a + c[8] = a == a == a != a + c[9] = a < b > a < b + c[10] = a > b > a < b + + a = 1 + b = 2 + c = ti.field(ti.i32, shape=(11, )) + d = ti.field(ti.i32, shape=(11, )) + + c[0] = a == b == a + c[1] = a == b != a + c[2] = a != b == a + c[3] = a < b > a + c[4] = a > b < a + c[5] = a < b < a + c[6] = a > b > a + c[7] = a == a == a == a + c[8] = a == a == a != a + c[9] = a < b > a < b + c[10] = a > b > a < b + + foo(a, b, d) + for i in range(11): + assert c[i] == d[i] + + +@test_utils.test() +def test_return(): + @ti.kernel + def foo(x: ti.i32) -> ti.i32: + return x + 1 + + assert foo(1) == 2 + + +@test_utils.test() +def test_format_print(): + a = ti.field(ti.i32, shape=(10, )) + + @ti.kernel + def foo(): + a[0] = 1.0 + a[5] = 2.0 + print('Test if the string.format and fstring print works') + print('string.format: a[0]={}, a[5]={}'.format(a[0], a[5])) + print(f'fstring: a[0]={a[0]}, a[5]={a[5]}') + + +@test_utils.test(print_preprocessed_ir=True) +def test_if(): + @ti.kernel + def foo(x: ti.i32) -> ti.i32: + ret = 0 + if x: + ret = 1 + else: + ret = 0 + return ret + + assert foo(1) + assert not foo(0) + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_if(): + @ti.kernel + def foo(x: ti.template()) -> ti.i32: + ret = 0 + if ti.static(x): + ret = 1 + else: + ret = 0 + return ret + + assert foo(1) + assert not foo(0) + + +@test_utils.test(print_preprocessed_ir=True) +def test_struct_for(): + a = ti.field(ti.i32, shape=(10, )) + + @ti.kernel + def foo(x: ti.i32): + for i in a: + a[i] = x + + x = 5 + foo(x) + for i in range(10): + assert a[i] == 5 + + +@test_utils.test(print_preprocessed_ir=True) +def test_grouped_struct_for(): + a = ti.field(ti.i32, shape=(4, 4)) + + @ti.kernel + def foo(x: ti.i32): + for I in ti.grouped(a): + a[I] = x + + x = 5 + foo(x) + for i in range(4): + for j in range(4): + assert a[i, j] == 5 + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_for(): + a = ti.field(ti.i32, shape=(10, )) + + @ti.kernel + def foo(x: ti.i32): + for i in ti.static(range(10)): + a[i] = x + + x = 5 + foo(x) + for i in range(10): + assert a[i] == 5 + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_grouped_for(): + a = ti.field(ti.i32, shape=(4, 4)) + + @ti.kernel + def foo(x: ti.i32): + for i in ti.static(ti.grouped(ti.ndrange((1, 3), (1, 3)))): + a[i] = x + + x = 5 + foo(x) + for i in range(4): + for j in range(4): + if 1 <= i < 3 and 1 <= j < 3: + assert a[i, j] == 5 + else: + assert a[i, j] == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_range_for_single_argument(): + a = ti.field(ti.i32, shape=(10, )) + + @ti.kernel + def foo(x: ti.i32): + for i in range(5): + a[i] = x + + x = 5 + foo(x) + for i in range(10): + if i < 5: + assert a[i] == 5 + else: + assert a[i] == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_range_for_two_arguments(): + a = ti.field(ti.i32, shape=(10, )) + + @ti.kernel + def foo(x: ti.i32): + for i in range(3, 7): + a[i] = x + + x = 5 + foo(x) + for i in range(10): + if 3 <= i < 7: + assert a[i] == 5 + else: + assert a[i] == 0 + + +@test_utils.test() +def test_range_for_three_arguments(): + a = ti.field(ti.i32, shape=(10, )) + + with pytest.raises(ti.TaichiCompilationError, + match='Range should have 1 or 2 arguments, found 3'): + + @ti.kernel + def foo(x: ti.i32): + for i in range(3, 7, 2): + a[i] = x + + x = 5 + foo(x) + + +@test_utils.test(print_preprocessed_ir=True) +def test_ndrange_for(): + x = ti.field(ti.f32, shape=(16, 32, 64)) + + @ti.kernel + def func(): + for i, j, k in ti.ndrange((4, 10), (3, 8), 17): + x[i, j, k] = i + j * 10 + k * 100 + + func() + for i in range(16): + for j in range(32): + for k in range(64): + if 4 <= i < 10 and 3 <= j < 8 and k < 17: + assert x[i, j, k] == i + j * 10 + k * 100 + else: + assert x[i, j, k] == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_grouped_ndrange_for(): + x = ti.field(ti.i32, shape=(6, 6, 6)) + y = ti.field(ti.i32, shape=(6, 6, 6)) + + @ti.kernel + def func(): + lower = ti.Vector([0, 1, 2]) + upper = ti.Vector([3, 4, 5]) + for I in ti.grouped( + ti.ndrange((lower[0], upper[0]), (lower[1], upper[1]), + (lower[2], upper[2]))): + x[I] = I[0] + I[1] + I[2] + for i in range(0, 3): + for j in range(1, 4): + for k in range(2, 5): + y[i, j, k] = i + j + k + + func() + + for i in range(6): + for j in range(6): + for k in range(6): + assert x[i, j, k] == y[i, j, k] + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_for_break(): + n = 10 + + @ti.kernel + def foo(a: ti.template()): + for i in ti.static(range(n)): + a[i] = 3 + if ti.static(i >= 5): + break + a[i] = 10 + a[i] = 5 + + a = ti.field(ti.i32, shape=(n, )) + foo(a) + for i in range(n): + if i < 5: + assert a[i] == 5 + elif i == 5: + assert a[i] == 3 + else: + assert a[i] == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_grouped_for_break(): + n = 4 + + @ti.kernel + def foo(a: ti.template()): + for I in ti.static(ti.grouped(ti.ndrange(n, n))): + a[I] = 3 + if ti.static(I[0] >= 3): + break + a[I] = 10 + a[I] = 5 + + a = ti.field(ti.i32, shape=(n, n)) + foo(a) + for i in range(n): + for j in range(n): + if i < 3: + assert a[i, j] == 5 + elif i == 3 and j == 0: + assert a[i, j] == 3 + else: + assert a[i, j] == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_for_continue(): + n = 10 + + @ti.kernel + def foo(a: ti.template()): + for i in ti.static(range(n)): + a[i] = 3 + if ti.static(i >= 5): + continue + a[i] = 10 + a[i] = 5 + + a = ti.field(ti.i32, shape=(n, )) + foo(a) + for i in range(n): + if i < 5: + assert a[i] == 5 + else: + assert a[i] == 3 + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_grouped_for_continue(): + n = 4 + + @ti.kernel + def foo(a: ti.template()): + for I in ti.static(ti.grouped(ti.ndrange(n, n))): + a[I] = 3 + if ti.static(I[0] >= 3): + continue + a[I] = 10 + a[I] = 5 + + a = ti.field(ti.i32, shape=(n, n)) + foo(a) + for i in range(n): + for j in range(n): + if i < 3: + assert a[i, j] == 5 + else: + assert a[i, j] == 3 + + +@test_utils.test(print_preprocessed_ir=True) +def test_for_break(): + n = 4 + + @ti.kernel + def foo(a: ti.template()): + for i in range(n): + for j in range(n): + a[i, j] = 3 + if i >= 3: + break + a[i, j] = 10 + a[i, j] = 5 + + a = ti.field(ti.i32, shape=(n, n)) + foo(a) + for i in range(n): + for j in range(n): + if i < 3: + assert a[i, j] == 5 + elif i == 3 and j == 0: + assert a[i, j] == 3 + else: + assert a[i, j] == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_for_continue(): + n = 4 + + @ti.kernel + def foo(a: ti.template()): + for i in range(n): + for j in range(n): + a[i, j] = 3 + if i >= 3: + continue + a[i, j] = 10 + a[i, j] = 5 + + a = ti.field(ti.i32, shape=(n, n)) + foo(a) + for i in range(n): + for j in range(n): + if i < 3: + assert a[i, j] == 5 + else: + assert a[i, j] == 3 + + +@test_utils.test() +def test_while(): + x = ti.field(ti.f32) + + N = 1 + + ti.root.dense(ti.i, N).place(x) + + @ti.kernel + def func(): + i = 0 + s = 0 + while i < 10: + s += i + i += 1 + x[0] = s + + func() + assert x[0] == 45 + + +@test_utils.test() +def test_while_break(): + ret = ti.field(ti.i32, shape=()) + + @ti.kernel + def func(): + i = 0 + s = 0 + while True: + s += i + i += 1 + if i > 10: + break + ret[None] = s + + func() + assert ret[None] == 55 + + +@test_utils.test() +def test_while_continue(): + ret = ti.field(ti.i32, shape=()) + + @ti.kernel + def func(): + i = 0 + s = 0 + while i < 10: + i += 1 + if i % 2 == 0: + continue + s += i + ret[None] = s + + func() + assert ret[None] == 25 + + +@test_utils.test(print_preprocessed_ir=True) +def test_func(): + @ti.func + def bar(x): + return x * x, -x + + a = ti.field(ti.i32, shape=(10, )) + b = ti.field(ti.i32, shape=(10, )) + + @ti.kernel + def foo(): + for i in a: + a[i], b[i] = bar(i) + + foo() + for i in range(10): + assert a[i] == i * i + assert b[i] == -i + + +@test_utils.test(print_preprocessed_ir=True) +def test_func_in_python_func(): + @ti.func + def bar(x: ti.template()): + if ti.static(x): + mat = bar(x // 2) + mat = mat @ mat + if ti.static(x % 2): + mat = mat @ ti.Matrix([[1, 1], [1, 0]]) + return mat + else: + return ti.Matrix([[1, 0], [0, 1]]) + + def fibonacci(x): + return impl.subscript(bar(x), 1, 0) + + @ti.kernel + def foo(x: ti.template()) -> ti.i32: + return fibonacci(x) + + fib = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] + + for i in range(10): + assert foo(i) == fib[i] + + +@test_utils.test(print_preprocessed_ir=True) +def test_ifexp(): + @ti.kernel + def foo(x: ti.i32) -> ti.i32: + return 1 if x else 0 + + assert foo(1) == 1 + assert foo(0) == 0 + + +@test_utils.test(print_preprocessed_ir=True) +def test_static_ifexp(): + @ti.kernel + def foo(x: ti.template()) -> ti.i32: + return 1 if ti.static(x) else 0 + + assert foo(1) == 1 + assert foo(0) == 0 + + +@test_utils.test() +def test_static_assign(): + a = ti.field(ti.i32, shape=(1, )) + b = ti.field(ti.i32, shape=(1, )) + + @ti.kernel + def foo(xx: ti.template(), yy: ti.template()) -> ti.i32: + x, y = ti.static(xx, yy) + x[0] -= 1 + y[0] -= 1 + return x[0] + y[0] + + a[0] = 2 + b[0] = 3 + assert foo(a, b) == 3 + + +@test_utils.test() +def test_static_assign_element(): + with pytest.raises( + ti.TaichiCompilationError, + match='Static assign cannot be used on elements in arrays'): + + @ti.kernel + def foo(): + a = ti.static([1, 2, 3]) + a[0] = ti.static(2) + + foo() + + +@test_utils.test() +def test_recreate_variable(): + with pytest.raises(ti.TaichiCompilationError, + match='Recreating variables is not allowed'): + + @ti.kernel + def foo(): + a = 1 + a = ti.static(2) + + foo() + + +@test_utils.test() +def test_taichi_other_than_ti(): + import taichi as tc + + @tc.func + def bar(x: tc.template()): + if tc.static(x): + mat = bar(x // 2) + mat = mat @ mat + if tc.static(x % 2): + mat = mat @ tc.Matrix([[1, 1], [1, 0]]) + return mat + else: + return tc.Matrix([[1, 0], [0, 1]]) + + def fibonacci(x): + return impl.subscript(bar(x), 1, 0) + + @tc.kernel + def foo(x: tc.template()) -> tc.i32: + return fibonacci(x) + + fib = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34] + + for i in range(10): + assert foo(i) == fib[i] + + +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +def test_assert_message(): + @ti.kernel + def func(): + x = 20 + assert 10 <= x < 20, 'Foo bar' + + with pytest.raises(RuntimeError, match='Foo bar'): + func() + + +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) +def test_assert_message_formatted(): + x = ti.field(dtype=int, shape=16) + x[10] = 42 + + @ti.kernel + def assert_formatted(): + for i in x: + assert x[i] == 0, 'x[%d] expect=%d got=%d' % (i, 0, x[i]) + + @ti.kernel + def assert_float(): + y = 0.5 + assert y < 0, 'y = %f' % y + + with pytest.raises(RuntimeError, match=r'x\[10\] expect=0 got=42'): + assert_formatted() + # TODO: note that we are not fully polished to be able to recover from + # assertion failures... + with pytest.raises(RuntimeError, match=r'y = 0.5'): + assert_float() + + # success case + x[10] = 0 + assert_formatted() + + +@test_utils.test() +def test_dict(): + @ti.kernel + def foo(x: ti.template()) -> ti.i32: + a = {1: 2, 3: 4} + b = {5: 6, **a} + return b[x] + + assert foo(1) == 2 + with pytest.raises(ti.TaichiCompilationError): + foo(2) + + +@test_utils.test() +def test_listcomp(): + @ti.func + def identity(dt, n: ti.template()): + return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)] + for i in range(n)]) + + @ti.kernel + def foo(n: ti.template()) -> ti.i32: + a = identity(ti.i32, n) + b = [j for i in a for j in i] + ret = 0 + for i in ti.static(range(n)): + for j in ti.static(range(n)): + ret += i * j * b[i * n + j] + return ret + + assert foo(5) == 1 + 4 + 9 + 16 + + +@test_utils.test() +def test_dictcomp(): + @ti.kernel + def foo(n: ti.template()) -> ti.i32: + a = {i: i * i for i in range(n) if i % 3 if i % 2} + ret = 0 + for i in ti.static(range(n)): + if ti.static(i % 3): + if ti.static(i % 2): + ret += a[i] + return ret + + assert foo(10) == 1 * 1 + 5 * 5 + 7 * 7 + + +@test_utils.test() +def test_dictcomp_fail(): + @ti.kernel + def foo(n: ti.template(), m: ti.template()) -> ti.i32: + a = {i: i * i for i in range(n) if i % 3 if i % 2} + return a[m] + + with pytest.raises(ti.TaichiCompilationError): + foo(5, 2) + + with pytest.raises(ti.TaichiCompilationError): + foo(5, 3) + + +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.opengl]) +def test_ndarray(): + n = 4 + m = 7 + + @ti.kernel + def run(x: ti.any_arr(element_dim=2, layout=ti.Layout.AOS), + y: ti.any_arr()): + for i in ti.static(range(n)): + for j in ti.static(range(m)): + x[i, j][0, 0] += i + j + y[i, j] + + a = ti.Matrix.ndarray(1, 1, ti.i32, shape=(n, m)) + for i in range(n): + for j in range(m): + a[i, j][0, 0] = i * j + b = np.ones((n, m), dtype=np.int32) + run(a, b) + for i in range(n): + for j in range(m): + assert a[i, j][0, 0] == i * j + i + j + 1 + + +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_builder(): + n = 8 + Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + + @ti.kernel + def fill(Abuilder: ti.types.sparse_matrix_builder()): + for i, j in ti.static(ti.ndrange(n, n)): + Abuilder[i, j] += i + j + + fill(Abuilder) + A = Abuilder.build() + for i in range(n): + for j in range(n): + assert A[i, j] == i + j + + +@test_utils.test() +def test_func_default_value(): + @ti.func + def bar(s, t=1): + return s + t + + @ti.kernel + def foo() -> ti.i32: + return bar(1) + + assert foo() == 2 + + +@test_utils.test() +def test_func_default_value_fail(): + with pytest.raises(ti.TaichiCompilationError): + + @ti.func + def bar(s, t=1): + return s + t + + @ti.kernel + def foo() -> ti.i32: + return bar(1, 2, 3) + + foo() + + +@test_utils.test() +def test_raise(): + dim = 1 + m = ti.Matrix.field(dim, dim, ti.f32) + ti.root.place(m) + + with pytest.raises( + ti.TaichiCompilationError, + match="Polar decomposition only supports 2D and 3D matrices."): + + @ti.kernel + def foo(): + ti.polar_decompose(m, ti.f32) + + foo() + + +@test_utils.test() +def test_scalar_argument(): + @ti.kernel + def add(a: ti.f32, b: ti.f32) -> ti.f32: + a = a + b + return a + + assert add(1.0, 2.0) == test_utils.approx(3.0) + + +@test_utils.test() +def test_default_template_args_on_func(): + @ti.func + def bar(a: ti.template() = 123): + return a + + @ti.kernel + def foo() -> ti.i32: + return bar() + + assert foo() == 123 + + +@test_utils.test() +def test_grouped_static_for_cast(): + @ti.kernel + def foo() -> ti.f32: + ret = 0. + for I in ti.static(ti.grouped(ti.ndrange((4, 5), (3, 5), 5))): + tmp = I.cast(float) + ret += tmp[2] / 2 + return ret + + assert foo() == test_utils.approx(10) diff --git a/tests/python/test_async.py b/tests/python/test_async.py index bc2e8867f8dd6..18f5c5d562f9e 100644 --- a/tests/python/test_async.py +++ b/tests/python/test_async.py @@ -1,9 +1,10 @@ import numpy as np import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_simple(): n = 32 @@ -20,7 +21,7 @@ def double(): assert x[i] == i * 2 -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_numpy(): n = 10000 @@ -37,7 +38,7 @@ def inc(a: ti.ext_arr()): assert x[i] == i * 10 -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_listgen_opt_with_offsets(): x = ti.field(dtype=ti.i32) @@ -52,4 +53,5 @@ def inc(): inc() ti.sync() - assert ti.get_kernel_stats().get_counters()['launched_tasks_list_gen'] <= 2 + assert ti.tools.async_utils.get_kernel_stats().get_counters( + )['launched_tasks_list_gen'] <= 2 diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py index 84fe528255016..73896fcfa0dc5 100644 --- a/tests/python/test_atomic.py +++ b/tests/python/test_atomic.py @@ -1,5 +1,5 @@ import taichi as ti -from taichi import approx +from tests import test_utils n = 128 @@ -35,19 +35,45 @@ def func(): assert valproc(ya) == e -@ti.test() +@test_utils.test() def test_atomic_add_global_i32(): run_atomic_add_global_case(ti.i32, 42) -@ti.test() +@test_utils.test() def test_atomic_add_global_f32(): - run_atomic_add_global_case(ti.f32, - 4.2, - valproc=lambda x: approx(x, rel=1e-5)) + run_atomic_add_global_case( + ti.f32, 4.2, valproc=lambda x: test_utils.approx(x, rel=1e-5)) -@ti.test() +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_atomic_min_max_uint(): + x = ti.field(ti.u64, shape=100) + + @ti.kernel + def test0(): + for I in x: + x[I] = 0 + x[1] = ti.cast(1, ti.u64) << 63 + for I in x: + ti.atomic_max(x[0], x[I]) + + test0() + assert x[0] == 9223372036854775808 + + @ti.kernel + def test1(): + for I in x: + x[I] = ti.cast(1, ti.u64) << 63 + x[1] = 100 + for I in x: + ti.atomic_min(x[0], x[I]) + + test1() + assert x[0] == 100 + + +@test_utils.test() def test_atomic_add_expr_evaled(): c = ti.field(ti.i32) step = 42 @@ -65,7 +91,7 @@ def func(): assert c[None] == n * step -@ti.test() +@test_utils.test() def test_atomic_add_demoted(): # Ensure demoted atomics do not crash the program. x = ti.field(ti.i32) @@ -80,7 +106,7 @@ def func(): s = i # Both adds should get demoted. x[i] = ti.atomic_add(s, step) - y[i] = s.atomic_add(step) + y[i] = ti.atomic_add(s, step) func() @@ -89,7 +115,7 @@ def func(): assert y[i] == i + step -@ti.test() +@test_utils.test() def test_atomic_add_with_local_store_simplify1(): # Test for the following LocalStoreStmt simplification case: # @@ -122,7 +148,7 @@ def func(): assert y[i] == i -@ti.test() +@test_utils.test() def test_atomic_add_with_local_store_simplify2(): # Test for the following LocalStoreStmt simplification case: # @@ -148,7 +174,7 @@ def func(): assert x[i] == i -@ti.test() +@test_utils.test() def test_atomic_add_with_if_simplify(): # Make sure IfStmt simplification doesn't move stmts depending on the result # of atomic_add() @@ -166,13 +192,13 @@ def func(): # A sequence of commands designed such that atomic_add() is the only # thing to decide whether the if branch can be simplified. s = i - j = s.atomic_add(s) + j = ti.atomic_add(s, s) k = j + s x[i] = k else: # If we look at the IR, this branch should be simplified, since nobody # is using atomic_add's result. - x[i].atomic_add(i) + ti.atomic_add(x[i], i) x[i] += step func() @@ -182,7 +208,7 @@ def func(): assert x[i] == expect -@ti.test() +@test_utils.test() def test_local_atomic_with_if(): ret = ti.field(dtype=ti.i32, shape=()) @@ -197,7 +223,7 @@ def test(): assert ret[None] == 1 -@ti.test() +@test_utils.test() def test_atomic_sub_expr_evaled(): c = ti.field(ti.i32) step = 42 @@ -215,7 +241,7 @@ def func(): assert c[None] == -n * step -@ti.test() +@test_utils.test() def test_atomic_max_expr_evaled(): c = ti.field(ti.i32) step = 42 @@ -233,7 +259,7 @@ def func(): assert c[None] == (n - 1) * step -@ti.test() +@test_utils.test() def test_atomic_min_expr_evaled(): c = ti.field(ti.i32) step = 42 @@ -252,7 +278,7 @@ def func(): assert c[None] == 0 -@ti.test() +@test_utils.test() def test_atomic_and_expr_evaled(): c = ti.field(ti.i32) step = 42 @@ -273,7 +299,7 @@ def func(): assert c[None] == 0 -@ti.test() +@test_utils.test() def test_atomic_or_expr_evaled(): c = ti.field(ti.i32) step = 42 @@ -292,7 +318,7 @@ def func(): assert c[None] == 1023 -@ti.test() +@test_utils.test() def test_atomic_xor_expr_evaled(): c = ti.field(ti.i32) step = 42 diff --git a/tests/python/test_basics.py b/tests/python/test_basics.py index ff474e7c14a87..cd452526c13e0 100644 --- a/tests/python/test_basics.py +++ b/tests/python/test_basics.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_simple(): n = 128 x = ti.field(ti.i32, shape=n) @@ -19,7 +20,7 @@ def func(): assert x[i] == 0 -@ti.test() +@test_utils.test() def test_range_loops(): n = 128 x = ti.field(ti.i32, shape=n) @@ -35,7 +36,7 @@ def func(): assert x[i] == i + 123 -@ti.test() +@test_utils.test() def test_python_access(): n = 128 x = ti.field(ti.i32, shape=n) @@ -46,7 +47,7 @@ def test_python_access(): assert x[4] == 456 -@ti.test() +@test_utils.test() def test_if(): x = ti.field(ti.f32, shape=16) @@ -77,7 +78,7 @@ def if_test2(): assert x[i] == i -@ti.test() +@test_utils.test() def test_if_global_load(): x = ti.field(ti.i32, shape=16) @@ -99,7 +100,7 @@ def fill(): assert x[i] == i -@ti.test() +@test_utils.test() def test_while_global_load(): x = ti.field(ti.i32, shape=16) y = ti.field(ti.i32, shape=()) @@ -116,3 +117,12 @@ def run(): run() assert y[None] == 3 + + +@test_utils.test() +def test_datatype_string(): + for ty in [ + ti.u8, ti.u16, ti.u32, ti.u64, ti.i8, ti.i16, ti.i32, ti.f32, + ti.f64 + ]: + assert ty.to_string() == str(ty) diff --git a/tests/python/test_binding.py b/tests/python/test_binding.py index 7447ac9eb38cc..6a7b197d29e4d 100644 --- a/tests/python/test_binding.py +++ b/tests/python/test_binding.py @@ -3,10 +3,10 @@ def test_binding(): ti.init() - taichi_lang = ti.core + taichi_lang = ti._lib.core print(taichi_lang.BinaryOpType.mul) - one = taichi_lang.make_const_expr_i32(1) - two = taichi_lang.make_const_expr_i32(2) + one = taichi_lang.make_const_expr_int(ti.i32, 1) + two = taichi_lang.make_const_expr_int(ti.i32, 2) expr = taichi_lang.make_binary_op_expr(taichi_lang.BinaryOpType.add, one, two) print(expr.serialize()) diff --git a/tests/python/test_bit_array.py b/tests/python/test_bit_array.py index ed894f297c7a9..d5426ab3c9f8c 100644 --- a/tests/python/test_bit_array.py +++ b/tests/python/test_bit_array.py @@ -1,11 +1,12 @@ import numpy as np import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.quant, debug=True) +@test_utils.test(require=ti.extension.quant, debug=True) def test_1D_bit_array(): - cu1 = ti.quant.int(1, False) + cu1 = ti.types.quantized_types.quant.int(1, False) x = ti.field(dtype=cu1) @@ -27,9 +28,9 @@ def verify_val(): verify_val() -@ti.test(require=ti.extension.quant, debug=True) +@test_utils.test(require=ti.extension.quant, debug=True) def test_2D_bit_array(): - ci1 = ti.quant.int(1, False) + ci1 = ti.types.quantized_types.quant.int(1, False) x = ti.field(dtype=ci1) diff --git a/tests/python/test_bit_array_vectorization.py b/tests/python/test_bit_array_vectorization.py index 9fad21763e919..02afb38ec2f3c 100644 --- a/tests/python/test_bit_array_vectorization.py +++ b/tests/python/test_bit_array_vectorization.py @@ -1,11 +1,14 @@ -import numpy as np +from taichi.lang.impl import get_runtime import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False) +@test_utils.test(require=ti.extension.quant, + debug=True, + cfg_optimization=False) def test_vectorized_struct_for(): - cu1 = ti.quant.int(1, False) + cu1 = ti.types.quantized_types.quant.int(1, False) x = ti.field(dtype=cu1) y = ti.field(dtype=cu1) @@ -29,7 +32,7 @@ def init(): @ti.kernel def assign_vectorized(): - ti.bit_vectorize(32) + get_runtime().prog.current_ast_builder().bit_vectorize(32) for i, j in x: y[i, j] = x[i, j] @@ -44,9 +47,9 @@ def verify(): verify() -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_offset_load(): - ci1 = ti.quant.int(1, False) + ci1 = ti.types.quantized_types.quant.int(1, False) x = ti.field(dtype=ci1) y = ti.field(dtype=ci1) @@ -74,7 +77,7 @@ def init(): @ti.kernel def assign_vectorized(dx: ti.template(), dy: ti.template()): - ti.bit_vectorize(32) + get_runtime().prog.current_ast_builder().bit_vectorize(32) for i, j in x: y[i, j] = x[i + dx, j + dy] z[i, j] = x[i + dx, j + dy] @@ -104,9 +107,9 @@ def verify(dx: ti.template(), dy: ti.template()): verify(-1, 1) -@ti.test(require=ti.extension.quant, debug=True) +@test_utils.test(require=ti.extension.quant, debug=True) def test_evolve(): - ci1 = ti.quant.int(1, False) + ci1 = ti.types.quantized_types.quant.int(1, False) x = ti.field(dtype=ci1) y = ti.field(dtype=ci1) @@ -134,7 +137,7 @@ def init(): @ti.kernel def evolve_vectorized(x: ti.template(), y: ti.template()): - ti.bit_vectorize(32) + get_runtime().prog.current_ast_builder().bit_vectorize(32) for i, j in x: num_active_neighbors = 0 num_active_neighbors += ti.cast(x[i - 1, j - 1], ti.u32) @@ -145,8 +148,8 @@ def evolve_vectorized(x: ti.template(), y: ti.template()): num_active_neighbors += ti.cast(x[i + 1, j - 1], ti.u32) num_active_neighbors += ti.cast(x[i + 1, j], ti.u32) num_active_neighbors += ti.cast(x[i + 1, j + 1], ti.u32) - y[i, j] = (num_active_neighbors == 3) or (num_active_neighbors == 2 - and x[i, j] == 1) + y[i, j] = (num_active_neighbors == 3) | \ + ((num_active_neighbors == 2) & (x[i, j] == 1)) @ti.kernel def evolve_naive(x: ti.template(), y: ti.template()): diff --git a/tests/python/test_bit_operations.py b/tests/python/test_bit_operations.py index 28956f1a2f19a..5c11194dd69ce 100644 --- a/tests/python/test_bit_operations.py +++ b/tests/python/test_bit_operations.py @@ -4,10 +4,10 @@ import pytest import taichi as ti -from taichi import allclose +from tests import test_utils -@ti.test() +@test_utils.test() def test_bit_shl(): @ti.kernel def shl(a: ti.i32, b: ti.i32) -> ti.i32: @@ -17,7 +17,7 @@ def shl(a: ti.i32, b: ti.i32) -> ti.i32: assert shl(3, i) == 3 * 2**i -@ti.test() +@test_utils.test() def test_bit_sar(): @ti.kernel def sar(a: ti.i32, b: ti.i32) -> ti.i32: @@ -33,7 +33,7 @@ def sar(a: ti.i32, b: ti.i32) -> ti.i32: assert sar(neg_test_num, i) == -2**(n - i) -@ti.test() +@test_utils.test() def test_bit_shr(): @ti.kernel def shr(a: ti.i32, b: ti.i32) -> ti.i32: diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 8d0d96d475925..dde9792304103 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -2,12 +2,13 @@ from pytest import approx import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.quant_basic, debug=True) +@test_utils.test(require=ti.extension.quant_basic, debug=True) def test_simple_array(): - ci13 = ti.quant.int(13, True) - cu19 = ti.quant.int(19, False) + ci13 = ti.types.quantized_types.quant.int(13, True) + cu19 = ti.types.quantized_types.quant.int(19, False) x = ti.field(dtype=ci13) y = ti.field(dtype=cu19) @@ -37,11 +38,13 @@ def verify_val(): # TODO: remove excluding of ti.metal -@ti.test(require=ti.extension.quant_basic, exclude=[ti.metal], debug=True) +@test_utils.test(require=ti.extension.quant_basic, + exclude=[ti.metal], + debug=True) def test_custom_int_load_and_store(): - ci13 = ti.quant.int(13, True) - cu14 = ti.quant.int(14, False) - ci5 = ti.quant.int(5, True) + ci13 = ti.types.quantized_types.quant.int(13, True) + cu14 = ti.types.quantized_types.quant.int(14, False) + ci5 = ti.types.quantized_types.quant.int(5, True) x = ti.field(dtype=ci13) y = ti.field(dtype=cu14) @@ -78,9 +81,9 @@ def verify_val(idx: ti.i32): verify_val.__wrapped__(idx) -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_custom_int_full_struct(): - cit = ti.quant.int(32, True) + cit = ti.types.quantized_types.quant.int(32, True) x = ti.field(dtype=cit) ti.root.dense(ti.i, 1).bit_struct(num_bits=32).place(x) @@ -96,9 +99,12 @@ def test_single_bit_struct(physical_type, compute_type, custom_bits, test_case): ti.init(arch=ti.cpu, debug=True) - cit1 = ti.quant.int(custom_bits[0], True, compute_type) - cit2 = ti.quant.int(custom_bits[1], False, compute_type) - cit3 = ti.quant.int(custom_bits[2], True, compute_type) + cit1 = ti.types.quantized_types.quant.int(custom_bits[0], True, + compute_type) + cit2 = ti.types.quantized_types.quant.int(custom_bits[1], False, + compute_type) + cit3 = ti.types.quantized_types.quant.int(custom_bits[2], True, + compute_type) a = ti.field(dtype=cit1) b = ti.field(dtype=cit2) @@ -139,12 +145,13 @@ def verify_val(test_val: ti.ext_arr()): test_single_bit_struct(32, ti.i32, [10, 10, 12], np.array([11, 19, 2020])) -@ti.test(require=[ti.extension.quant_basic, ti.extension.sparse], debug=True) +@test_utils.test(require=[ti.extension.quant_basic, ti.extension.sparse], + debug=True) def test_bit_struct_struct_for(): block_size = 16 N = 64 cell = ti.root.pointer(ti.i, N // block_size) - fixed32 = ti.quant.fixed(frac=32, range=1024) + fixed32 = ti.types.quantized_types.quant.fixed(frac=32, num_range=1024) x = ti.field(dtype=fixed32) cell.dense(ti.i, block_size).bit_struct(32).place(x) diff --git a/tests/python/test_bitmasked.py b/tests/python/test_bitmasked.py index 4d5887e4140e7..e31b4eb71f1e1 100644 --- a/tests/python/test_bitmasked.py +++ b/tests/python/test_bitmasked.py @@ -1,4 +1,5 @@ import taichi as ti +from tests import test_utils def _test_basic(): @@ -29,17 +30,18 @@ def sum(): assert s[None] == 42 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_basic(): _test_basic() -@ti.test(require=[ti.extension.sparse, ti.extension.packed], packed=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], + packed=True) def test_basic_packed(): _test_basic() -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_bitmasked_then_dense(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -63,7 +65,7 @@ def func(): assert s[None] == 256 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_bitmasked_bitmasked(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -87,7 +89,7 @@ def func(): assert s[None] == 4 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_huge_bitmasked(): # Mainly for testing Metal listgen's grid-stride loop implementation. x = ti.field(ti.f32) @@ -114,7 +116,7 @@ def count(): assert s[None] == (n * n * 2) // 32 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_bitmasked_listgen_bounded(): # Mainly for testing Metal's listgen is bounded by the actual number of # elements possible for that SNode. Note that 1) SNode's size is padded @@ -145,7 +147,7 @@ def count(): assert c[None] == n -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_deactivate(): # https://github.com/taichi-dev/taichi/issues/778 a = ti.field(ti.i32) @@ -209,17 +211,18 @@ def run(): assert s[None] == 42 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_sparsity_changes(): _test_sparsity_changes() -@ti.test(require=[ti.extension.sparse, ti.extension.packed], packed=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], + packed=True) def test_sparsity_changes_packed(): _test_sparsity_changes() -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_bitmasked_offset_child(): x = ti.field(ti.i32) x2 = ti.field(ti.i32) @@ -258,7 +261,7 @@ def func(): assert s[None] == 7 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_bitmasked_2d_power_of_two(): some_val = ti.field(dtype=float) width, height = 10, 10 @@ -286,7 +289,7 @@ def run(): assert num_active[None] == total -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_root_deactivate(): a = ti.field(ti.i32) a_a = ti.root.bitmasked(ti.i, 4) diff --git a/tests/python/test_bls.py b/tests/python/test_bls.py index 065be8ea4001e..b4116bbf01633 100644 --- a/tests/python/test_bls.py +++ b/tests/python/test_bls.py @@ -1,32 +1,8 @@ -import pytest - import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.bls, dynamic_index=True) -def test_bls_with_dynamic_index(): - x, y = ti.field(ti.f32), ti.field(ti.f32) - - N = 64 - bs = 16 - - ti.root.pointer(ti.i, N // bs).dense(ti.i, bs).place(x, y) - - @ti.kernel - def populate(): - for i in range(N): - x[i] = i - - @ti.kernel - def call_block_local(): - ti.block_local(x) - - populate() - with pytest.raises(ti.InvalidOperationError): - call_block_local() - - -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_simple_1d(): x, y = ti.field(ti.f32), ti.field(ti.f32) @@ -53,7 +29,7 @@ def copy(): assert y[i] == i -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_simple_2d(): x, y = ti.field(ti.f32), ti.field(ti.f32) @@ -86,43 +62,43 @@ def _test_bls_stencil(*args, **kwargs): bls_test_template(*args, **kwargs) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_gather_1d_trivial(): # y[i] = x[i] _test_bls_stencil(1, 128, bs=32, stencil=((0, ), )) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_gather_1d(): # y[i] = x[i - 1] + x[i] _test_bls_stencil(1, 128, bs=32, stencil=((-1, ), (0, ))) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_gather_2d(): stencil = [(0, 0), (0, -1), (0, 1), (1, 0)] _test_bls_stencil(2, 128, bs=16, stencil=stencil) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_gather_2d_nonsquare(): stencil = [(0, 0), (0, -1), (0, 1), (1, 0)] _test_bls_stencil(2, 128, bs=(4, 16), stencil=stencil) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_gather_3d(): stencil = [(-1, -1, -1), (2, 0, 1)] _test_bls_stencil(3, 64, bs=(4, 8, 16), stencil=stencil) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_scatter_1d_trivial(): # y[i] = x[i] _test_bls_stencil(1, 128, bs=32, stencil=((0, ), ), scatter=True) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_scatter_1d(): _test_bls_stencil(1, 128, bs=32, stencil=( (1, ), @@ -130,13 +106,13 @@ def test_scatter_1d(): ), scatter=True) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_scatter_2d(): stencil = [(0, 0), (0, -1), (0, 1), (1, 0)] _test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_multiple_inputs(): x, y, z, w, w2 = ti.field(ti.i32), ti.field(ti.i32), ti.field( ti.i32), ti.field(ti.i32), ti.field(ti.i32) @@ -171,7 +147,7 @@ def copy(bls: ti.template(), w: ti.template()): assert w[i, j] == w2[i, j] -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_bls_large_block(): n = 2**10 block_size = 32 diff --git a/tests/python/test_bls_assume_in_range.py b/tests/python/test_bls_assume_in_range.py index dcd2dc3b613ec..c7dc0cbe3ceab 100644 --- a/tests/python/test_bls_assume_in_range.py +++ b/tests/python/test_bls_assume_in_range.py @@ -1,10 +1,11 @@ import taichi as ti +from tests import test_utils from .bls_test_template import bls_particle_grid -@ti.test(require=ti.extension.bls) -def _test_scattering(): +@test_utils.test(require=ti.extension.bls) +def test_scattering(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -12,8 +13,8 @@ def _test_scattering(): use_offset=False) -@ti.test(require=ti.extension.bls) -def _test_scattering_offset(): +@test_utils.test(require=ti.extension.bls) +def test_scattering_offset(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -21,8 +22,8 @@ def _test_scattering_offset(): use_offset=True) -@ti.test(require=ti.extension.bls) -def _test_scattering_two_pointer_levels(): +@test_utils.test(require=ti.extension.bls) +def test_scattering_two_pointer_levels(): bls_particle_grid(N=128, ppc=10, block_size=8, @@ -31,7 +32,7 @@ def _test_scattering_two_pointer_levels(): use_offset=False) -@ti.test(require=ti.extension.bls, dynamic_index=False) +@test_utils.test(require=ti.extension.bls) def test_gathering(): bls_particle_grid(N=128, ppc=10, @@ -40,8 +41,8 @@ def test_gathering(): use_offset=False) -@ti.test(require=ti.extension.bls) -def _test_gathering_offset(): +@test_utils.test(require=ti.extension.bls) +def test_gathering_offset(): bls_particle_grid(N=128, ppc=10, block_size=8, diff --git a/tests/python/test_bool_op.py b/tests/python/test_bool_op.py new file mode 100644 index 0000000000000..d15160e68f1ab --- /dev/null +++ b/tests/python/test_bool_op.py @@ -0,0 +1,68 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test(debug=True, short_circuit_operators=True) +def test_and_shorted(): + a = ti.field(ti.i32, shape=10) + + @ti.func + def explode() -> ti.i32: + return a[-1] + + @ti.kernel + def func() -> ti.i32: + return False and explode() + + assert func() == 0 + + +@test_utils.test(debug=True, short_circuit_operators=True) +def test_and_not_shorted(): + @ti.kernel + def func() -> ti.i32: + return True and False + + assert func() == 0 + + +@test_utils.test(debug=True, short_circuit_operators=True) +def test_or_shorted(): + a = ti.field(ti.i32, shape=10) + + @ti.func + def explode() -> ti.i32: + return a[-1] + + @ti.kernel + def func() -> ti.i32: + return True or explode() + + assert func() == 1 + + +@test_utils.test(debug=True, short_circuit_operators=True) +def test_or_not_shorted(): + @ti.kernel + def func() -> ti.i32: + return False or True + + assert func() == 1 + + +@test_utils.test(debug=True) +def test_static_or(): + @ti.kernel + def func() -> ti.i32: + return ti.static(0 or 3 or 5) + + assert func() == 3 + + +@test_utils.test(debug=True) +def test_static_and(): + @ti.kernel + def func() -> ti.i32: + return ti.static(5 and 2 and 0) + + assert func() == 0 diff --git a/tests/python/test_callable_template_mapper.py b/tests/python/test_callable_template_mapper.py index 21adb989deca0..d52be405dd57f 100644 --- a/tests/python/test_callable_template_mapper.py +++ b/tests/python/test_callable_template_mapper.py @@ -1,9 +1,10 @@ from taichi.lang.kernel_impl import TaichiCallableTemplateMapper import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_callable_template_mapper(): x = ti.field(ti.i32) y = ti.field(ti.f32) @@ -33,7 +34,7 @@ def test_callable_template_mapper(): assert mapper.lookup((0, x, 1))[0] == 0 -@ti.test() +@test_utils.test() def test_callable_template_mapper_numpy(): x = ti.field(ti.i32) y = ti.field(ti.f32) diff --git a/tests/python/test_cast.py b/tests/python/test_cast.py index bf8f680ddac14..0bcca774a2b3a 100644 --- a/tests/python/test_cast.py +++ b/tests/python/test_cast.py @@ -1,9 +1,30 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test() +@pytest.mark.parametrize('dtype', [ti.u8, ti.u16, ti.u32]) +@test_utils.test(exclude=ti.opengl) +def test_cast_uint_to_float(dtype): + @ti.kernel + def func(a: dtype) -> ti.f32: + return ti.cast(a, ti.f32) + + assert func(255) == 255 + + +@pytest.mark.parametrize('dtype', [ti.u8, ti.u16, ti.u32]) +@test_utils.test(exclude=ti.opengl) +def test_cast_float_to_uint(dtype): + @ti.kernel + def func(a: ti.f32) -> dtype: + return ti.cast(a, dtype) + + assert func(255) == 255 + + +@test_utils.test() def test_cast_f32(): z = ti.field(ti.i32, shape=()) @@ -15,7 +36,7 @@ def func(): assert z[None] == 1000 -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_cast_f64(): z = ti.field(ti.i32, shape=()) @@ -54,7 +75,7 @@ def func(x: float, y: float) -> int: assert func(233, large) == 233 * large -@ti.test() +@test_utils.test() def test_cast_within_while(): ret = ti.field(ti.i32, shape=()) @@ -69,7 +90,7 @@ def func(): func() -@ti.test() +@test_utils.test() def test_bit_cast(): x = ti.field(ti.i32, shape=()) y = ti.field(ti.f32, shape=()) @@ -89,7 +110,7 @@ def func2(): assert z[None] == 2333 -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_int_extension(): x = ti.field(dtype=ti.i32, shape=2) y = ti.field(dtype=ti.u32, shape=2) @@ -119,13 +140,13 @@ def run_cast_u32(): assert y[1] == 128 -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_custom_int_extension(): x = ti.field(dtype=ti.i32, shape=2) y = ti.field(dtype=ti.u32, shape=2) - ci5 = ti.quant.int(5, True, ti.i16) - cu7 = ti.quant.int(7, False, ti.u16) + ci5 = ti.types.quantized_types.quant.int(5, True, ti.i16) + cu7 = ti.types.quantized_types.quant.int(7, False, ti.u16) a = ti.field(dtype=ci5) b = ti.field(dtype=cu7) diff --git a/tests/python/test_cell_size_inspection.py b/tests/python/test_cell_size_inspection.py deleted file mode 100644 index 39b07d1987439..0000000000000 --- a/tests/python/test_cell_size_inspection.py +++ /dev/null @@ -1,43 +0,0 @@ -import taichi as ti - - -@ti.test(arch=ti.cpu) -def test_primitives(): - x = ti.field(dtype=ti.i16) - y = ti.field(dtype=ti.f32) - z = ti.field(dtype=ti.f64) - - p = ti.field(dtype=ti.f32) - q = ti.field(dtype=ti.f32) - r = ti.field(dtype=ti.f64) - - n1 = ti.root.dense(ti.i, 32) - n1.place(x) - - n2 = ti.root.dense(ti.i, 32) - n2.place(y, z) - - n3 = ti.root.dense(ti.i, 1) - n3.place(p, q, r) - - assert n1.cell_size_bytes == 2 - assert 12 <= n2.cell_size_bytes <= 16 - assert n3.cell_size_bytes == 16 - - -@ti.test(arch=ti.cpu) -def test_bit_struct(): - cit = ti.quant.int(16, False) - x = ti.field(dtype=cit) - y = ti.field(dtype=ti.type_factory.custom_float(significand_type=cit)) - z = ti.field(dtype=ti.f32) - - n1 = ti.root.dense(ti.i, 32) - n1.bit_struct(num_bits=32).place(x) - - n2 = ti.root.dense(ti.i, 4) - n2.bit_struct(num_bits=32).place(y) - n2.place(z) - - assert n1.cell_size_bytes == 4 - assert n2.cell_size_bytes == 8 diff --git a/tests/python/test_classfunc.py b/tests/python/test_classfunc.py index 6c7c8d249620f..95c4b2c6f34ea 100644 --- a/tests/python/test_classfunc.py +++ b/tests/python/test_classfunc.py @@ -1,7 +1,10 @@ +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_classfunc(): @ti.data_oriented class Foo: diff --git a/tests/python/test_clear_all_gradients.py b/tests/python/test_clear_all_gradients.py index 8624d6ec52060..f93e780494b4d 100644 --- a/tests/python/test_clear_all_gradients.py +++ b/tests/python/test_clear_all_gradients.py @@ -1,7 +1,10 @@ +from taichi.lang import impl + import taichi as ti +from tests import test_utils -@ti.test(exclude=[ti.vulkan]) +@test_utils.test(exclude=[ti.vulkan]) def test_clear_all_gradients(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -23,7 +26,7 @@ def test_clear_all_gradients(): w.grad[i, j] = 6 ti.clear_all_gradients() - assert ti.get_runtime().get_num_compiled_functions() == 3 + assert impl.get_runtime().get_num_compiled_functions() == 3 assert x.grad[None] == 0 for i in range(n): @@ -34,4 +37,4 @@ def test_clear_all_gradients(): ti.clear_all_gradients() # No more kernel compilation - assert ti.get_runtime().get_num_compiled_functions() == 3 + assert impl.get_runtime().get_num_compiled_functions() == 3 diff --git a/tests/python/test_cli.py b/tests/python/test_cli.py index 5ca33d1fdf220..ec61d43f75129 100644 --- a/tests/python/test_cli.py +++ b/tests/python/test_cli.py @@ -6,7 +6,7 @@ from unittest.mock import patch import pytest -from taichi.main import TaichiMain +from taichi._main import TaichiMain import taichi as ti @@ -207,27 +207,6 @@ def test_cli_benchmark(): assert args.threads == "4" -def test_cli_test(): - with patch_sys_argv_helper( - ["ti", "test", "cli", "atomic", "-c", "-v", "-r2", - "-t4"]) as custom_argv: - cli = TaichiMain(test_mode=True) - args = cli() - assert args.files == ["cli", "atomic"] - assert args.cpp == True - assert args.verbose == True - assert args.rerun == "2" - assert args.threads == "4" - - with patch_sys_argv_helper( - ["ti", "test", "cli", "atomic", "-c", "-v", "-r2", - "-t4"]) as custom_argv: - with patch.object(TaichiMain, 'test', return_value=1) as mock_method: - cli = TaichiMain(test_mode=False) - return_code = cli() - assert return_code == 1 - - def test_cli_debug(): with patch_sys_argv_helper(["ti", "debug", "a.py"]) as custom_argv: cli = TaichiMain(test_mode=True) @@ -240,12 +219,3 @@ def test_cli_run(): cli = TaichiMain(test_mode=True) args = cli() assert args.filename == "a.py" - - -def test_cli_task(): - with patch_sys_argv_helper(["ti", "task", "test_task", "arg1", - "arg2"]) as custom_argv: - cli = TaichiMain(test_mode=True) - args = cli() - assert args.taskname == "test_task" - assert args.taskargs == ["arg1", "arg2"] diff --git a/tests/python/test_compare.py b/tests/python/test_compare.py index 9b00e5853c0a6..ea0a45b941782 100644 --- a/tests/python/test_compare.py +++ b/tests/python/test_compare.py @@ -1,7 +1,11 @@ +import pytest +from taichi.lang import impl + import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_compare_basics(): a = ti.field(ti.i32) ti.root.dynamic(ti.i, 256).place(a) @@ -40,7 +44,7 @@ def func(): assert a[11] -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_compare_equality(): a = ti.field(ti.i32) ti.root.dynamic(ti.i, 256).place(a) @@ -79,7 +83,7 @@ def func(): assert not a[11] -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_no_duplicate_eval(): a = ti.field(ti.i32) ti.root.dynamic(ti.i, 256).place(a) @@ -94,7 +98,7 @@ def func(): assert a[2] # ti.append returns 0 -@ti.test() +@test_utils.test() def test_no_duplicate_eval_func(): a = ti.field(ti.i32, ()) b = ti.field(ti.i32, ()) @@ -104,7 +108,7 @@ def why_this_foo_fail(n): return ti.atomic_add(b[None], n) def foo(n): - return ti.atomic_add(ti.subscript(b, None), n) + return ti.atomic_add(impl.subscript(b, None), n) @ti.kernel def func(): @@ -115,7 +119,7 @@ def func(): assert b[None] == 2 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_chain_compare(): a = ti.field(ti.i32) ti.root.dynamic(ti.i, 256).place(a) @@ -135,3 +139,62 @@ def func(): func() assert a[0] assert not a[1] + + +@test_utils.test() +def test_static_in(): + @ti.kernel + def foo(a: ti.template()) -> ti.i32: + b = 0 + if ti.static(a in [ti.i32, ti.u32]): + b = 1 + elif ti.static(a not in [ti.f32, ti.f64]): + b = 2 + return b + + assert foo(ti.u32) == 1 + assert foo(ti.i64) == 2 + assert foo(ti.f32) == 0 + + +@test_utils.test() +def test_non_static_in(): + with pytest.raises(ti.TaichiCompilationError, + match='"In" is only supported inside `ti.static`.'): + + @ti.kernel + def foo(a: ti.template()) -> ti.i32: + b = 0 + if a in [ti.i32, ti.u32]: + b = 1 + return b + + foo(ti.i32) + + +@test_utils.test() +def test_static_is(): + @ti.kernel + def is_f32(tp: ti.template()) -> ti.i32: + return ti.static(tp is ti.f32) + + @ti.kernel + def is_not_f32(tp: ti.template()) -> ti.i32: + return ti.static(tp is not ti.f32) + + assert is_f32(ti.f32) == 1 + assert is_f32(ti.i32) == 0 + assert is_not_f32(ti.f32) == 0 + assert is_not_f32(ti.i32) == 1 + + +@test_utils.test() +def test_non_static_is(): + with pytest.raises(ti.TaichiCompilationError, + match='"Is" is only supported inside `ti.static`.'): + + @ti.kernel + def is_f32(tp: ti.template()) -> ti.i32: + return tp is ti.f32 + + is_f32(ti.f32) diff --git a/tests/python/test_complex_struct.py b/tests/python/test_complex_struct.py index 3a180a3a9e9db..cc698e8ebc05c 100644 --- a/tests/python/test_complex_struct.py +++ b/tests/python/test_complex_struct.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_complex_dense(): a = ti.field(ti.i32, shape=(4, 4)) b = ti.field(ti.i32, shape=(16, 16)) @@ -87,7 +88,7 @@ def set_d(): assert d[i, j, k] == 4 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_complex_pointer(): a = ti.field(ti.i32, shape=(4, 4)) b = ti.field(ti.i32, shape=(16, 16)) diff --git a/tests/python/test_constant_fold.py b/tests/python/test_constant_fold.py index 2be2106aaf288..98aca30e9823e 100644 --- a/tests/python/test_constant_fold.py +++ b/tests/python/test_constant_fold.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_constant_fold(): n = 100 diff --git a/tests/python/test_constant_func.py b/tests/python/test_constant_func.py deleted file mode 100644 index 59eea2ea1ffdc..0000000000000 --- a/tests/python/test_constant_func.py +++ /dev/null @@ -1,98 +0,0 @@ -import operator as ops - -import numpy as np -import pytest - -import taichi as ti -from taichi import allclose - -binary_func_table = [ - (ops.add, ) * 2, - (ops.sub, ) * 2, - (ops.mul, ) * 2, - (ops.truediv, ) * 2, - (ops.floordiv, ) * 2, - (ops.mod, ) * 2, - (ops.pow, ) * 2, - (ops.and_, ) * 2, - (ops.or_, ) * 2, - (ops.xor, ) * 2, - (ops.eq, ) * 2, - (ops.ne, ) * 2, - (ops.lt, ) * 2, - (ops.le, ) * 2, - (ops.gt, ) * 2, - (ops.ge, ) * 2, - (ti.max, np.maximum), - (ti.min, np.minimum), - (ti.atan2, np.arctan2), -] - -unary_func_table = [ - (ops.neg, ) * 2, - (ops.invert, ) * 2, - (ti.logical_not, np.logical_not), - (ti.abs, np.abs), - (ti.exp, np.exp), - (ti.log, np.log), - (ti.sin, np.sin), - (ti.cos, np.cos), - (ti.tan, np.tan), - (ti.asin, np.arcsin), - (ti.acos, np.arccos), - (ti.tanh, np.tanh), - (ti.floor, np.floor), - (ti.ceil, np.ceil), -] - - -@pytest.mark.parametrize('ti_func,np_func', binary_func_table) -def test_python_scope_vector_binary(ti_func, np_func): - ti.init() - x = ti.Vector([2, 3]) - y = ti.Vector([5, 4]) - - result = ti_func(x, y).to_numpy() - if ti_func in [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge]: - result = result.astype(bool) - expected = np_func(x.to_numpy(), y.to_numpy()) - assert allclose(result, expected) - - -@pytest.mark.parametrize('ti_func,np_func', unary_func_table) -def test_python_scope_vector_unary(ti_func, np_func): - ti.init() - x = ti.Vector([2, 3] if ti_func in - [ops.invert, ti.logical_not] else [0.2, 0.3]) - - result = ti_func(x).to_numpy() - if ti_func in [ti.logical_not]: - result = result.astype(bool) - expected = np_func(x.to_numpy()) - assert allclose(result, expected) - - -def test_python_scope_matmul(): - ti.init() - a = np.array([[1, 2], [3, 4]]) - b = np.array([[5, 6], [7, 8]]) - x = ti.Vector(a) - y = ti.Vector(b) - - result = (x @ y).to_numpy() - expected = a @ b - assert allclose(result, expected) - - -def test_python_scope_linalg(): - ti.init() - a = np.array([3, 4, -2]) - b = np.array([-5, 0, 6]) - x = ti.Vector(a) - y = ti.Vector(b) - - assert allclose(x.dot(y), np.dot(a, b)) - assert allclose(x.norm(), np.sqrt(np.dot(a, a))) - assert allclose(x.normalized(), a / np.sqrt(np.dot(a, a))) - assert x.any() == 1 # To match that of Taichi IR, we return -1 for True - assert y.all() == 0 diff --git a/tests/python/test_continue.py b/tests/python/test_continue.py index 752644973d648..cdb13d54cd73a 100644 --- a/tests/python/test_continue.py +++ b/tests/python/test_continue.py @@ -1,9 +1,10 @@ import taichi as ti +from tests import test_utils n = 1000 -@ti.test() +@test_utils.test() def test_for_continue(): x = ti.field(ti.i32, shape=n) @@ -23,7 +24,7 @@ def run(): assert xs[i] == expect -@ti.test() +@test_utils.test() def test_while_continue(): x = ti.field(ti.i32, shape=n) @@ -46,7 +47,7 @@ def run(): assert xs[i] == expect -@ti.test() +@test_utils.test() def test_kernel_continue(): x = ti.field(ti.i32, shape=n) @@ -65,7 +66,7 @@ def run(): assert xs[i] == expect -@ti.test() +@test_utils.test() def test_unconditional_continue(): x = ti.field(ti.i32, shape=n) @@ -84,7 +85,7 @@ def run(): assert xs[i] == 0 -@ti.test() +@test_utils.test() def test_kernel_continue_in_nested_if(): x = ti.field(ti.i32, shape=n) @@ -106,7 +107,7 @@ def run(a: ti.i32): assert x[0] == 0 -@ti.test() +@test_utils.test() def test_kernel_continue_in_nested_if_2(): x = ti.field(ti.i32, shape=n) @@ -127,7 +128,7 @@ def run(a: ti.i32): assert x[0] == 0 -@ti.test() +@test_utils.test() def test_kernel_continue_in_nested_if_3(): x = ti.field(ti.i32, shape=n) diff --git a/tests/python/test_copy_from.py b/tests/python/test_copy_from.py index 646d0c4809870..f546477e1e826 100644 --- a/tests/python/test_copy_from.py +++ b/tests/python/test_copy_from.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_scalar(): n = 16 diff --git a/tests/python/test_cuda_internals.py b/tests/python/test_cuda_internals.py index dede77a7656ba..1c2e3e3aee691 100644 --- a/tests/python/test_cuda_internals.py +++ b/tests/python/test_cuda_internals.py @@ -1,34 +1,37 @@ +from taichi.lang import impl + import taichi as ti +from tests import test_utils # TODO: these are not really tests... -@ti.test(arch=ti.cuda) +@test_utils.test(arch=ti.cuda) def test_do_nothing(): @ti.kernel def test(): for i in range(10): - ti.call_internal("do_nothing") + impl.call_internal("do_nothing") test() -@ti.test(arch=ti.cuda) +@test_utils.test(arch=ti.cuda) def test_active_mask(): @ti.kernel def test(): for i in range(48): if i % 2 == 0: - ti.call_internal("test_active_mask") + impl.call_internal("test_active_mask") test() -@ti.test(arch=ti.cuda) +@test_utils.test(arch=ti.cuda) def test_shfl_down(): @ti.kernel def test(): for i in range(32): - ti.call_internal("test_shfl") + impl.call_internal("test_shfl") test() diff --git a/tests/python/test_custom_float.py b/tests/python/test_custom_float.py index b2330869bdbd1..aae23362038ab 100644 --- a/tests/python/test_custom_float.py +++ b/tests/python/test_custom_float.py @@ -1,14 +1,14 @@ import math from pytest import approx -from taichi.lang import expr, impl import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_custom_float(): - cft = ti.quant.fixed(frac=32, range=2) + cft = ti.types.quantized_types.quant.fixed(frac=32, num_range=2) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) @@ -27,9 +27,9 @@ def foo(): assert x[None] == approx(0.66) -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_custom_matrix_rotation(): - cft = ti.quant.fixed(frac=16, range=1.2) + cft = ti.types.quantized_types.quant.fixed(frac=16, num_range=1.2) x = ti.Matrix.field(2, 2, dtype=cft) @@ -45,8 +45,7 @@ def rotate_18_degrees(): angle = math.pi / 10 x[None] = x[None] @ ti.Matrix( [[ti.cos(angle), ti.sin(angle)], [-ti.sin(angle), - ti.cos(angle)]], - dt=impl.get_runtime().default_fp) + ti.cos(angle)]]) for i in range(5): rotate_18_degrees() @@ -56,10 +55,11 @@ def rotate_18_degrees(): assert x[None][1, 1] == approx(0, abs=1e-4) -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_custom_float_implicit_cast(): - ci13 = ti.quant.int(bits=13) - cft = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) + ci13 = ti.types.quantized_types.quant.int(bits=13) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=ci13, scale=0.1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) @@ -72,10 +72,11 @@ def foo(): assert x[None] == approx(10.0) -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_cache_read_only(): - ci15 = ti.quant.int(bits=15) - cft = ti.type_factory.custom_float(significand_type=ci15, scale=0.1) + ci15 = ti.types.quantized_types.quant.int(bits=15) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=ci15, scale=0.1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_float_exponents.py b/tests/python/test_custom_float_exponents.py index 6ecd324e94e75..6ede9f03a6498 100644 --- a/tests/python/test_custom_float_exponents.py +++ b/tests/python/test_custom_float_exponents.py @@ -3,15 +3,15 @@ from pytest import approx import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_custom_float_unsigned(): - cu13 = ti.quant.int(13, False) - exp = ti.quant.int(6, False) - cft = ti.type_factory.custom_float(significand_type=cu13, - exponent_type=exp, - scale=1) + cu13 = ti.types.quantized_types.quant.int(13, False) + exp = ti.types.quantized_types.quant.int(6, False) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cu13, exponent_type=exp, scale=1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) @@ -28,13 +28,12 @@ def test_custom_float_unsigned(): assert x[None] == v -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_custom_float_signed(): - cu13 = ti.quant.int(13, True) - exp = ti.quant.int(6, False) - cft = ti.type_factory.custom_float(significand_type=cu13, - exponent_type=exp, - scale=1) + cu13 = ti.types.quantized_types.quant.int(13, True) + exp = ti.types.quantized_types.quant.int(6, False) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cu13, exponent_type=exp, scale=1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) @@ -60,13 +59,12 @@ def test_custom_float_signed(): @pytest.mark.parametrize('digits_bits', [23, 24]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_custom_float_precision(digits_bits): - cu24 = ti.quant.int(digits_bits, True) - exp = ti.quant.int(8, False) - cft = ti.type_factory.custom_float(significand_type=cu24, - exponent_type=exp, - scale=1) + cu24 = ti.types.quantized_types.quant.int(digits_bits, True) + exp = ti.types.quantized_types.quant.int(8, False) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cu24, exponent_type=exp, scale=1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) @@ -85,13 +83,12 @@ def test_custom_float_precision(digits_bits): @pytest.mark.parametrize('signed', [True, False]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_custom_float_truncation(signed): - cit = ti.quant.int(2, signed) - exp = ti.quant.int(5, False) - cft = ti.type_factory.custom_float(significand_type=cit, - exponent_type=exp, - scale=1) + cit = ti.types.quantized_types.quant.int(2, signed) + exp = ti.types.quantized_types.quant.int(5, False) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit, exponent_type=exp, scale=1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) @@ -117,13 +114,12 @@ def test_custom_float_truncation(signed): assert x[None] == 1.75 -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_custom_float_atomic_demotion(): - cit = ti.quant.int(2, True) - exp = ti.quant.int(5, False) - cft = ti.type_factory.custom_float(significand_type=cit, - exponent_type=exp, - scale=1) + cit = ti.types.quantized_types.quant.int(2, True) + exp = ti.types.quantized_types.quant.int(5, False) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit, exponent_type=exp, scale=1) x = ti.field(dtype=cft) ti.root.bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_float_shared_exp.py b/tests/python/test_custom_float_shared_exp.py index d27fad67622d9..02e9da00b8dec 100644 --- a/tests/python/test_custom_float_shared_exp.py +++ b/tests/python/test_custom_float_shared_exp.py @@ -2,20 +2,19 @@ from pytest import approx import taichi as ti +from tests import test_utils @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_shared_exponents(exponent_bits): - exp = ti.quant.int(exponent_bits, False) - cit1 = ti.quant.int(10, False) - cit2 = ti.quant.int(14, False) - cft1 = ti.type_factory.custom_float(significand_type=cit1, - exponent_type=exp, - scale=1) - cft2 = ti.type_factory.custom_float(significand_type=cit2, - exponent_type=exp, - scale=1) + exp = ti.types.quantized_types.quant.int(exponent_bits, False) + cit1 = ti.types.quantized_types.quant.int(10, False) + cit2 = ti.types.quantized_types.quant.int(14, False) + cft1 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit1, exponent_type=exp, scale=1) + cft2 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit2, exponent_type=exp, scale=1) a = ti.field(dtype=cft1) b = ti.field(dtype=cft2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) @@ -75,17 +74,15 @@ def foo(x: ti.f32, y: ti.f32): @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_shared_exponent_add(exponent_bits): - exp = ti.quant.int(exponent_bits, False) - cit1 = ti.quant.int(10, False) - cit2 = ti.quant.int(14, False) - cft1 = ti.type_factory.custom_float(significand_type=cit1, - exponent_type=exp, - scale=1) - cft2 = ti.type_factory.custom_float(significand_type=cit2, - exponent_type=exp, - scale=1) + exp = ti.types.quantized_types.quant.int(exponent_bits, False) + cit1 = ti.types.quantized_types.quant.int(10, False) + cit2 = ti.types.quantized_types.quant.int(14, False) + cft1 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit1, exponent_type=exp, scale=1) + cft2 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit2, exponent_type=exp, scale=1) a = ti.field(dtype=cft1) b = ti.field(dtype=cft2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) @@ -115,17 +112,15 @@ def foo(x: ti.f32, y: ti.f32): @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_shared_exponent_borrow(exponent_bits): - exp = ti.quant.int(exponent_bits, False) - cit1 = ti.quant.int(10, False) - cit2 = ti.quant.int(14, False) - cft1 = ti.type_factory.custom_float(significand_type=cit1, - exponent_type=exp, - scale=1) - cft2 = ti.type_factory.custom_float(significand_type=cit2, - exponent_type=exp, - scale=1) + exp = ti.types.quantized_types.quant.int(exponent_bits, False) + cit1 = ti.types.quantized_types.quant.int(10, False) + cit2 = ti.types.quantized_types.quant.int(14, False) + cft1 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit1, exponent_type=exp, scale=1) + cft2 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit2, exponent_type=exp, scale=1) a = ti.field(dtype=cft1) b = ti.field(dtype=cft2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) @@ -148,17 +143,15 @@ def inc(): @pytest.mark.parametrize('exponent_bits', [5, 6, 7, 8]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_negative(exponent_bits): - exp = ti.quant.int(exponent_bits, False) - cit1 = ti.quant.int(10, False) - cit2 = ti.quant.int(14, True) - cft1 = ti.type_factory.custom_float(significand_type=cit1, - exponent_type=exp, - scale=1) - cft2 = ti.type_factory.custom_float(significand_type=cit2, - exponent_type=exp, - scale=1) + exp = ti.types.quantized_types.quant.int(exponent_bits, False) + cit1 = ti.types.quantized_types.quant.int(10, False) + cit2 = ti.types.quantized_types.quant.int(14, True) + cft1 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit1, exponent_type=exp, scale=1) + cft2 = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit2, exponent_type=exp, scale=1) a = ti.field(dtype=cft1) b = ti.field(dtype=cft2) ti.root.bit_struct(num_bits=32).place(a, b, shared_exponent=True) diff --git a/tests/python/test_custom_float_time_integration.py b/tests/python/test_custom_float_time_integration.py index aa342feb467cd..00906efec1a8e 100644 --- a/tests/python/test_custom_float_time_integration.py +++ b/tests/python/test_custom_float_time_integration.py @@ -4,20 +4,20 @@ from pytest import approx import taichi as ti +from tests import test_utils @pytest.mark.parametrize('use_cft,use_exponent,use_shared_exp', [(False, False, False), (True, False, False), (True, True, False), (True, True, True)]) -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_custom_float_time_integration(use_cft, use_exponent, use_shared_exp): if use_cft: if use_exponent: - exp = ti.quant.int(6, False) - cit = ti.quant.int(13, True) - cft = ti.type_factory.custom_float(significand_type=cit, - exponent_type=exp, - scale=1) + exp = ti.types.quantized_types.quant.int(6, False) + cit = ti.types.quantized_types.quant.int(13, True) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit, exponent_type=exp, scale=1) x = ti.Vector.field(2, dtype=cft) if use_shared_exp: ti.root.bit_struct(num_bits=32).place(x, shared_exponent=True) @@ -25,9 +25,9 @@ def test_custom_float_time_integration(use_cft, use_exponent, use_shared_exp): ti.root.bit_struct(num_bits=32).place(x.get_scalar_field(0)) ti.root.bit_struct(num_bits=32).place(x.get_scalar_field(1)) else: - cit = ti.quant.int(16, True) - cft = ti.type_factory.custom_float(significand_type=cit, - scale=1 / 2**14) + cit = ti.types.quantized_types.quant.int(16, True) + cft = ti.types.quantized_types.type_factory.custom_float( + significand_type=cit, scale=1 / 2**14) x = ti.Vector.field(2, dtype=cft) ti.root.bit_struct(num_bits=32).place(x) else: @@ -35,7 +35,7 @@ def test_custom_float_time_integration(use_cft, use_exponent, use_shared_exp): @ti.func def v_at(p): - return ti.Vector([-p[1], p[0]], ti.f32) + return ti.Vector([-p[1], p[0]]) @ti.kernel def advance(dt: ti.f32): diff --git a/tests/python/test_custom_int.py b/tests/python/test_custom_int.py index 874653ce67276..b75d366b8d064 100644 --- a/tests/python/test_custom_int.py +++ b/tests/python/test_custom_int.py @@ -1,9 +1,10 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_custom_int_implicit_cast(): - ci13 = ti.quant.int(13, True) + ci13 = ti.types.quantized_types.quant.int(13, True) x = ti.field(dtype=ci13) ti.root.bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_struct.py b/tests/python/test_custom_struct.py index ebd84df244e86..657e8f6820fd7 100644 --- a/tests/python/test_custom_struct.py +++ b/tests/python/test_custom_struct.py @@ -2,9 +2,10 @@ from pytest import approx import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_struct_member_access(): n = 32 @@ -37,7 +38,7 @@ def run_python_scope(): assert y[i].b == i * 2 + 1 -@ti.test() +@test_utils.test() def test_struct_whole_access(): n = 32 @@ -76,7 +77,7 @@ def run_python_scope(): assert y[i].b == int(1.01 * i) -@ti.test() +@test_utils.test() def test_struct_fill(): n = 32 @@ -113,7 +114,7 @@ def fill_elements(): assert np.allclose(x[i].b.to_numpy(), int(x[i].a)) -@ti.test() +@test_utils.test() def test_matrix_type(): n = 32 vec2f = ti.types.vector(2, ti.f32) @@ -141,7 +142,7 @@ def run_python_scope(): assert np.allclose(x[i].to_numpy(), np.array([i + 1, i, i])) -@ti.test() +@test_utils.test() def test_struct_type(): n = 32 vec3f = ti.types.vector(3, float) @@ -203,7 +204,7 @@ def run_python_scope(): assert x[i].line.length == 5.0 -@ti.test() +@test_utils.test() def test_struct_assign(): n = 32 vec3f = ti.types.vector(3, float) @@ -242,7 +243,7 @@ def run_python_scope(): assert x[i].line.length == i + 0.5 -@ti.test() +@test_utils.test() def test_compound_type_implicit_cast(): vec2i = ti.types.vector(2, int) vec2f = ti.types.vector(2, float) @@ -277,7 +278,7 @@ def i2f_python_scope(): assert type(float_value) == float and float_value == approx(6.0, rel=1e-4) -@ti.test() +@test_utils.test() def test_local_struct_assign(): n = 32 vec3f = ti.types.vector(3, float) @@ -299,3 +300,58 @@ def run_python_scope(): run_taichi_scope() run_python_scope() + + +@test_utils.test(debug=True) +def test_copy_python_scope_struct_to_taichi_scope(): + a = ti.Struct({'a': 2, 'b': 3}) + + @ti.kernel + def test(): + b = a + assert b.a == 2 + assert b.b == 3 + b = ti.Struct({'a': 3, 'b': 4}) + assert b.a == 3 + assert b.b == 4 + + test() + + +@test_utils.test(debug=True) +def test_copy_struct_field_element_to_taichi_scope(): + a = ti.Struct.field({'a': ti.i32, 'b': ti.i32}, shape=()) + a[None].a = 2 + a[None].b = 3 + + @ti.kernel + def test(): + b = a[None] + assert b.a == 2 + assert b.b == 3 + b.a = 5 + b.b = 9 + assert b.a == 5 + assert b.b == 9 + assert a[None].a == 2 + assert a[None].b == 3 + + test() + + +@test_utils.test(debug=True) +def test_copy_struct_in_taichi_scope(): + @ti.kernel + def test(): + a = ti.Struct({'a': 2, 'b': 3}) + b = a + assert b.a == 2 + assert b.b == 3 + b.a = 5 + b.b = 9 + assert b.a == 5 + assert b.b == 9 + assert a.a == 2 + assert a.b == 3 + + test() diff --git a/tests/python/test_custom_type_atomics.py b/tests/python/test_custom_type_atomics.py index e6af4f0e2aaf9..1e810963b6c2a 100644 --- a/tests/python/test_custom_type_atomics.py +++ b/tests/python/test_custom_type_atomics.py @@ -1,14 +1,17 @@ from pytest import approx import taichi as ti +from tests import test_utils # TODO: remove excluding of ti.metal. -@ti.test(require=ti.extension.quant_basic, exclude=[ti.metal], debug=True) +@test_utils.test(require=ti.extension.quant_basic, + exclude=[ti.metal], + debug=True) def test_custom_int_atomics(): - ci13 = ti.quant.int(13, True) - ci5 = ti.quant.int(5, True) - cu2 = ti.quant.int(2, False) + ci13 = ti.types.quantized_types.quant.int(13, True) + ci5 = ti.types.quantized_types.quant.int(5, True) + cu2 = ti.types.quantized_types.quant.int(2, False) x = ti.field(dtype=ci13) y = ti.field(dtype=ci5) @@ -38,9 +41,10 @@ def foo(): assert z[None] == 3 -@ti.test(require=[ti.extension.quant_basic, ti.extension.data64], debug=True) +@test_utils.test(require=[ti.extension.quant_basic, ti.extension.data64], + debug=True) def test_custom_int_atomics_b64(): - ci13 = ti.quant.int(13, True) + ci13 = ti.types.quantized_types.quant.int(13, True) x = ti.field(dtype=ci13) @@ -62,12 +66,14 @@ def foo(): assert x[2] == 315 -@ti.test(require=ti.extension.quant_basic, debug=True) +@test_utils.test(require=ti.extension.quant_basic, debug=True) def test_custom_float_atomics(): - ci13 = ti.quant.int(13, True) - ci19 = ti.quant.int(19, False) - cft13 = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) - cft19 = ti.type_factory.custom_float(significand_type=ci19, scale=0.1) + ci13 = ti.types.quantized_types.quant.int(13, True) + ci19 = ti.types.quantized_types.quant.int(19, False) + cft13 = ti.types.quantized_types.type_factory.custom_float( + significand_type=ci13, scale=0.1) + cft19 = ti.types.quantized_types.type_factory.custom_float( + significand_type=ci19, scale=0.1) x = ti.field(dtype=cft13) y = ti.field(dtype=cft19) diff --git a/tests/python/test_customized_grad.py b/tests/python/test_customized_grad.py index c116a3fe2b10c..d813c6a0ed1e3 100644 --- a/tests/python/test_customized_grad.py +++ b/tests/python/test_customized_grad.py @@ -1,7 +1,10 @@ +import pytest + import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_customized_kernels_tape(): x = ti.field(ti.f32) total = ti.field(ti.f32) @@ -31,7 +34,7 @@ def backward(mul): assert x.grad[0] == 4 -@ti.test() +@test_utils.test() def test_customized_kernels_grad(): x = ti.field(ti.f32) total = ti.field(ti.f32) @@ -62,7 +65,7 @@ def backward(mul): assert x.grad[0] == 4 -@ti.test() +@test_utils.test() def test_customized_kernels_indirect(): x = ti.field(ti.f32) total = ti.field(ti.f32) @@ -95,7 +98,7 @@ def backward(mul): assert x.grad[0] == 4 -@ti.test() +@test_utils.test() def test_customized_kernels_oop(): @ti.data_oriented class A: @@ -130,7 +133,7 @@ def backward(self, mul): assert a.x.grad[0] == 4 -@ti.test() +@test_utils.test() def test_customized_kernels_oop2(): @ti.data_oriented class A: @@ -168,8 +171,7 @@ def backward(self, mul): assert a.x.grad[0] == 4 -@ti.test() -@ti.must_throw(RuntimeError) +@test_utils.test() def test_decorated_primal_is_taichi_kernel(): x = ti.field(ti.f32) total = ti.field(ti.f32) @@ -185,16 +187,17 @@ def func(mul: ti.f32): for i in range(n): ti.atomic_add(total[None], x[i] * mul) - @ti.ad.grad_for(func) - def backward(mul): - func.grad(mul) + with pytest.raises(RuntimeError): + + @ti.ad.grad_for(func) + def backward(mul): + func.grad(mul) with ti.Tape(loss=total): func(4) -@ti.test() -@ti.must_throw(RuntimeError) +@test_utils.test() def test_decorated_primal_missing_decorator(): x = ti.field(ti.f32) total = ti.field(ti.f32) @@ -214,9 +217,11 @@ def foward(mul): func(mul) func(mul) - @ti.ad.grad_for(func) - def backward(mul): - func.grad(mul) + with pytest.raises(RuntimeError): + + @ti.ad.grad_for(func) + def backward(mul): + func.grad(mul) with ti.Tape(loss=total): func(4) diff --git a/tests/python/test_debug.py b/tests/python/test_debug.py index bff25f5008ebe..e60abf74e80a8 100644 --- a/tests/python/test_debug.py +++ b/tests/python/test_debug.py @@ -1,6 +1,7 @@ import pytest import taichi as ti +from tests import test_utils def test_cpu_debug_snode_reader(): @@ -12,47 +13,38 @@ def test_cpu_debug_snode_reader(): assert x[None] == 10.0 -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_cpu_debug_snode_writer_out_of_bound(): - ti.set_gdb_trigger(False) - x = ti.field(ti.f32, shape=3) with pytest.raises(RuntimeError): x[3] = 10.0 -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_cpu_debug_snode_writer_out_of_bound_negative(): - ti.set_gdb_trigger(False) - x = ti.field(ti.f32, shape=3) with pytest.raises(RuntimeError): x[-1] = 10.0 -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_cpu_debug_snode_reader_out_of_bound(): - ti.set_gdb_trigger(False) - x = ti.field(ti.f32, shape=3) with pytest.raises(RuntimeError): a = x[3] -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_cpu_debug_snode_reader_out_of_bound_negative(): - ti.set_gdb_trigger(False) - x = ti.field(ti.f32, shape=3) with pytest.raises(RuntimeError): a = x[-1] -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_out_of_bound(): - ti.set_gdb_trigger(False) x = ti.field(ti.i32, shape=(8, 16)) @ti.kernel @@ -63,9 +55,8 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_not_out_of_bound(): - ti.set_gdb_trigger(False) x = ti.field(ti.i32, shape=(8, 16)) @ti.kernel @@ -75,9 +66,8 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_out_of_bound_dynamic(): - ti.set_gdb_trigger(False) x = ti.field(ti.i32) ti.root.dynamic(ti.i, 16, 4).place(x) @@ -90,9 +80,8 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_not_out_of_bound_dynamic(): - ti.set_gdb_trigger(False) x = ti.field(ti.i32) ti.root.dynamic(ti.i, 16, 4).place(x) @@ -104,10 +93,8 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_out_of_bound_with_offset(): - ti.init(debug=True) - ti.set_gdb_trigger(False) x = ti.field(ti.i32, shape=(8, 16), offset=(-8, -8)) @ti.kernel @@ -119,9 +106,8 @@ def func(): func() -@ti.test(require=ti.extension.assertion, debug=True) +@test_utils.test(require=ti.extension.assertion, debug=True, gdb_trigger=False) def test_not_out_of_bound_with_offset(): - ti.set_gdb_trigger(False) x = ti.field(ti.i32, shape=(8, 16), offset=(-4, -8)) @ti.kernel diff --git a/tests/python/test_delay_modify.py b/tests/python/test_delay_modify.py new file mode 100644 index 0000000000000..bb61d5f8218f1 --- /dev/null +++ b/tests/python/test_delay_modify.py @@ -0,0 +1,18 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_simplify_bug(): + @ti.kernel + def foo() -> ti.types.vector(4, dtype=ti.i32): + a = ti.Vector([0, 0, 0, 0]) + for i in range(5): + for k in ti.static(range(4)): + if i == 3: + a[k] = 1 + return a + + a = foo() + + assert (a == ti.Vector([1, 1, 1, 1])).all() == 1 diff --git a/tests/python/test_div.py b/tests/python/test_div.py index 65c8fb7c9750f..0da0c24100eae 100644 --- a/tests/python/test_div.py +++ b/tests/python/test_div.py @@ -1,7 +1,10 @@ +from taichi.lang import impl + import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def _test_floor_div(arg1, a, arg2, b, arg3, c): z = ti.field(arg3, shape=()) @@ -13,7 +16,7 @@ def func(x: arg1, y: arg2): assert z[None] == c -@ti.test() +@test_utils.test() def _test_true_div(arg1, a, arg2, b, arg3, c): z = ti.field(arg3, shape=()) @@ -56,9 +59,9 @@ def test_true_div(): _test_true_div(ti.f32, -3, ti.i32, 2, ti.i32, -1) -@ti.test() +@test_utils.test() def test_div_default_ip(): - ti.get_runtime().set_default_ip(ti.i64) + impl.get_runtime().set_default_ip(ti.i64) z = ti.field(ti.f32, shape=()) @ti.kernel @@ -70,7 +73,7 @@ def func(): assert z[None] == 100000 -@ti.test() +@test_utils.test() def test_floor_div_pythonic(): z = ti.field(ti.i32, shape=()) diff --git a/tests/python/test_dynamic.py b/tests/python/test_dynamic.py index 7d473c14a3899..a443a4d78e5e5 100644 --- a/tests/python/test_dynamic.py +++ b/tests/python/test_dynamic.py @@ -1,17 +1,11 @@ import pytest +from taichi.lang.misc import serialize import taichi as ti +from tests import test_utils -def ti_support_dynamic(test): - return ti.archs_excluding(ti.cc, ti.vulkan)(test) - - -def ti_support_non_top_dynamic(test): - return ti.archs_excluding(ti.opengl, ti.cc, ti.vulkan)(test) - - -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_dynamic(): x = ti.field(ti.f32) n = 128 @@ -29,7 +23,7 @@ def func(): assert x[i] == i -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_dynamic2(): x = ti.field(ti.f32) n = 128 @@ -47,7 +41,7 @@ def func(): assert x[i] == i -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_dynamic_matrix(): x = ti.Matrix.field(2, 1, dtype=ti.i32) n = 8192 @@ -56,7 +50,7 @@ def test_dynamic_matrix(): @ti.kernel def func(): - ti.serialize() + serialize() for i in range(n // 4): x[i * 4][1, 0] = i @@ -70,7 +64,7 @@ def func(): assert b == 0 -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_append(): x = ti.field(ti.i32) n = 128 @@ -92,7 +86,7 @@ def func(): assert elements[i] == i -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_length(): x = ti.field(ti.i32) y = ti.field(ti.f32, shape=()) @@ -116,7 +110,7 @@ def get_len(): assert y[None] == n -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_append_ret_value(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -141,13 +135,13 @@ def func(): assert x[i] + 3 == z[i] -@ti.test(exclude=[ti.opengl, ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_dense_dynamic(): # The spin lock implementation has triggered a bug in CUDA, the end result # being that appending to Taichi's dynamic node messes up its length. See # https://stackoverflow.com/questions/65995357/cuda-spinlock-implementation-with-independent-thread-scheduling-supported # CUDA 11.2 didn't fix this bug, unfortunately. - if ti.cfg.arch == ti.cuda: + if ti.lang.impl.current_cfg().arch == ti.cuda: pytest.skip('CUDA spinlock bug') n = 128 @@ -158,7 +152,7 @@ def test_dense_dynamic(): @ti.kernel def func(): - ti.serialize() + serialize() for i in range(n): for j in range(n): ti.append(x.parent(), j, i) @@ -172,7 +166,7 @@ def func(): assert l[i] == n -@ti.test(exclude=[ti.opengl, ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_dense_dynamic_len(): n = 128 x = ti.field(ti.i32) @@ -191,9 +185,8 @@ def func(): assert l[i] == 0 -@ti.test(exclude=[ti.cc, ti.vulkan]) +@test_utils.test(require=ti.extension.sparse) def test_dynamic_activate(): - ti.init(arch=ti.metal) # record the lengths l = ti.field(ti.i32, 3) x = ti.field(ti.i32) diff --git a/tests/python/test_eig.py b/tests/python/test_eig.py index b0d50b503d3ed..f0e883f01f485 100644 --- a/tests/python/test_eig.py +++ b/tests/python/test_eig.py @@ -2,6 +2,7 @@ import pytest import taichi as ti +from tests import test_utils def _eigen_vector_equal(v1, v2, tol): @@ -111,9 +112,10 @@ def test_eig2x2(): for func in [_test_eig2x2_real, _test_eig2x2_complex]: for fp in [ti.f32, ti.f64]: - @ti.test(require=ti.extension.data64 if fp == ti.f64 else [], - default_fp=fp, - fast_math=False) + @test_utils.test( + require=ti.extension.data64 if fp == ti.f64 else [], + default_fp=fp, + fast_math=False) def wrapped(): func(fp) @@ -124,9 +126,10 @@ def test_sym_eig2x2(): for func in [_test_sym_eig2x2]: for fp in [ti.f32, ti.f64]: - @ti.test(require=ti.extension.data64 if fp == ti.f64 else [], - default_fp=fp, - fast_math=False) + @test_utils.test( + require=ti.extension.data64 if fp == ti.f64 else [], + default_fp=fp, + fast_math=False) def wrapped(): func(fp) diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index 7a6eb01ad332c..59354c96e2c25 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -2,7 +2,7 @@ import pytest import taichi as ti -from taichi import allclose +from tests import test_utils def _c_mod(a, b): @@ -11,7 +11,7 @@ def _c_mod(a, b): @pytest.mark.parametrize('lhs_is_mat,rhs_is_mat', [(True, True), (True, False), (False, True)]) -@ti.test(fast_math=False, exclude=[ti.vulkan]) +@test_utils.test(fast_math=False, exclude=[ti.vulkan]) def test_binary_f(lhs_is_mat, rhs_is_mat): x = ti.Matrix.field(3, 2, ti.f32, 16) if lhs_is_mat: @@ -55,27 +55,27 @@ def func(): x = x.to_numpy() y = y.to_numpy() z = z.to_numpy() - assert allclose(x[0], y + z) - assert allclose(x[1], y - z) - assert allclose(x[2], y * z) - assert allclose(x[3], y / z) - assert allclose(x[4], y // z) - assert allclose(x[5], y % z) - assert allclose(x[6], y**z) - assert allclose(x[7], y == z) - assert allclose(x[8], y != z) - assert allclose(x[9], y > z) - assert allclose(x[10], y >= z) - assert allclose(x[11], y < z) - assert allclose(x[12], y <= z) - assert allclose(x[13], np.arctan2(y, z)) - assert allclose(x[14], np.minimum(y, z)) - assert allclose(x[15], np.maximum(y, z)) + assert test_utils.allclose(x[0], y + z) + assert test_utils.allclose(x[1], y - z) + assert test_utils.allclose(x[2], y * z) + assert test_utils.allclose(x[3], y / z) + assert test_utils.allclose(x[4], y // z) + assert test_utils.allclose(x[5], y % z) + assert test_utils.allclose(x[6], y**z) + assert test_utils.allclose(x[7], y == z) + assert test_utils.allclose(x[8], y != z) + assert test_utils.allclose(x[9], y > z) + assert test_utils.allclose(x[10], y >= z) + assert test_utils.allclose(x[11], y < z) + assert test_utils.allclose(x[12], y <= z) + assert test_utils.allclose(x[13], np.arctan2(y, z)) + assert test_utils.allclose(x[14], np.minimum(y, z)) + assert test_utils.allclose(x[15], np.maximum(y, z)) @pytest.mark.parametrize('is_mat', [(True, True), (True, False), (False, True)]) -@ti.test() +@test_utils.test() def test_binary_i(is_mat): lhs_is_mat, rhs_is_mat = is_mat @@ -125,30 +125,30 @@ def func(): x = x.to_numpy() y = y.to_numpy() z = z.to_numpy() - assert allclose(x[0], y + z) - assert allclose(x[1], y - z) - assert allclose(x[2], y * z) - assert allclose(x[3], y // z) - assert allclose(x[4], y // z) - assert allclose(x[5], y % z) - assert allclose(x[6], y % z) - assert allclose(x[7], y**z) - assert allclose(x[8], y == z) - assert allclose(x[9], y != z) - assert allclose(x[10], y > z) - assert allclose(x[11], y >= z) - assert allclose(x[12], y < z) - assert allclose(x[13], y <= z) - assert allclose(x[14], y & z) - assert allclose(x[15], y ^ z) - assert allclose(x[16], y | z) - assert allclose(x[17], np.minimum(y, z)) - assert allclose(x[18], np.maximum(y, z)) - assert allclose(x[19], y << z) + assert test_utils.allclose(x[0], y + z) + assert test_utils.allclose(x[1], y - z) + assert test_utils.allclose(x[2], y * z) + assert test_utils.allclose(x[3], y // z) + assert test_utils.allclose(x[4], y // z) + assert test_utils.allclose(x[5], y % z) + assert test_utils.allclose(x[6], y % z) + assert test_utils.allclose(x[7], y**z, rel=1e-5) + assert test_utils.allclose(x[8], y == z) + assert test_utils.allclose(x[9], y != z) + assert test_utils.allclose(x[10], y > z) + assert test_utils.allclose(x[11], y >= z) + assert test_utils.allclose(x[12], y < z) + assert test_utils.allclose(x[13], y <= z) + assert test_utils.allclose(x[14], y & z) + assert test_utils.allclose(x[15], y ^ z) + assert test_utils.allclose(x[16], y | z) + assert test_utils.allclose(x[17], np.minimum(y, z)) + assert test_utils.allclose(x[18], np.maximum(y, z)) + assert test_utils.allclose(x[19], y << z) @pytest.mark.parametrize('rhs_is_mat', [True, False]) -@ti.test(fast_math=False) +@test_utils.test(fast_math=False) def test_writeback_binary_f(rhs_is_mat): x = ti.Matrix.field(3, 2, ti.f32, 9) y = ti.Matrix.field(3, 2, ti.f32, ()) @@ -184,18 +184,18 @@ def func(): x = x.to_numpy() y = y.to_numpy() z = z.to_numpy() - assert allclose(x[1], y + z) - assert allclose(x[2], y - z) - assert allclose(x[3], y * z) - assert allclose(x[4], y / z) - assert allclose(x[5], y // z) - assert allclose(x[6], y % z) - assert allclose(x[7], np.minimum(y, z)) - assert allclose(x[8], np.maximum(y, z)) + assert test_utils.allclose(x[1], y + z) + assert test_utils.allclose(x[2], y - z) + assert test_utils.allclose(x[3], y * z) + assert test_utils.allclose(x[4], y / z) + assert test_utils.allclose(x[5], y // z) + assert test_utils.allclose(x[6], y % z) + assert test_utils.allclose(x[7], np.minimum(y, z)) + assert test_utils.allclose(x[8], np.maximum(y, z)) @pytest.mark.parametrize('rhs_is_mat', [(True, True), (True, False)]) -@ti.test() +@test_utils.test() def test_writeback_binary_i(rhs_is_mat): x = ti.Matrix.field(3, 2, ti.i32, 12) y = ti.Matrix.field(3, 2, ti.i32, ()) @@ -230,23 +230,23 @@ def func(): x = x.to_numpy() y = y.to_numpy() z = z.to_numpy() - assert allclose(x[1], y + z) - assert allclose(x[2], y - z) - assert allclose(x[3], y * z) - assert allclose(x[4], y // z) - assert allclose(x[5], y % z) - assert allclose(x[6], y & z) - assert allclose(x[7], y | z) - assert allclose(x[8], y ^ z) - assert allclose(x[10], np.minimum(y, z)) - assert allclose(x[11], np.maximum(y, z)) - - -@ti.test(exclude=[ti.vulkan]) + assert test_utils.allclose(x[1], y + z) + assert test_utils.allclose(x[2], y - z) + assert test_utils.allclose(x[3], y * z) + assert test_utils.allclose(x[4], y // z) + assert test_utils.allclose(x[5], y % z) + assert test_utils.allclose(x[6], y & z) + assert test_utils.allclose(x[7], y | z) + assert test_utils.allclose(x[8], y ^ z) + assert test_utils.allclose(x[10], np.minimum(y, z)) + assert test_utils.allclose(x[11], np.maximum(y, z)) + + +@test_utils.test() def test_unary(): xi = ti.Matrix.field(3, 2, ti.i32, 4) yi = ti.Matrix.field(3, 2, ti.i32, ()) - xf = ti.Matrix.field(3, 2, ti.f32, 14) + xf = ti.Matrix.field(3, 2, ti.f32, 15) yf = ti.Matrix.field(3, 2, ti.f32, ()) yi.from_numpy(np.array([[3, 2], [9, 0], [7, 4]], np.int32)) @@ -256,10 +256,10 @@ def test_unary(): def func(): xi[0] = -yi[None] xi[1] = ~yi[None] - xi[2] = ti.logical_not(yi[None]) - xi[3] = ti.abs(yi[None]) + xi[2] = not yi[None] + xi[3] = abs(yi[None]) xf[0] = -yf[None] - xf[1] = ti.abs(yf[None]) + xf[1] = abs(yf[None]) xf[2] = ti.sqrt(yf[None]) xf[3] = ti.sin(yf[None]) xf[4] = ti.cos(yf[None]) @@ -272,35 +272,37 @@ def func(): xf[11] = ti.exp(yf[None]) xf[12] = ti.log(yf[None]) xf[13] = ti.rsqrt(yf[None]) + xf[14] = ti.round(yf[None]) func() xi = xi.to_numpy() yi = yi.to_numpy() xf = xf.to_numpy() yf = yf.to_numpy() - assert allclose(xi[0], -yi) - assert allclose(xi[1], ~yi) - assert allclose(xi[3], np.abs(yi)) - assert allclose(xf[0], -yf) - assert allclose(xf[1], np.abs(yf)) - assert allclose(xf[2], np.sqrt(yf), rel=1e-5) - assert allclose(xf[3], np.sin(yf), rel=1e-4) - assert allclose(xf[4], np.cos(yf), rel=1e-4) - assert allclose(xf[5], np.tan(yf), rel=1e-4) - assert allclose(xf[6], np.arcsin(yf), rel=1e-4) - assert allclose(xf[7], np.arccos(yf), rel=1e-4) - assert allclose(xf[8], np.tanh(yf), rel=1e-4) - assert allclose(xf[9], np.floor(yf), rel=1e-5) - assert allclose(xf[10], np.ceil(yf), rel=1e-5) - assert allclose(xf[11], np.exp(yf), rel=1e-5) - assert allclose(xf[12], np.log(yf), rel=1e-5) - assert allclose(xf[13], 1 / np.sqrt(yf), rel=1e-5) + assert test_utils.allclose(xi[0], -yi) + assert test_utils.allclose(xi[1], ~yi) + assert test_utils.allclose(xi[3], np.abs(yi)) + assert test_utils.allclose(xf[0], -yf) + assert test_utils.allclose(xf[1], np.abs(yf)) + assert test_utils.allclose(xf[2], np.sqrt(yf), rel=1e-5) + assert test_utils.allclose(xf[3], np.sin(yf), rel=1e-4) + assert test_utils.allclose(xf[4], np.cos(yf), rel=1e-4) + assert test_utils.allclose(xf[5], np.tan(yf), rel=1e-4) + assert test_utils.allclose(xf[6], np.arcsin(yf), rel=1e-4) + assert test_utils.allclose(xf[7], np.arccos(yf), rel=1e-4) + assert test_utils.allclose(xf[8], np.tanh(yf), rel=1e-4) + assert test_utils.allclose(xf[9], np.floor(yf), rel=1e-5) + assert test_utils.allclose(xf[10], np.ceil(yf), rel=1e-5) + assert test_utils.allclose(xf[11], np.exp(yf), rel=1e-5) + assert test_utils.allclose(xf[12], np.log(yf), rel=1e-5) + assert test_utils.allclose(xf[13], 1 / np.sqrt(yf), rel=1e-5) + assert test_utils.allclose(xf[14], np.round(yf), rel=1e-5) @pytest.mark.parametrize('is_mat', [(True, True, True), (True, False, False), (False, True, False), (False, False, True), (False, True, True)]) -@ti.test() +@test_utils.test() def test_ternary_i(is_mat): cond_is_mat, lhs_is_mat, rhs_is_mat = is_mat x = ti.Matrix.field(3, 2, ti.i32, 1) @@ -339,5 +341,6 @@ def func(): y = y.to_numpy() z = z.to_numpy() w = w.to_numpy() - assert allclose(x[0], - np.int32(np.bool_(y)) * z + np.int32(1 - np.bool_(y)) * w) + assert test_utils.allclose( + x[0], + np.int32(np.bool_(y)) * z + np.int32(1 - np.bool_(y)) * w) diff --git a/tests/python/test_empty.py b/tests/python/test_empty.py index 9343162c5e1da..80ac246bbab28 100644 --- a/tests/python/test_empty.py +++ b/tests/python/test_empty.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_empty(): @ti.kernel def func(): @@ -10,7 +11,7 @@ def func(): func() -@ti.test() +@test_utils.test() def test_empty_args(): @ti.kernel def func(x: ti.i32, arr: ti.ext_arr()): diff --git a/tests/python/test_exception.py b/tests/python/test_exception.py new file mode 100644 index 0000000000000..2cace74874862 --- /dev/null +++ b/tests/python/test_exception.py @@ -0,0 +1,158 @@ +from inspect import currentframe, getframeinfo +from sys import version_info + +import pytest +from tests import test_utils + +import taichi as ti + + +@test_utils.test() +def test_exception_multiline(): + frameinfo = getframeinfo(currentframe()) + with pytest.raises(ti.TaichiNameError) as e: + # yapf: disable + @ti.kernel + def foo(): + aaaa(111, + 1211222, + + 23) + foo() + # yapf: enable + + if version_info < (3, 8): + msg = f""" +On line {frameinfo.lineno + 5} of file "{frameinfo.filename}", in foo: + aaaa(111,""" + else: + msg = f""" +On line {frameinfo.lineno + 5} of file "{frameinfo.filename}", in foo: + aaaa(111, + ^^^^""" + print(e.value.args[0]) + assert e.value.args[0][:len(msg)] == msg + + +@test_utils.test() +def test_exception_from_func(): + frameinfo = getframeinfo(currentframe()) + with pytest.raises(ti.TaichiNameError) as e: + + @ti.func + def baz(): + t() + + @ti.func + def bar(): + baz() + + @ti.kernel + def foo(): + bar() + + foo() + lineno = frameinfo.lineno + file = frameinfo.filename + if version_info < (3, 8): + msg = f""" +On line {lineno + 13} of file "{file}", in foo: + bar() +On line {lineno + 9} of file "{file}", in bar: + baz() +On line {lineno + 5} of file "{file}", in baz: + t()""" + else: + msg = f""" +On line {lineno + 13} of file "{file}", in foo: + bar() + ^^^^^ +On line {lineno + 9} of file "{file}", in bar: + baz() + ^^^^^ +On line {lineno + 5} of file "{file}", in baz: + t() + ^""" + print(e.value.args[0]) + assert e.value.args[0][:len(msg)] == msg + + +@test_utils.test() +def test_tab(): + frameinfo = getframeinfo(currentframe()) + with pytest.raises(ti.TaichiNameError) as e: + # yapf: disable + @ti.kernel + def foo(): + a(11, 22, 3) + foo() + # yapf: enable + lineno = frameinfo.lineno + file = frameinfo.filename + if version_info < (3, 8): + msg = f""" +On line {lineno + 5} of file "{file}", in foo: + a(11, 22, 3)""" + else: + msg = f""" +On line {lineno + 5} of file "{file}", in foo: + a(11, 22, 3) + ^""" + print(e.value.args[0]) + assert e.value.args[0][:len(msg)] == msg + + +@test_utils.test() +def test_super_long_line(): + frameinfo = getframeinfo(currentframe()) + with pytest.raises(ti.TaichiNameError) as e: + # yapf: disable + @ti.kernel + def foo(): + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(111) + foo() + # yapf: enable + lineno = frameinfo.lineno + file = frameinfo.filename + if version_info < (3, 8): + msg = f""" +On line {lineno + 5} of file "{file}", in foo: + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(111) +""" + else: + msg = f""" +On line {lineno + 5} of file "{file}", in foo: + aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaa + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +bbbbbbbbbbbbbbbbbbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(111) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^""" + print(e.value.args[0]) + assert e.value.args[0][:len(msg)] == msg + + +@pytest.mark.skipif(version_info < (3, 8), reason="This is a feature for python>=3.8") +@test_utils.test() +def test_exception_in_node_with_body(): + frameinfo = getframeinfo(currentframe()) + @ti.kernel + def foo(): + for i in range(1, 2, 3): + a = 1 + b = 1 + c = 1 + d = 1 + + with pytest.raises(ti.TaichiCompilationError) as e: + foo() + lineno = frameinfo.lineno + file = frameinfo.filename + msg = f""" +On line {lineno + 3} of file "{file}", in foo: + for i in range(1, 2, 3): + ^^^^^^^^^^^^^^^^^^^^^^^^ +Range should have 1 or 2 arguments, found 3""" + print(e.value.args[0]) + assert e.value.args[0] == msg + diff --git a/tests/python/test_expr_dict.py b/tests/python/test_expr_dict.py index 0735fdcfa444f..2cb32de83e3ff 100644 --- a/tests/python/test_expr_dict.py +++ b/tests/python/test_expr_dict.py @@ -1,17 +1,18 @@ import taichi as ti +from tests import test_utils -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_expr_dict_basic(): @ti.kernel def func(u: int, v: float) -> float: x = {'foo': 2 + u, 'bar': 3 + v} return x['foo'] * 100 + x['bar'] - assert func(2, 0.1) == ti.approx(403.1) + assert func(2, 0.1) == test_utils.approx(403.1) -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_expr_dict_field(): a = ti.field(ti.f32, shape=(4, )) @@ -22,10 +23,10 @@ def func() -> float: a[0] = 2 a[1] = 0.1 - assert func() == ti.approx(403.1) + assert func() == test_utils.approx(403.1) -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_dictcomp_multiple_ifs(): n = 8 x = ti.field(ti.i32, shape=(n, )) @@ -35,7 +36,7 @@ def test() -> ti.i32: # Taichi doesn't support global fields appearing anywhere after "for" # here. a = {x[j]: x[j] + j for j in range(100) if j > 2 if j < 5} - return sum(a) + return sum(a.values()) for i in range(n): x[i] = i * 2 diff --git a/tests/python/test_expr_list.py b/tests/python/test_expr_list.py index 73b968654a07c..5f9ea669cffb8 100644 --- a/tests/python/test_expr_list.py +++ b/tests/python/test_expr_list.py @@ -1,17 +1,18 @@ import taichi as ti +from tests import test_utils -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_expr_list_basic(): @ti.kernel def func(u: int, v: float) -> float: x = [2 + u, 3 + v] return x[0] * 100 + x[1] - assert func(1, 1.1) == ti.approx(304.1) + assert func(1, 1.1) == test_utils.approx(304.1) -@ti.test() +@test_utils.test() def test_listcomp_multiple_ifs(): x = ti.field(ti.i32, shape=(4, )) diff --git a/tests/python/test_external_func.py b/tests/python/test_external_func.py index b00ac72a21243..b7c4bb22fdd98 100644 --- a/tests/python/test_external_func.py +++ b/tests/python/test_external_func.py @@ -1,15 +1,89 @@ +import ctypes +import os +import shutil +import tempfile + import pytest +from taichi.lang.util import has_clangpp import taichi as ti +from tests import test_utils + +@pytest.mark.skipif(not has_clangpp(), reason='Clang not installed.') +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_source_builder_from_source(): + source_bc = ''' + extern "C" { + void add_and_mul(float *a, float *b, float *c, float *d, int *e) { + *c = (*a) + (*b); + *d = (*a) * (*b); + *e = int((*a) * (*b) + (*a)); + } + void pow_int(int *a, int *b, int *c) { + int ret = 1; + for (int i = 0; i < (*b); i++) + ret = ret * (*a); + *c = ret; + } + } + ''' + sb_bc = ti.lang.source_builder.SourceBuilder.from_source(source_bc) -@pytest.mark.parametrize('x,y', [(2, 3), (-1, 4)]) -@ti.test(exclude=ti.cpu, require=ti.extension.extfunc) -def test_asm(x, y): @ti.kernel - def func(x: ti.f32, y: ti.f32) -> ti.f32: - z = 0.0 - ti.asm('$0 = %0 * %1', inputs=[x, y], outputs=[z]) - return z + def func_bc() -> ti.i32: + a = 2.0 + b = 3.0 + c = 0.0 + d = 0.0 + e = 3 + sb_bc.add_and_mul(a, b, c, d, e) + p = 0 + c_plus_d = int(c + d) + sb_bc.pow_int(c_plus_d, e, p) + return p + + assert func_bc() == 11**8 + + +@pytest.mark.skipif(not has_clangpp(), reason='Clang not installed.') +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_source_builder_from_file(): + source_code = ''' + extern "C" { + void add_and_mul(float *a, float *b, float *c, float *d, int *e) { + *c = (*a) + (*b); + *d = (*a) * (*b); + *e = int((*a) * (*b) + (*a)); + } + void pow_int(int *a, int *b, int *c) { + int ret = 1; + for (int i = 0; i < (*b); i++) + ret = ret * (*a); + *c = ret; + } + } + ''' + + td = tempfile.mkdtemp() + fn = os.path.join(td, 'source.cpp') + with open(fn, 'w') as f: + f.write(source_code) + sb_bc = ti.lang.source_builder.SourceBuilder.from_file(fn) + + @ti.kernel + def func_bc() -> ti.i32: + a = 2.0 + b = 3.0 + c = 0.0 + d = 0.0 + e = 3 + sb_bc.add_and_mul(a, b, c, d, e) + p = 0 + c_plus_d = int(c + d) + sb_bc.pow_int(c_plus_d, e, p) + return p + + assert func_bc() == 11**8 - assert func(x, y) == x * y + shutil.rmtree(td) diff --git a/tests/python/test_f16.py b/tests/python/test_f16.py new file mode 100644 index 0000000000000..00709f0a91c8d --- /dev/null +++ b/tests/python/test_f16.py @@ -0,0 +1,303 @@ +import math + +import numpy as np +import pytest +from taichi.lang.util import has_pytorch + +import taichi as ti +from tests import test_utils + +archs_support_f16 = [ti.cpu, ti.cuda, ti.vulkan] + + +@test_utils.test(arch=archs_support_f16) +def test_snode_read_write(): + dtype = ti.f16 + x = ti.field(dtype, shape=()) + x[None] = 0.3 + print(x[None]) + assert (x[None] == test_utils.approx(0.3, rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_float16(): + dtype = ti.float16 + x = ti.field(dtype, shape=()) + x[None] = 0.3 + print(x[None]) + assert (x[None] == test_utils.approx(0.3, rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_to_numpy(): + n = 16 + x = ti.field(ti.f16, shape=n) + + @ti.kernel + def init(): + for i in x: + x[i] = i * 2 + + init() + y = x.to_numpy() + for i in range(n): + assert (y[i] == 2 * i) + + +@test_utils.test(arch=archs_support_f16) +def test_from_numpy(): + n = 16 + y = ti.field(dtype=ti.f16, shape=n) + x = np.arange(n, dtype=np.half) + y.from_numpy(x) + + @ti.kernel + def init(): + for i in y: + y[i] = 3 * i + + init() + z = y.to_numpy() + for i in range(n): + assert (z[i] == i * 3) + + +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(arch=archs_support_f16) +def test_to_torch(): + n = 16 + x = ti.field(ti.f16, shape=n) + + @ti.kernel + def init(): + for i in x: + x[i] = i * 2 + + init() + y = x.to_torch() + print(y) + for i in range(n): + assert (y[i] == 2 * i) + + +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(arch=archs_support_f16) +def test_from_torch(): + import torch + n = 16 + y = ti.field(dtype=ti.f16, shape=n) + # torch doesn't have rand implementation for float16 so we need to create float first and then convert + x = torch.range(0, n - 1).to(torch.float16) + y.from_torch(x) + + @ti.kernel + def init(): + for i in y: + y[i] = 3 * i + + init() + z = y.to_torch() + for i in range(n): + assert (z[i] == i * 3) + + +@test_utils.test(arch=archs_support_f16) +def test_binary_op(): + dtype = ti.f16 + x = ti.field(dtype, shape=()) + y = ti.field(dtype, shape=()) + z = ti.field(dtype, shape=()) + + @ti.kernel + def add(): + x[None] = y[None] + z[None] + x[None] = x[None] * z[None] + + y[None] = 0.2 + z[None] = 0.72 + add() + u = x.to_numpy() + assert (u[None] == test_utils.approx(0.6624, rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_rand_promote(): + dtype = ti.f16 + x = ti.field(dtype, shape=(4, 4)) + + @ti.kernel + def init(): + for i, j in x: + x[i, j] = ti.random(dtype=dtype) + print(x[i, j]) + + init() + + +@test_utils.test(arch=archs_support_f16) +def test_unary_op(): + dtype = ti.f16 + x = ti.field(dtype, shape=()) + y = ti.field(dtype, shape=()) + + @ti.kernel + def foo(): + x[None] = -y[None] + x[None] = ti.floor(x[None]) + y[None] = ti.ceil(y[None]) + + y[None] = -1.4 + foo() + assert (x[None] == test_utils.approx(1, rel=1e-3)) + assert (y[None] == test_utils.approx(-1, rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_extra_unary_promote(): + dtype = ti.f16 + x = ti.field(dtype, shape=()) + y = ti.field(dtype, shape=()) + + @ti.kernel + def foo(): + x[None] = abs(y[None]) + + y[None] = -0.3 + foo() + assert (x[None] == test_utils.approx(0.3, rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16, exclude=ti.vulkan) +def test_binary_extra_promote(): + x = ti.field(dtype=ti.f16, shape=()) + y = ti.field(dtype=ti.f16, shape=()) + z = ti.field(dtype=ti.f16, shape=()) + + @ti.kernel + def foo(): + y[None] = x[None]**2 + z[None] = ti.atan2(y[None], 0.3) + + x[None] = 0.1 + foo() + assert (z[None] == test_utils.approx(math.atan2(0.1**2, 0.3), rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_arg_f16(): + dtype = ti.f16 + x = ti.field(dtype, shape=()) + y = ti.field(dtype, shape=()) + + @ti.kernel + def foo(a: ti.f16): + x[None] = y[None] + a + + y[None] = -0.3 + foo(1.2) + assert (x[None] == test_utils.approx(0.9, rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_fractal_f16(): + n = 320 + pixels = ti.field(dtype=ti.f16, shape=(n * 2, n)) + + @ti.func + def complex_sqr(z): + return ti.Vector([z[0]**2 - z[1]**2, z[1] * z[0] * 2], dt=ti.f16) + + @ti.kernel + def paint(t: float): + for i, j in pixels: # Parallelized over all pixels + c = ti.Vector([-0.8, ti.cos(t) * 0.2], dt=ti.f16) + z = ti.Vector([i / n - 1, j / n - 0.5], dt=ti.f16) * 2 + iterations = 0 + while z.norm() < 20 and iterations < 50: + z = complex_sqr(z) + c + iterations += 1 + pixels[i, j] = 1 - iterations * 0.02 + + paint(0.03) + + +# TODO(): Vulkan support +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_atomic_add_f16(): + f = ti.field(dtype=ti.f16, shape=(2)) + + @ti.kernel + def foo(): + # Parallel sum + for i in range(1000): + f[0] += 1.12 + + # Serial sum + for _ in range(1): + for i in range(1000): + f[1] = f[1] + 1.12 + + foo() + assert (f[0] == test_utils.approx(f[1], rel=1e-3)) + + +# TODO(): Vulkan support +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_atomic_max_f16(): + f = ti.field(dtype=ti.f16, shape=(2)) + + @ti.kernel + def foo(): + # Parallel max + for i in range(1000): + ti.atomic_max(f[0], 1.12 * i) + + # Serial max + for _ in range(1): + for i in range(1000): + f[1] = ti.max(1.12 * i, f[1]) + + foo() + assert (f[0] == test_utils.approx(f[1], rel=1e-3)) + + +# TODO(): Vulkan support +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_atomic_min_f16(): + f = ti.field(dtype=ti.f16, shape=(2)) + + @ti.kernel + def foo(): + # Parallel min + for i in range(1000): + ti.atomic_min(f[0], -3.13 * i) + + # Serial min + for _ in range(1): + for i in range(1000): + f[1] = ti.min(-3.13 * i, f[1]) + + foo() + assert (f[0] == test_utils.approx(f[1], rel=1e-3)) + + +@test_utils.test(arch=archs_support_f16) +def test_cast_f32_to_f16(): + @ti.kernel + def func() -> ti.f16: + a = ti.cast(23.0, ti.f32) + b = ti.cast(4.0, ti.f32) + return ti.cast(a * b, ti.f16) + + assert func() == pytest.approx(23.0 * 4.0, 1e-4) + + +@test_utils.test(arch=archs_support_f16, require=ti.extension.data64) +def test_cast_f64_to_f16(): + @ti.kernel + def func() -> ti.f16: + a = ti.cast(23.0, ti.f64) + b = ti.cast(4.0, ti.f64) + return ti.cast(a * b, ti.f16) + + assert func() == pytest.approx(23.0 * 4.0, 1e-4) diff --git a/tests/python/test_field.py b/tests/python/test_field.py index c361657dcfdcf..06af7ed63a0ac 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -3,8 +3,11 @@ ''' import pytest +from taichi.lang import impl +from taichi.lang.misc import get_host_arch_list import taichi as ti +from tests import test_utils data_types = [ti.i32, ti.f32, ti.i64, ti.f64] field_shapes = [(), 8, (6, 12)] @@ -14,7 +17,7 @@ @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', field_shapes) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_scalar_field(dtype, shape): x = ti.field(dtype, shape) @@ -29,7 +32,7 @@ def test_scalar_field(dtype, shape): @pytest.mark.parametrize('n', vector_dims) @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', field_shapes) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_vector_field(n, dtype, shape): x = ti.Vector.field(n, dtype, shape) @@ -46,7 +49,7 @@ def test_vector_field(n, dtype, shape): @pytest.mark.parametrize('n,m', matrix_dims) @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', field_shapes) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_matrix_field(n, m, dtype, shape): x = ti.Matrix.field(n, m, dtype=dtype, shape=shape) @@ -62,7 +65,7 @@ def test_matrix_field(n, m, dtype, shape): @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', field_shapes) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_scalr_field_from_numpy(dtype, shape): import numpy as np x = ti.field(dtype, shape) @@ -79,7 +82,7 @@ def test_scalr_field_from_numpy(dtype, shape): @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', field_shapes) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_scalr_field_from_numpy_with_mismatch_shape(dtype, shape): import numpy as np x = ti.field(dtype, shape) @@ -99,7 +102,7 @@ def test_scalr_field_from_numpy_with_mismatch_shape(dtype, shape): x.from_numpy(arr) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_field_needs_grad(): # Just make sure the usage doesn't crash, see #1545 n = 8 @@ -121,7 +124,7 @@ def test_default_fp(dtype): x = ti.Vector.field(2, float, ()) - assert x.dtype == ti.get_runtime().default_fp + assert x.dtype == impl.get_runtime().default_fp @pytest.mark.parametrize('dtype', [ti.i32, ti.i64]) @@ -130,19 +133,59 @@ def test_default_ip(dtype): x = ti.Vector.field(2, int, ()) - assert x.dtype == ti.get_runtime().default_ip + assert x.dtype == impl.get_runtime().default_ip -@ti.test() +@test_utils.test() def test_field_name(): a = ti.field(dtype=ti.f32, shape=(2, 3), name='a') b = ti.Vector.field(3, dtype=ti.f32, shape=(2, 3), name='b') c = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(5, 4), name='c') - assert a.name == 'a' - assert b.name == 'b' - assert c.name == 'c' - assert b.snode.name == 'b' + assert a._name == 'a' + assert b._name == 'b' + assert c._name == 'c' + assert b.snode._name == 'b' d = [] for i in range(10): d.append(ti.field(dtype=ti.f32, shape=(2, 3), name=f'd{i}')) - assert d[i].name == f'd{i}' + assert d[i]._name == f'd{i}' + + +@test_utils.test() +@pytest.mark.parametrize('shape', field_shapes) +@pytest.mark.parametrize('dtype', [ti.i32, ti.f32]) +def test_field_copy_from(shape, dtype): + x = ti.field(dtype=ti.f32, shape=shape) + other = ti.field(dtype=dtype, shape=shape) + other.fill(1) + x.copy_from(other) + convert = lambda arr: arr[0] if len(arr) == 1 else arr + assert (convert(x.shape) == shape) + assert (x.dtype == ti.f32) + assert ((x.to_numpy() == 1).all()) + + +@test_utils.test() +def test_field_copy_from_with_mismatch_shape(): + x = ti.field(dtype=ti.f32, shape=(2, 3)) + for other_shape in [(2, ), (2, 2), (2, 3, 4)]: + other = ti.field(dtype=ti.f16, shape=other_shape) + with pytest.raises(ValueError): + x.copy_from(other) + + +@test_utils.test() +def test_field_copy_from_with_non_filed_object(): + import numpy as np + x = ti.field(dtype=ti.f32, shape=(2, 3)) + other = np.zeros((2, 3)) + with pytest.raises(TypeError): + x.copy_from(other) + + +@test_utils.test() +def test_field_shape_0(): + with pytest.raises( + ti._lib.core.TaichiRuntimeError, + match="Every dimension of a Taichi field should be positive"): + x = ti.field(dtype=ti.f32, shape=0) diff --git a/tests/python/test_fields_builder.py b/tests/python/test_fields_builder.py index 25d192b8cb3d1..cd72bc17b312d 100644 --- a/tests/python/test_fields_builder.py +++ b/tests/python/test_fields_builder.py @@ -1,195 +1,195 @@ import pytest -from taichi.lang.exception import InvalidOperationError +from taichi.lang.exception import TaichiRuntimeError import taichi as ti +from tests import test_utils -@ti.test(arch=[ti.cpu, ti.cuda, ti.vulkan, ti.metal]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan, ti.metal]) def test_fields_with_shape(): - n = 5 - x = ti.field(ti.f32, [n]) + shape = 5 + x = ti.field(ti.f32, shape=shape) @ti.kernel - def func(): - for i in range(n): + def assign_field_single(): + for i in range(shape): x[i] = i - func() - - for i in range(n): + assign_field_single() + for i in range(shape): assert x[i] == i - y = ti.field(ti.f32, [n]) + y = ti.field(ti.f32, shape=shape) @ti.kernel - def func2(): - for i in range(n): + def assign_field_multiple(): + for i in range(shape): y[i] = i * 2 - for i in range(n): + for i in range(shape): x[i] = i * 3 - func2() - - for i in range(n): + assign_field_multiple() + for i in range(shape): assert x[i] == i * 3 assert y[i] == i * 2 - func() - - for i in range(n): + assign_field_single() + for i in range(shape): assert x[i] == i -@ti.test(arch=[ti.cpu, ti.cuda, ti.vulkan, ti.metal]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan, ti.metal]) def test_fields_builder_dense(): - n = 5 - + shape = 5 fb1 = ti.FieldsBuilder() x = ti.field(ti.f32) - fb1.dense(ti.i, n).place(x) + fb1.dense(ti.i, shape).place(x) fb1.finalize() @ti.kernel - def func1(): - for i in range(n): + def assign_field_single(): + for i in range(shape): x[i] = i * 3 - func1() - for i in range(n): + assign_field_single() + for i in range(shape): assert x[i] == i * 3 fb2 = ti.FieldsBuilder() y = ti.field(ti.f32) - fb2.dense(ti.i, n).place(y) + fb2.dense(ti.i, shape).place(y) z = ti.field(ti.f32) - fb2.dense(ti.i, n).place(z) + fb2.dense(ti.i, shape).place(z) fb2.finalize() @ti.kernel - def func2(): - for i in range(n): + def assign_field_multiple(): + for i in range(shape): x[i] = i * 2 - for i in range(n): + for i in range(shape): y[i] = i + 5 - for i in range(n): + for i in range(shape): z[i] = i + 10 - func2() - for i in range(n): + assign_field_multiple() + for i in range(shape): assert x[i] == i * 2 assert y[i] == i + 5 assert z[i] == i + 10 - func1() - for i in range(n): + assign_field_single() + for i in range(shape): assert x[i] == i * 3 -@ti.test(arch=[ti.cpu, ti.cuda, ti.metal]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.metal]) def test_fields_builder_pointer(): - n = 5 - + shape = 5 fb1 = ti.FieldsBuilder() x = ti.field(ti.f32) - fb1.pointer(ti.i, n).place(x) + fb1.pointer(ti.i, shape).place(x) fb1.finalize() @ti.kernel - def func1(): - for i in range(n): + def assign_field_single(): + for i in range(shape): x[i] = i * 3 - func1() - for i in range(n): + assign_field_single() + for i in range(shape): assert x[i] == i * 3 fb2 = ti.FieldsBuilder() y = ti.field(ti.f32) - fb2.pointer(ti.i, n).place(y) + fb2.pointer(ti.i, shape).place(y) z = ti.field(ti.f32) - fb2.pointer(ti.i, n).place(z) + fb2.pointer(ti.i, shape).place(z) fb2.finalize() - # test range-for @ti.kernel - def func2(): - for i in range(n): + def assign_field_multiple_range_for(): + for i in range(shape): x[i] = i * 2 - for i in range(n): + for i in range(shape): y[i] = i + 5 - for i in range(n): + for i in range(shape): z[i] = i + 10 - func2() - for i in range(n): + assign_field_multiple_range_for() + for i in range(shape): assert x[i] == i * 2 assert y[i] == i + 5 assert z[i] == i + 10 - # test struct-for @ti.kernel - def func3(): + def assign_field_multiple_struct_for(): for i in y: y[i] += 5 for i in z: z[i] -= 5 - func3() - for i in range(n): + assign_field_multiple_struct_for() + for i in range(shape): assert y[i] == i + 10 assert z[i] == i + 5 - func1() - for i in range(n): + assign_field_single() + for i in range(shape): assert x[i] == i * 3 -@ti.test(arch=[ti.cpu, ti.cuda, ti.vulkan]) -def test_fields_builder_destroy(): - def A(i): - n = i * 10**3 - fb = ti.FieldsBuilder() - a = ti.field(ti.f64) - fb.dense(ti.i, n).place(a) - c = fb.finalize() - c.destroy() - - def B(i): - n = i * 10**3 +# We currently only consider data types that all platforms support. +# See https://docs.taichi.graphics/lang/articles/basic/type#supported-primitive-types for more details. +@pytest.mark.parametrize('test_1d_size', [1, 10, 100]) +@pytest.mark.parametrize('field_type', [ti.f32, ti.i32]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan, ti.metal]) +def test_fields_builder_destroy(test_1d_size, field_type): + def test_for_single_destroy_multi_fields(): fb = ti.FieldsBuilder() - a = ti.field(ti.f64) - fb.dense(ti.i, n).place(a) - c = fb.finalize() + for create_field_idx in range(10): + field = ti.field(field_type) + fb.dense(ti.i, test_1d_size).place(field) + fb_snode_tree = fb.finalize() + fb_snode_tree.destroy() - ni = i * 10**3 - fbi = ti.FieldsBuilder() - ai = ti.field(ti.f64) - fbi.dense(ti.i, n).place(ai) - ci = fbi.finalize() + def test_for_multi_destroy_multi_fields(): + fb0 = ti.FieldsBuilder() + fb1 = ti.FieldsBuilder() - c.destroy() - ci.destroy() + for create_field_idx in range(10): + field0 = ti.field(field_type) + field1 = ti.field(field_type) - for i in range(5): - A(5) - B(2) - A(4) + fb0.dense(ti.i, test_1d_size).place(field0) + fb1.pointer(ti.i, test_1d_size).place(field1) + fb0_snode_tree = fb0.finalize() + fb1_snode_tree = fb1.finalize() -@ti.test(arch=[ti.cpu, ti.cuda]) -def test_fields_builder_exceeds_max(): - sz = 4 + fb0_snode_tree.destroy() + fb1_snode_tree.destroy() - def create_fb(): + def test_for_raise_destroy_twice(): fb = ti.FieldsBuilder() - x = ti.field(ti.f32) - fb.dense(ti.ij, (sz, sz)).place(x) - fb.finalize() + a = ti.field(ti.f32) + fb.dense(ti.i, test_1d_size).place(a) + c = fb.finalize() + + with pytest.raises(TaichiRuntimeError): + c.destroy() + c.destroy() - # kMaxNumSnodeTreesLlvm=32 in taichi/inc/constants.h - for _ in range(32): - create_fb() - with pytest.raises(RuntimeError) as e: - create_fb() - assert 'LLVM backend supports up to 32 snode trees' in e.value.args[0] +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.vulkan]) +def test_field_initialize_zero(): + fb0 = ti.FieldsBuilder() + a = ti.field(ti.i32) + fb0.dense(ti.i, 1).place(a) + c = fb0.finalize() + a[0] = 5 + c.destroy() + fb1 = ti.FieldsBuilder() + b = ti.field(ti.i32) + fb1.dense(ti.i, 1).place(b) + d = fb1.finalize() + assert b[0] == 0 diff --git a/tests/python/test_fill.py b/tests/python/test_fill.py index 7fc5fdcc16039..5692fe38860f4 100644 --- a/tests/python/test_fill.py +++ b/tests/python/test_fill.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_fill_scalar(): val = ti.field(ti.i32) n = 4 @@ -20,7 +21,7 @@ def test_fill_scalar(): assert val[i, j] == 2 -@ti.test() +@test_utils.test() def test_fill_matrix_scalar(): val = ti.Matrix.field(2, 3, ti.i32) @@ -44,7 +45,7 @@ def test_fill_matrix_scalar(): assert val[i, j][p, q] == 2 -@ti.test() +@test_utils.test() def test_fill_matrix_matrix(): val = ti.Matrix.field(2, 3, ti.i32) diff --git a/tests/python/test_for_break.py b/tests/python/test_for_break.py index 56d601b57cbbc..ba858f9f01293 100644 --- a/tests/python/test_for_break.py +++ b/tests/python/test_for_break.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_for_break(): x = ti.field(ti.i32) N, M = 4, 4 @@ -24,7 +25,7 @@ def func(): assert x[i, j] == 100 * i + j -@ti.test() +@test_utils.test() def test_for_break2(): x = ti.field(ti.i32) N, M = 8, 8 @@ -47,7 +48,7 @@ def func(): assert x[i, j] == 100 * i + j -@ti.archs_excluding(ti.vulkan) +@test_utils.test(exclude=ti.vulkan) def test_for_break3(): x = ti.field(ti.i32) N, M = 8, 8 @@ -70,7 +71,7 @@ def func(): assert x[i, j] == 100 * i + j -@ti.test() +@test_utils.test() def test_for_break_complex(): x = ti.field(ti.i32) N, M = 16, 32 diff --git a/tests/python/test_for_group_mismatch.py b/tests/python/test_for_group_mismatch.py index e31780ec6e78f..e25cea4503c8b 100644 --- a/tests/python/test_for_group_mismatch.py +++ b/tests/python/test_for_group_mismatch.py @@ -1,8 +1,11 @@ +import pytest +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def test_struct_for_mismatch(): x = ti.field(ti.f32, (3, 4)) @@ -11,11 +14,11 @@ def func(): for i in x: print(i) - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def test_struct_for_mismatch2(): x = ti.field(ti.f32, (3, 4)) @@ -24,11 +27,11 @@ def func(): for i, j, k in x: print(i, j, k) - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def _test_grouped_struct_for_mismatch(): # doesn't work for now # need grouped refactor @@ -41,11 +44,11 @@ def func(): for i, j in ti.grouped(x): print(i, j) - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def _test_ndrange_for_mismatch(): # doesn't work for now # need ndrange refactor @@ -54,11 +57,11 @@ def func(): for i in ti.ndrange(3, 4): print(i) - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def _test_ndrange_for_mismatch2(): # doesn't work for now # need ndrange and grouped refactor @@ -67,11 +70,11 @@ def func(): for i, j, k in ti.ndrange(3, 4): print(i, j, k) - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def _test_grouped_ndrange_for_mismatch(): # doesn't work for now # need ndrange and grouped refactor @@ -80,11 +83,11 @@ def func(): for i in ti.grouped(ti.ndrange(3, 4)): print(i) - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(IndexError) +@test_utils.test(arch=get_host_arch_list()) def _test_static_ndrange_for_mismatch(): # doesn't work for now # need ndrange and static refactor @@ -93,4 +96,5 @@ def func(): for i in ti.static(ti.ndrange(3, 4)): print(i) - func() + with pytest.raises(ti.TaichiCompilationError): + func() diff --git a/tests/python/test_fp_flush_to_zero.py b/tests/python/test_fp_flush_to_zero.py index e0330bc3d2de2..d3c06e4c995c1 100644 --- a/tests/python/test_fp_flush_to_zero.py +++ b/tests/python/test_fp_flush_to_zero.py @@ -1,23 +1,22 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_ftz_f32(): a = ti.field(dtype=ti.f32, shape=2) @ti.kernel def foo(): a[0] = 1e-45 - x = 1e-10 - y = 1e-35 - a[1] = x * y + a[1] = 1e-10 * 1e-35 foo() assert a[0] == 0 assert a[1] == 0 -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_ftz_f64(): a = ti.field(dtype=ti.f64, shape=2) diff --git a/tests/python/test_function.py b/tests/python/test_function.py index fdbfcd6f5cb04..70c7ee004e3da 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -1,13 +1,14 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test(experimental_real_function=True) +@test_utils.test(arch=[ti.cpu, ti.gpu]) def test_function_without_return(): x = ti.field(ti.i32, shape=()) - @ti.func + @ti.experimental.real_func def foo(val: ti.i32): x[None] += val @@ -21,11 +22,11 @@ def run(): assert x[None] == 42 -@ti.test(experimental_real_function=True) +@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True) def test_function_with_return(): x = ti.field(ti.i32, shape=()) - @ti.func + @ti.experimental.real_func def foo(val: ti.i32) -> ti.i32: x[None] += val return val @@ -41,35 +42,11 @@ def run(): assert x[None] == 42 -@ti.test(experimental_real_function=True, exclude=[ti.opengl, ti.cc]) -def test_function_with_multiple_last_return(): - x = ti.field(ti.i32, shape=()) - - @ti.func - def foo(val: ti.i32) -> ti.i32: - if x[None]: - x[None] += val * 2 - return val * 2 - else: - x[None] += val - return val - - @ti.kernel - def run(): - a = foo(40) - foo(1) - assert a == 40 - - x[None] = 0 - run() - assert x[None] == 42 - - -@ti.test(experimental_real_function=True) +@test_utils.test(arch=[ti.cpu, ti.gpu]) def test_call_expressions(): x = ti.field(ti.i32, shape=()) - @ti.func + @ti.experimental.real_func def foo(val: ti.i32) -> ti.i32: if x[None] > 10: x[None] += 1 @@ -86,71 +63,7 @@ def run(): assert x[None] == 26 -@ti.test(arch=ti.cpu, experimental_real_function=True) -@ti.must_throw(AssertionError) -def test_failing_multiple_return(): - x = ti.field(ti.i32, shape=()) - - @ti.func - def foo(val: ti.i32) -> ti.i32: - if x[None] > 10: - if x[None] > 20: - return 1 - x[None] += 1 - x[None] += val - return 0 - - @ti.kernel - def run(): - assert foo(15) == 0 - assert foo(10) == 0 - assert foo(100) == 1 - - x[None] = 0 - run() - assert x[None] == 26 - - -@ti.test(experimental_real_function=True) -def test_python_function(): - x = ti.field(ti.i32, shape=()) - - @ti.func - def inc(val: ti.i32): - x[None] += val - - def identity(x): - return x - - @ti.data_oriented - class A: - def __init__(self): - self.count = ti.field(ti.i32, shape=()) - self.count[None] = 0 - - @ti.pyfunc - def dec(self, val: ti.i32) -> ti.i32: - self.count[None] += 1 - x[None] -= val - return self.count[None] - - @ti.kernel - def run(self) -> ti.i32: - a = self.dec(1) - identity(2) - inc(identity(3)) - return a - - a = A() - x[None] = 0 - assert a.run() == 1 - assert a.run() == 2 - assert x[None] == 4 - assert a.dec(4) == 3 - assert x[None] == 0 - - -@ti.test(arch=[ti.cpu, ti.cuda], debug=True) +@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True) def test_default_templates(): @ti.func def func1(x: ti.template()): @@ -217,7 +130,7 @@ def run_func(): run_func() -@ti.test(experimental_real_function=True) +@test_utils.test(arch=[ti.cpu, ti.gpu]) def test_experimental_templates(): x = ti.field(ti.i32, shape=()) y = ti.field(ti.i32, shape=()) @@ -237,7 +150,7 @@ def run_kernel(): assert x[None] == 11 assert y[None] == 21 - @ti.func + @ti.experimental.real_func def inc(x: ti.template()): x[None] += 1 @@ -263,21 +176,21 @@ def verify(): verify() -@ti.test(experimental_real_function=True) +@test_utils.test(arch=[ti.cpu, ti.gpu]) def test_missing_arg_annotation(): - with pytest.raises(ti.KernelDefError, match='must be type annotated'): + with pytest.raises(ti.TaichiSyntaxError, match='must be type annotated'): - @ti.func + @ti.experimental.real_func def add(a, b: ti.i32) -> ti.i32: return a + b -@ti.test(experimental_real_function=True) +@test_utils.test(arch=[ti.cpu, ti.gpu]) def test_missing_return_annotation(): - with pytest.raises(ti.TaichiSyntaxError, + with pytest.raises(ti.TaichiCompilationError, match='return value must be annotated'): - @ti.func + @ti.experimental.real_func def add(a: ti.i32, b: ti.i32): return a + b @@ -286,3 +199,25 @@ def run(): add(30, 2) run() + + +@test_utils.test(arch=[ti.cpu, ti.gpu], cfg_optimization=False) +def test_recursion(): + @ti.experimental.real_func + def sum(f: ti.template(), l: ti.i32, r: ti.i32) -> ti.i32: + ret = 0 + if l == r: + ret = f[l] + else: + ret = sum(f, l, (l + r) // 2) + sum(f, (l + r) // 2 + 1, r) + return ret + + f = ti.field(ti.i32, shape=100) + for i in range(100): + f[i] = i + + @ti.kernel + def get_sum() -> ti.i32: + return sum(f, 0, 99) + + assert get_sum() == 99 * 50 diff --git a/tests/python/test_function_parameter_by_value.py b/tests/python/test_function_parameter_by_value.py index 61cd7ecf12983..c27d5988ca33f 100644 --- a/tests/python/test_function_parameter_by_value.py +++ b/tests/python/test_function_parameter_by_value.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_pass_by_value(): @ti.func def set_val(x, i): diff --git a/tests/python/test_fuse_dense.py b/tests/python/test_fuse_dense.py index 8e1e381fb2080..a08a0fe47e946 100644 --- a/tests/python/test_fuse_dense.py +++ b/tests/python/test_fuse_dense.py @@ -1,20 +1,21 @@ import taichi as ti +from tests import test_utils from .fuse_test_template import (template_fuse_dense_x2y2z, template_fuse_reduction) -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_fuse_dense_x2y2z(): template_fuse_dense_x2y2z(size=10 * 1024**2) -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_fuse_reduction(): template_fuse_reduction(size=10 * 1024**2) -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_no_fuse_sigs_mismatch(): n = 4096 x = ti.field(ti.i32, shape=(n, )) diff --git a/tests/python/test_fuse_dynamic.py b/tests/python/test_fuse_dynamic.py index dedae5000d353..e514fd05b1e45 100644 --- a/tests/python/test_fuse_dynamic.py +++ b/tests/python/test_fuse_dynamic.py @@ -3,6 +3,7 @@ import pytest import taichi as ti +from tests import test_utils def benchmark_fuse_dynamic_x2y2z(size=1024**2, repeat=10, first_n=100): @@ -55,7 +56,7 @@ def y_to_z(): assert z[i] == x[i] + 5 -@ti.test(require=[ti.extension.async_mode, ti.extension.sparse], - async_mode=True) +@test_utils.test(require=[ti.extension.async_mode, ti.extension.sparse], + async_mode=True) def test_fuse_dynamic_x2y2z(): benchmark_fuse_dynamic_x2y2z() diff --git a/tests/python/test_gc.py b/tests/python/test_gc.py index cee3221771062..afa050ca5882e 100644 --- a/tests/python/test_gc.py +++ b/tests/python/test_gc.py @@ -1,4 +1,5 @@ import taichi as ti +from tests import test_utils def _test_block_gc(): @@ -23,7 +24,7 @@ def init(): for i in x: x[i] = ti.Vector( [ti.random() * 0.1 + 0.5, - ti.random() * 0.1 + 0.5], dt=ti.f32) + ti.random() * 0.1 + 0.5]) init() @@ -38,7 +39,7 @@ def move(): for p in x: x[p] += ti.Vector([0.0, 0.1]) - assert grid.num_dynamically_allocated == 0 + assert grid._num_dynamically_allocated == 0 for _ in range(100): grid.deactivate_all() # Scatter the particles to the sparse grid @@ -50,57 +51,57 @@ def move(): # The block of particles can occupy at most two blocks on the sparse grid. # It's fine to run 100 times and do just one final check, because # num_dynamically_allocated stores the number of slots *ever* allocated. - assert 1 <= grid.num_dynamically_allocated <= 2, grid.num_dynamically_allocated + assert 1 <= grid._num_dynamically_allocated <= 2, grid._num_dynamically_allocated -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_block(): _test_block_gc() #TODO: Remove exclude of ti.metal. -@ti.test(require=[ti.extension.sparse, ti.extension.async_mode], - exclude=[ti.metal], - async_mode=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.async_mode], + exclude=[ti.metal], + async_mode=True) def test_block_async(): _test_block_gc() -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_dynamic_gc(): x = ti.field(dtype=ti.i32) L = ti.root.dynamic(ti.i, 1024 * 1024, chunk_size=1024) L.place(x) - assert L.num_dynamically_allocated == 0 + assert L._num_dynamically_allocated == 0 for i in range(100): x[1024] = 1 L.deactivate_all() - assert L.num_dynamically_allocated <= 2 + assert L._num_dynamically_allocated <= 2 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer_gc(): x = ti.field(dtype=ti.i32) L = ti.root.pointer(ti.ij, 32) L.pointer(ti.ij, 32).dense(ti.ij, 8).place(x) - assert L.num_dynamically_allocated == 0 + assert L._num_dynamically_allocated == 0 for i in range(1024): x[i * 8, i * 8] = 1 - assert L.num_dynamically_allocated == 1 + assert L._num_dynamically_allocated == 1 L.deactivate_all() # Note that being inactive doesn't mean it's not allocated. - assert L.num_dynamically_allocated == 1 + assert L._num_dynamically_allocated == 1 -@ti.test(require=[ti.extension.sparse, ti.extension.async_mode], - async_mode=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.async_mode], + async_mode=True) def test_fuse_allocator_state(): N = 16 x = ti.field(dtype=ti.i32, shape=N) @@ -127,7 +128,7 @@ def deactivate_y(): ti.sync() # TODO: assert that activate_y and deactivate_y are not fused. - assert y_parent.num_dynamically_allocated == N + assert y_parent._num_dynamically_allocated == N ys = y.to_numpy() for i, y in enumerate(ys): expected = N if i == N else 0 diff --git a/tests/python/test_get_external_tensor_shape.py b/tests/python/test_get_external_tensor_shape.py index 8796835808c1f..fa5578fc067dc 100644 --- a/tests/python/test_get_external_tensor_shape.py +++ b/tests/python/test_get_external_tensor_shape.py @@ -1,14 +1,16 @@ import numpy as np import pytest +from taichi.lang.util import has_pytorch import taichi as ti +from tests import test_utils -if ti.has_pytorch(): +if has_pytorch(): import torch @pytest.mark.parametrize('size', [[1], [1, 2, 3, 4]]) -@ti.test() +@test_utils.test() def test_get_external_tensor_shape_access_numpy(size): @ti.kernel def func(x: ti.ext_arr(), index: ti.template()) -> ti.i32: @@ -22,7 +24,7 @@ def func(x: ti.ext_arr(), index: ti.template()) -> ti.i32: @pytest.mark.parametrize('size', [[1, 1], [2, 2]]) -@ti.test() +@test_utils.test() def test_get_external_tensor_shape_sum_numpy(size): @ti.kernel def func(x: ti.ext_arr()) -> ti.i32: @@ -40,9 +42,9 @@ def func(x: ti.ext_arr()) -> ti.i32: y_ref, y_hat) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') @pytest.mark.parametrize('size', [[1, 2, 3, 4]]) -@ti.test(exclude=ti.opengl) +@test_utils.test(exclude=ti.opengl) def test_get_external_tensor_shape_access_torch(size): @ti.kernel def func(x: ti.ext_arr(), index: ti.template()) -> ti.i32: @@ -55,9 +57,9 @@ def func(x: ti.ext_arr(), index: ti.template()) -> ti.i32: idx, y_ref, y_hat) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') @pytest.mark.parametrize('size', [[1, 2, 3, 4]]) -@ti.test(exclude=ti.opengl) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.opengl]) def test_get_external_tensor_shape_access_ndarray(size): @ti.kernel def func(x: ti.any_arr(), index: ti.template()) -> ti.i32: diff --git a/tests/python/test_ggui.py b/tests/python/test_ggui.py new file mode 100644 index 0000000000000..6a7017e5b7f84 --- /dev/null +++ b/tests/python/test_ggui.py @@ -0,0 +1,292 @@ +import os +import pathlib +import platform +import tempfile + +import numpy as np +import pytest +from taichi._lib import core as _ti_core + +import taichi as ti +from tests import test_utils + +REGENERATE_GROUNDTRUTH_IMAGES = False +RENDER_REPEAT = 5 +supported_archs = [ti.vulkan, ti.cuda] + + +def get_temp_png(): + f, name = tempfile.mkstemp(suffix='.png') + os.close(f) + return name + + +def write_temp_image(window): + f = get_temp_png() + window.write_image(f) + try: + os.remove(f) + except OSError: + pass + + +def verify_image(window, image_name, tolerence=0.1): + if REGENERATE_GROUNDTRUTH_IMAGES: + ground_truth_name = f"tests/python/expected/{image_name}.png" + window.write_image(ground_truth_name) + else: + ground_truth_name = str( + pathlib.Path(__file__).parent) + f"/expected/{image_name}.png" + actual_name = get_temp_png() + window.write_image(actual_name) + ground_truth_np = ti.imread(ground_truth_name) + actual_np = ti.imread(actual_name) + assert len(ground_truth_np.shape) == len(actual_np.shape) + for i in range(len(ground_truth_np.shape)): + assert ground_truth_np.shape[i] == actual_np.shape[i] + diff = ground_truth_np - actual_np + mse = np.mean(diff * diff) + assert mse <= tolerence # the pixel values are 0~255 + os.remove(actual_name) + + +@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") +@test_utils.test(arch=supported_archs) +def test_geometry_2d(): + window = ti.ui.Window('test', (640, 480), show_window=False) + canvas = window.get_canvas() + + # simple circles + n_circles_0 = 10 + circle_positions_0 = ti.Vector.field(2, ti.f32, shape=n_circles_0) + for i in range(n_circles_0): + circle_positions_0[i] = ti.Vector([0.1, i * 0.1]) + + # circles with per vertex colors + n_circles_1 = 10 + circle_positions_1 = ti.Vector.field(2, ti.f32, shape=n_circles_1) + circle_colors_1 = ti.Vector.field(3, ti.f32, shape=n_circles_1) + for i in range(n_circles_0): + circle_positions_1[i] = ti.Vector([0.2, i * 0.1]) + circle_colors_1[i] = ti.Vector([i * 0.1, 1.0 - i * 0.1, 0.5]) + + # simple triangles + n_triangles_0 = 10 + triangles_positions_0 = ti.Vector.field(2, ti.f32, shape=3 * n_triangles_0) + for i in range(n_triangles_0): + triangles_positions_0[3 * i] = ti.Vector([0.3, i * 0.1]) + triangles_positions_0[3 * i + 1] = ti.Vector([0.35, i * 0.1]) + triangles_positions_0[3 * i + 2] = ti.Vector([0.35, i * 0.1 + 0.05]) + + # triangles with per vertex colors and indices + triangles_positions_1 = ti.Vector.field(2, ti.f32, shape=4) + triangles_colors_1 = ti.Vector.field(3, ti.f32, shape=4) + triangles_positions_1[0] = ti.Vector([0.4, 0]) + triangles_positions_1[1] = ti.Vector([0.4, 1]) + triangles_positions_1[2] = ti.Vector([0.45, 0]) + triangles_positions_1[3] = ti.Vector([0.45, 1]) + triangles_colors_1[0] = ti.Vector([0, 0, 0]) + triangles_colors_1[1] = ti.Vector([1, 0, 0]) + triangles_colors_1[2] = ti.Vector([0, 1, 0]) + triangles_colors_1[3] = ti.Vector([1, 1, 0]) + triangle_indices_1 = ti.Vector.field(3, ti.i32, shape=2) + triangle_indices_1[0] = ti.Vector([0, 1, 3]) + triangle_indices_1[1] = ti.Vector([0, 2, 3]) + + # simple lines + n_lines_0 = 10 + lines_positions_0 = ti.Vector.field(2, ti.f32, shape=2 * n_lines_0) + for i in range(n_lines_0): + lines_positions_0[2 * i] = ti.Vector([0.5, i * 0.1]) + lines_positions_0[2 * i + 1] = ti.Vector([0.5, i * 0.1 + 0.05]) + + # lines with per vertex colors and indices + lines_positions_1 = ti.Vector.field(2, ti.f32, shape=4) + lines_colors_1 = ti.Vector.field(3, ti.f32, shape=4) + lines_positions_1[0] = ti.Vector([0.6, 0]) + lines_positions_1[1] = ti.Vector([0.6, 1]) + lines_positions_1[2] = ti.Vector([0.65, 0]) + lines_positions_1[3] = ti.Vector([0.65, 1]) + lines_colors_1[0] = ti.Vector([0, 0, 0]) + lines_colors_1[1] = ti.Vector([1, 0, 0]) + lines_colors_1[2] = ti.Vector([0, 1, 0]) + lines_colors_1[3] = ti.Vector([1, 1, 0]) + lines_indices_1 = ti.Vector.field(2, ti.i32, shape=6) + line_id = 0 + for i in range(4): + for j in range(i + 1, 4): + lines_indices_1[line_id] = ti.Vector([i, j]) + line_id += 1 + + def render(): + + canvas.circles(circle_positions_0, radius=0.05, color=(1, 0, 0)) + + canvas.circles(circle_positions_1, + radius=0.05, + per_vertex_color=circle_colors_1) + + canvas.triangles(triangles_positions_0, color=(0, 0, 1)) + + canvas.triangles(triangles_positions_1, + per_vertex_color=triangles_colors_1, + indices=triangle_indices_1) + + canvas.lines(lines_positions_0, width=0.01, color=(0, 1, 0)) + + canvas.lines(lines_positions_1, + width=0.01, + per_vertex_color=lines_colors_1, + indices=lines_indices_1) + + for _ in range(RENDER_REPEAT): + render() + write_temp_image(window) + render() + if (platform.system() == 'Darwin'): + # FIXME: Use lower tolerence when macOS ggui supports wide lines + verify_image(window, 'test_geometry_2d', 1.0) + else: + verify_image(window, 'test_geometry_2d') + window.destroy() + + +@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") +@test_utils.test(arch=supported_archs) +def test_geometry_3d(): + window = ti.ui.Window('test', (640, 480), show_window=False) + canvas = window.get_canvas() + scene = ti.ui.Scene() + camera = ti.ui.make_camera() + camera.position(0.0, 0.0, 1.5) + camera.lookat(0.0, 0.0, 0) + scene.set_camera(camera) + + # simple particles + num_per_dim = 32 + num_particles_0 = int(num_per_dim**3) + particles_positions_0 = ti.Vector.field(3, ti.f32, shape=num_particles_0) + + @ti.kernel + def init_particles_0(): + for x, y, z in ti.ndrange(num_per_dim, num_per_dim, num_per_dim): + i = x * (num_per_dim**2) + y * num_per_dim + z + gap = 0.01 + particles_positions_0[i] = ti.Vector( + [-0.4, 0, 0.0], + dt=ti.f32) + ti.Vector([x, y, z], dt=ti.f32) * gap + + init_particles_0() + + # particles with individual colors + num_per_dim = 32 + num_particles_1 = int(num_per_dim**3) + particles_positions_1 = ti.Vector.field(3, ti.f32, shape=num_particles_1) + particles_colors_1 = ti.Vector.field(3, ti.f32, shape=num_particles_1) + + @ti.kernel + def init_particles_1(): + for x, y, z in ti.ndrange(num_per_dim, num_per_dim, num_per_dim): + i = x * (num_per_dim**2) + y * num_per_dim + z + gap = 0.01 + particles_positions_1[i] = ti.Vector( + [0.2, 0, 0.0], + dt=ti.f32) + ti.Vector([x, y, z], dt=ti.f32) * gap + particles_colors_1[i] = ti.Vector([x, y, z], + dt=ti.f32) / num_per_dim + + init_particles_1() + + # mesh + vertices = ti.Vector.field(3, ti.f32, shape=8) + colors = ti.Vector.field(3, ti.f32, shape=8) + + @ti.kernel + def init_mesh(): + for i, j, k in ti.ndrange(2, 2, 2): + index = i * 4 + j * 2 + k + vertices[index] = ti.Vector( + [-0.1, -0.3, 0.0], + dt=ti.f32) + ti.Vector([i, j, k], dt=ti.f32) * 0.25 + colors[index] = ti.Vector([i, j, k], dt=ti.f32) + + init_mesh() + indices = ti.field(ti.i32, shape=36) + indices_np = np.array([ + 0, 1, 2, 3, 1, 2, 4, 5, 6, 7, 5, 6, 0, 1, 4, 5, 1, 4, 2, 3, 6, 7, 3, 6, + 0, 2, 4, 6, 2, 4, 1, 3, 5, 7, 3, 5 + ], + dtype=np.int32) + indices.from_numpy(indices_np) + + def render(): + scene.point_light(pos=(2, 2, 2), color=(1, 1, 1)) + + scene.particles(particles_positions_0, radius=0.01, color=(0.5, 0, 0)) + + scene.particles(particles_positions_1, + radius=0.01, + per_vertex_color=particles_colors_1) + + scene.mesh(vertices, + per_vertex_color=colors, + indices=indices, + two_sided=True) + + canvas.scene(scene) + + for _ in range(RENDER_REPEAT): + render() + write_temp_image(window) + render() + verify_image(window, 'test_geometry_3d') + window.destroy() + + +@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") +@test_utils.test(arch=supported_archs) +def test_set_image(): + window = ti.ui.Window('test', (640, 480), show_window=False) + canvas = window.get_canvas() + + img = ti.Vector.field(4, ti.f32, (512, 512)) + + @ti.kernel + def init_img(): + for i, j in img: + img[i, j] = ti.Vector([i, j, 0, 512], dt=ti.f32) / 512 + + init_img() + + def render(): + canvas.set_image(img) + + for _ in range(RENDER_REPEAT): + render() + write_temp_image(window) + render() + verify_image(window, 'test_set_image') + window.destroy() + + +@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") +@test_utils.test(arch=supported_archs) +def test_imgui(): + window = ti.ui.Window('test', (640, 480), show_window=False) + + def render(): + with window.GUI.sub_window("window 0", 0.1, 0.1, 0.8, 0.2) as w: + w.text("Hello Taichi!") + w.text("Hello Again!") + with window.GUI.sub_window("window 1", 0.1, 0.4, 0.8, 0.2) as w: + w.button("Press to unlease creativity") + w.slider_float('creativity level', 100.0, 0.0, 100.0) + with window.GUI.sub_window("window 2", 0.1, 0.7, 0.8, 0.2) as w: + w.color_edit_3('Heyy', (0, 0, 1)) + + for _ in range(RENDER_REPEAT): + render() + write_temp_image(window) + render() + verify_image(window, 'test_imgui') + window.destroy() diff --git a/tests/python/test_global_buffer_misalined.py b/tests/python/test_global_buffer_misalined.py index 93600a1f82908..eab4524d225aa 100644 --- a/tests/python/test_global_buffer_misalined.py +++ b/tests/python/test_global_buffer_misalined.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_global_buffer_misalignment(): @ti.kernel def test(x: ti.f32): diff --git a/tests/python/test_global_store_grad.py b/tests/python/test_global_store_grad.py index 958afe7b53acf..6f98e9f016759 100644 --- a/tests/python/test_global_store_grad.py +++ b/tests/python/test_global_store_grad.py @@ -1,14 +1,13 @@ """ import taichi as ti -ti.cfg.print_ir = True +ti.lang.impl.current_cfg().print_ir = True def test_global_store_branching(): # ti.reset() N = 16 - ti.runtime.print_preprocessed = True x = ti.field(ti.f32) y = ti.field(ti.f32) diff --git a/tests/python/test_global_thread_idx.py b/tests/python/test_global_thread_idx.py new file mode 100644 index 0000000000000..6e0d5869023a7 --- /dev/null +++ b/tests/python/test_global_thread_idx.py @@ -0,0 +1,19 @@ +import numpy as np + +import taichi as ti +from tests import test_utils + + +@test_utils.test(arch=ti.cuda) +def test_global_thread_idx(): + n = 2048 + x = ti.field(ti.i32, shape=n) + + @ti.kernel + def func(): + for i in range(n): + tid = ti.global_thread_idx() + x[tid] = tid + + func() + assert np.arange(n).sum() == x.to_numpy().sum() diff --git a/tests/python/test_grouped.py b/tests/python/test_grouped.py index d1a66519a730b..0757c7f921814 100644 --- a/tests/python/test_grouped.py +++ b/tests/python/test_grouped.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_vector_index(): val = ti.field(ti.i32) @@ -16,7 +17,7 @@ def test(): for i in range(n): for j in range(m): for k in range(p): - I = ti.Vector([i, j, k], dt=ti.i32) + I = ti.Vector([i, j, k]) val[I] = i + j * 2 + k * 3 test() @@ -27,7 +28,7 @@ def test(): assert val[i, j, k] == i + j * 2 + k * 3 -@ti.test() +@test_utils.test() def test_grouped(): val = ti.field(ti.i32) @@ -50,7 +51,7 @@ def test(): assert val[i, j, k] == i + j * 2 + k * 3 -@ti.test() +@test_utils.test() def test_grouped_ndrange(): val = ti.field(ti.i32) @@ -77,7 +78,7 @@ def test(): j * 2 if x0 <= i < y0 and x1 <= j < y1 else 0) -@ti.test() +@test_utils.test() def test_static_grouped_ndrange(): val = ti.field(ti.i32) @@ -104,7 +105,7 @@ def test(): j * 2 if x0 <= i < y0 and x1 <= j < y1 else 0) -@ti.test() +@test_utils.test() def test_grouped_ndrange_starred(): val = ti.field(ti.i32) @@ -129,7 +130,7 @@ def test(): k] == (i + j * 2 + k * 3 if j < n and k < n else 0) -@ti.test() +@test_utils.test() def test_grouped_ndrange_0d(): val = ti.field(ti.i32, shape=()) @@ -143,7 +144,7 @@ def test(): assert val[None] == 42 -@ti.test() +@test_utils.test() def test_static_grouped_ndrange_0d(): val = ti.field(ti.i32, shape=()) @@ -157,7 +158,7 @@ def test(): assert val[None] == 42 -@ti.test() +@test_utils.test() def test_static_grouped_func(): K = 3 diff --git a/tests/python/test_gui.py b/tests/python/test_gui.py index ee4683d012927..dbb0bae7ab651 100644 --- a/tests/python/test_gui.py +++ b/tests/python/test_gui.py @@ -1,12 +1,13 @@ import numpy as np import pytest +from taichi.lang.misc import get_host_arch_list import taichi as ti -from taichi import make_temp_file +from tests import test_utils @pytest.mark.parametrize('dtype', [ti.u8, ti.f32]) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_save_image_without_window(dtype): n = 255 pixels = ti.field(dtype=dtype, shape=(n, n, 3)) @@ -23,7 +24,7 @@ def paint(c: dtype): else: paint(i * 1.0 / n) gui.set_image(pixels) - image_path = make_temp_file(suffix='.png') + image_path = test_utils.make_temp_file(suffix='.png') gui.show(image_path) image = ti.imread(image_path) delta = (image - i).sum() diff --git a/tests/python/test_image_io.py b/tests/python/test_image_io.py index 7816fdca4e543..fd04ea19b64a1 100644 --- a/tests/python/test_image_io.py +++ b/tests/python/test_image_io.py @@ -2,9 +2,11 @@ import numpy as np import pytest +from taichi.lang.misc import get_host_arch_list +from taichi.lang.util import to_numpy_type import taichi as ti -from taichi import make_temp_file +from tests import test_utils # jpg is also supported but hard to test here since it's lossy: @@ -13,7 +15,7 @@ @pytest.mark.parametrize('resx,resy', [(201, 173)]) @pytest.mark.parametrize('is_field', [False, True]) @pytest.mark.parametrize('dt', [ti.u8]) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_image_io(resx, resy, comp, ext, is_field, dt): if comp != 1: shape = (resx, resy, comp) @@ -21,10 +23,10 @@ def test_image_io(resx, resy, comp, ext, is_field, dt): shape = (resx, resy) if is_field: pixel_t = ti.field(dt, shape) - pixel = np.random.randint(256, size=shape, dtype=ti.to_numpy_type(dt)) + pixel = np.random.randint(256, size=shape, dtype=to_numpy_type(dt)) if is_field: pixel_t.from_numpy(pixel) - fn = make_temp_file(suffix='.' + ext) + fn = test_utils.make_temp_file(suffix='.' + ext) if is_field: ti.imwrite(pixel_t, fn) else: @@ -40,15 +42,15 @@ def test_image_io(resx, resy, comp, ext, is_field, dt): @pytest.mark.parametrize('comp,ext', [(3, 'png'), (4, 'png')]) @pytest.mark.parametrize('resx,resy', [(91, 81)]) @pytest.mark.parametrize('dt', [ti.f32, ti.f64]) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_image_io_vector(resx, resy, comp, ext, dt): shape = (resx, resy) - pixel = np.random.rand(*shape, comp).astype(ti.to_numpy_type(dt)) + pixel = np.random.rand(*shape, comp).astype(to_numpy_type(dt)) pixel_t = ti.Vector.field(comp, dt, shape) pixel_t.from_numpy(pixel) - fn = make_temp_file(suffix='.' + ext) + fn = test_utils.make_temp_file(suffix='.' + ext) ti.imwrite(pixel_t, fn) - pixel_r = (ti.imread(fn).astype(ti.to_numpy_type(dt)) + 0.5) / 256.0 + pixel_r = (ti.imread(fn).astype(to_numpy_type(dt)) + 0.5) / 256.0 assert np.allclose(pixel_r, pixel, atol=2e-2) os.remove(fn) @@ -56,17 +58,17 @@ def test_image_io_vector(resx, resy, comp, ext, dt): @pytest.mark.parametrize('comp,ext', [(3, 'png')]) @pytest.mark.parametrize('resx,resy', [(91, 81)]) @pytest.mark.parametrize('dt', [ti.u16, ti.u32, ti.u64]) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_image_io_uint(resx, resy, comp, ext, dt): shape = (resx, resy) - np_type = ti.to_numpy_type(dt) + np_type = to_numpy_type(dt) # When saving to disk, pixel data will be truncated into 8 bits. # Be careful here if you want lossless saving. np_max = np.iinfo(np_type).max // 256 pixel = np.random.randint(256, size=(*shape, comp), dtype=np_type) * np_max pixel_t = ti.Vector.field(comp, dt, shape) pixel_t.from_numpy(pixel) - fn = make_temp_file(suffix='.' + ext) + fn = test_utils.make_temp_file(suffix='.' + ext) ti.imwrite(pixel_t, fn) pixel_r = ti.imread(fn).astype(np_type) * np_max assert (pixel_r == pixel).all() @@ -76,7 +78,7 @@ def test_image_io_uint(resx, resy, comp, ext, dt): @pytest.mark.parametrize('comp', [1, 3]) @pytest.mark.parametrize('resx,resy', [(91, 81)]) @pytest.mark.parametrize('scale', [1, 2, 3]) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_image_resize_sum(resx, resy, comp, scale): shape = (resx, resy) if comp != 1: @@ -86,4 +88,4 @@ def test_image_resize_sum(resx, resy, comp, scale): new_img = ti.imresize(old_img, resx * scale) else: new_img = ti.imresize(old_img, resx * scale, resy * scale) - assert np.sum(old_img) * scale**2 == ti.approx(np.sum(new_img)) + assert np.sum(old_img) * scale**2 == test_utils.approx(np.sum(new_img)) diff --git a/tests/python/test_immediate_layout.py b/tests/python/test_immediate_layout.py index bf60a8ba7dabb..065b3b04cffe7 100644 --- a/tests/python/test_immediate_layout.py +++ b/tests/python/test_immediate_layout.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_1D(): N = 2 x = ti.field(ti.f32) diff --git a/tests/python/test_indices.py b/tests/python/test_indices.py index 362153e161af3..fd676fae9b24c 100644 --- a/tests/python/test_indices.py +++ b/tests/python/test_indices.py @@ -1,18 +1,21 @@ +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_indices(): a = ti.field(ti.f32, shape=(128, 32, 8)) b = ti.field(ti.f32) ti.root.dense(ti.j, 32).dense(ti.i, 16).place(b) - mapping_a = a.snode.physical_index_position() + mapping_a = a.snode._physical_index_position() assert mapping_a == {0: 0, 1: 1, 2: 2} - mapping_b = b.snode.physical_index_position() + mapping_b = b.snode._physical_index_position() assert mapping_b == {0: 0, 1: 1} # Note that b is column-major: @@ -34,7 +37,7 @@ def get_field_addr(i: ti.i32, j: ti.i32) -> ti.u64: assert get_field_addr(0, 1) + 4 == get_field_addr(1, 1) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_float_as_index(): a = ti.field(ti.f32, (8, 5)) diff --git a/tests/python/test_indices_assert.py b/tests/python/test_indices_assert.py index ce6a0ce9a52f0..99053c1f83773 100644 --- a/tests/python/test_indices_assert.py +++ b/tests/python/test_indices_assert.py @@ -3,19 +3,19 @@ import pytest import taichi as ti +from tests import test_utils @pytest.mark.skipif(platform.system() == 'Windows', reason="Too much virtual memory for github windows env.") -@ti.test(debug=True, gdb_trigger=False, packed=False, arch=[ti.cpu]) +@test_utils.test(debug=True, gdb_trigger=False, packed=False, arch=[ti.cpu]) def test_indices_assert(): - overflow = ti.field(ti.i32, (334, 334, 334, 2 * 10)) + overflow = ti.field(ti.u8, (3, 715827883)) @ti.kernel def access_overflow(): - overflow[0, 0, 0, 0] = 10 - print(overflow[333, 333, 333, 0]) + overflow[2, 715827882] = 10 with pytest.raises(RuntimeError, match='The indices provided are too big!'): diff --git a/tests/python/test_internal_func.py b/tests/python/test_internal_func.py index b7043e8fdf33e..78dcd8eef90f9 100644 --- a/tests/python/test_internal_func.py +++ b/tests/python/test_internal_func.py @@ -1,31 +1,28 @@ import time -import taichi as ti - +from taichi.lang import impl -# TODO: these are not really tests... -def all_archs_for_this(test): - # ti.call_internal() is not supported on CUDA, Metal, OpenGL yet - return ti.archs_excluding(ti.metal, ti.opengl, ti.cuda, ti.vulkan)(test) +import taichi as ti +from tests import test_utils -@ti.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan]) +@test_utils.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan, ti.cc]) def test_basic(): @ti.kernel def test(): for _ in range(10): - ti.call_internal("do_nothing") + impl.call_internal("do_nothing") test() -@ti.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan]) +@test_utils.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan, ti.cc]) def test_host_polling(): return @ti.kernel def test(): - ti.call_internal("refresh_counter") + impl.call_internal("refresh_counter") for i in range(10): print('updating tail to', i) @@ -33,40 +30,40 @@ def test(): time.sleep(0.1) -@ti.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan]) +@test_utils.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan, ti.cc]) def test_list_manager(): @ti.kernel def test(): - ti.call_internal("test_list_manager") + impl.call_internal("test_list_manager") test() test() -@ti.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan]) +@test_utils.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan, ti.cc]) def test_node_manager(): @ti.kernel def test(): - ti.call_internal("test_node_allocator") + impl.call_internal("test_node_allocator") test() test() -@ti.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan]) +@test_utils.test(exclude=[ti.metal, ti.opengl, ti.cuda, ti.vulkan, ti.cc]) def test_node_manager_gc(): @ti.kernel def test_cpu(): - ti.call_internal("test_node_allocator_gc_cpu") + impl.call_internal("test_node_allocator_gc_cpu") test_cpu() -@ti.test(arch=[ti.cpu, ti.cuda], debug=True) +@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True) def test_return(): @ti.kernel def test_cpu(): - ret = ti.call_internal("test_internal_func_args", 1.0, 2.0, 3) + ret = impl.call_internal("test_internal_func_args", 1.0, 2.0, 3) assert ret == 9 test_cpu() diff --git a/tests/python/test_kernel_arg_errors.py b/tests/python/test_kernel_arg_errors.py index 9df816cd92f19..2e0797795b77a 100644 --- a/tests/python/test_kernel_arg_errors.py +++ b/tests/python/test_kernel_arg_errors.py @@ -1,17 +1,17 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_pass_float_as_i32(): @ti.kernel def foo(a: ti.i32): pass - with pytest.raises(ti.KernelArgError) as e: + with pytest.raises(ti.TaichiRuntimeTypeError) as e: foo(1.2) - assert e.type is ti.KernelArgError assert e.value.args[ 0] == "Argument 0 (type=) cannot be converted into required type i32" diff --git a/tests/python/test_kernel_templates.py b/tests/python/test_kernel_templates.py index e21c2076080fd..09637436c3ff1 100644 --- a/tests/python/test_kernel_templates.py +++ b/tests/python/test_kernel_templates.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_kernel_template_basic(): x = ti.field(ti.i32) y = ti.field(ti.f32) @@ -32,7 +33,7 @@ def inc2(z: ti.i32, a: ti.template(), b: ti.i32): assert x[i] == 12 -@ti.test() +@test_utils.test() def test_kernel_template_gradient(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -66,7 +67,7 @@ def compute_loss(): assert x.grad[i] == 4 -@ti.test() +@test_utils.test() def test_func_template(): a = [ti.field(dtype=ti.f32) for _ in range(2)] b = [ti.field(dtype=ti.f32) for _ in range(2)] @@ -98,7 +99,7 @@ def aTob(l: ti.template()): assert b[l][i, j] == l -@ti.test() +@test_utils.test() def test_func_template2(): a = ti.field(dtype=ti.f32) b = ti.field(dtype=ti.f32) diff --git a/tests/python/test_lang.py b/tests/python/test_lang.py index e913db5365fca..1ce3dcde85b23 100644 --- a/tests/python/test_lang.py +++ b/tests/python/test_lang.py @@ -1,10 +1,12 @@ import numpy as np import pytest +from taichi.lang.misc import get_host_arch_list import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_nested_subscript(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -24,7 +26,7 @@ def inc(): assert x[0] == 1 -@ti.test() +@test_utils.test() def test_norm(): val = ti.field(ti.i32) f = ti.field(ti.f32) @@ -55,7 +57,7 @@ def test2(): assert val[i] == 96 + i -@ti.test() +@test_utils.test() def test_simple2(): val = ti.field(ti.i32) f = ti.field(ti.f32) @@ -82,7 +84,7 @@ def test2(): assert val[i] == 1 + i * 2 -@ti.test() +@test_utils.test() def test_recreate(): @ti.kernel def test(): @@ -92,7 +94,7 @@ def test(): test() -@ti.test() +@test_utils.test() def test_local_atomics(): n = 32 val = ti.field(ti.i32, shape=n) @@ -112,8 +114,7 @@ def test(): assert val[i] == i + 45 -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(UnboundLocalError) +@test_utils.test(arch=get_host_arch_list()) def test_loop_var_life(): @ti.kernel def test(): @@ -121,11 +122,11 @@ def test(): pass print(i) - test() + with pytest.raises(Exception): + test() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(UnboundLocalError) +@test_utils.test(arch=get_host_arch_list()) def test_loop_var_life_double_iters(): @ti.kernel def test(): @@ -133,13 +134,14 @@ def test(): pass print(i) - test() + with pytest.raises(Exception): + test() @pytest.mark.parametrize('dtype', [ti.i32, ti.f32, ti.i64, ti.f64]) @pytest.mark.parametrize('ti_zero,zero', [(ti.zero, 0), (ti.one, 1)]) @pytest.mark.parametrize('is_mat', [False, True]) -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_meta_zero_one(dtype, ti_zero, zero, is_mat): if is_mat: x = ti.Matrix.field(2, 3, dtype, ()) @@ -153,7 +155,7 @@ def func(): y[None] = ti_zero(x[None]) for a in [-1, -2.3, -1, -0.3, 0, 1, 1.9, 2, 3]: - if ti.core.is_integral(dtype): + if ti.types.is_integral(dtype): a = int(a) x.fill(a) func() diff --git a/tests/python/test_lexical_scope.py b/tests/python/test_lexical_scope.py index 33d39df21db2e..4c5909c93437b 100644 --- a/tests/python/test_lexical_scope.py +++ b/tests/python/test_lexical_scope.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_func_closure(): def my_test(): a = 32 diff --git a/tests/python/test_linalg.py b/tests/python/test_linalg.py index 947d1964b1d9f..3a882ce699e51 100644 --- a/tests/python/test_linalg.py +++ b/tests/python/test_linalg.py @@ -2,12 +2,13 @@ import numpy as np import pytest +from taichi.lang.misc import get_host_arch_list import taichi as ti -from taichi import approx +from tests import test_utils -@ti.test() +@test_utils.test() def test_const_init(): a = ti.Matrix.field(2, 3, dtype=ti.i32, shape=()) b = ti.Vector.field(3, dtype=ti.i32, shape=()) @@ -27,7 +28,7 @@ def init(): assert b[None][j] == j -@ti.test() +@test_utils.test() def test_basic_utils(): a = ti.Vector.field(3, dtype=ti.f32) b = ti.Vector.field(2, dtype=ti.f32) @@ -42,7 +43,7 @@ def test_basic_utils(): @ti.kernel def init(): - a[None] = ti.Vector([1.0, 2.0, 3.0]) + a[None] = ti.Vector([1.0, 2.0, -3.0]) b[None] = ti.Vector([4.0, 5.0]) abT[None] = a[None].outer_product(b[None]) @@ -60,15 +61,15 @@ def init(): sqrt14 = np.sqrt(14.0) invSqrt14 = 1.0 / sqrt14 - assert normSqrA[None] == approx(14.0) - assert normInvA[None] == approx(invSqrt14) - assert normA[None] == approx(sqrt14) - assert aNormalized[None][0] == approx(1.0 * invSqrt14) - assert aNormalized[None][1] == approx(2.0 * invSqrt14) - assert aNormalized[None][2] == approx(3.0 * invSqrt14) + assert normSqrA[None] == test_utils.approx(14.0) + assert normInvA[None] == test_utils.approx(invSqrt14) + assert normA[None] == test_utils.approx(sqrt14) + assert aNormalized[None][0] == test_utils.approx(1.0 * invSqrt14) + assert aNormalized[None][1] == test_utils.approx(2.0 * invSqrt14) + assert aNormalized[None][2] == test_utils.approx(-3.0 * invSqrt14) -@ti.test() +@test_utils.test() def test_cross(): a = ti.Vector.field(3, dtype=ti.f32) b = ti.Vector.field(3, dtype=ti.f32) @@ -97,7 +98,7 @@ def init(): assert c2[None] == -3.0 -@ti.test() +@test_utils.test() def test_dot(): a = ti.Vector.field(3, dtype=ti.f32) b = ti.Vector.field(3, dtype=ti.f32) @@ -124,7 +125,7 @@ def init(): assert c2[None] == 14.0 -@ti.test() +@test_utils.test() def test_transpose(): dim = 3 m = ti.Matrix.field(dim, dim, ti.f32) @@ -144,7 +145,7 @@ def transpose(): for i in range(dim): for j in range(dim): - assert m[None][j, i] == approx(i * 2 + j * 7) + assert m[None][j, i] == test_utils.approx(i * 2 + j * 7) def _test_polar_decomp(dim, dt): @@ -178,24 +179,25 @@ def V(i, j): for i in range(dim): for j in range(dim): - assert m[None][i, j] == approx(V(i, j), abs=tol) - assert I[None][i, j] == approx(int(i == j), abs=tol) - assert D[None][i, j] == approx(0, abs=tol) + assert m[None][i, j] == test_utils.approx(V(i, j), abs=tol) + assert I[None][i, j] == test_utils.approx(int(i == j), abs=tol) + assert D[None][i, j] == test_utils.approx(0, abs=tol) def test_polar_decomp(): for dim in [2, 3]: for dt in [ti.f32, ti.f64]: - @ti.test(require=ti.extension.data64 if dt == ti.f64 else [], - default_fp=dt) + @test_utils.test( + require=ti.extension.data64 if dt == ti.f64 else [], + default_fp=dt) def wrapped(): _test_polar_decomp(dim, dt) wrapped() -@ti.test() +@test_utils.test() def test_matrix(): x = ti.Matrix.field(2, 2, dtype=ti.i32) @@ -219,7 +221,7 @@ def inc(): assert x[i][1, 1] == 1 + i -@ti.test() +@test_utils.test() def _test_mat_inverse_size(n): m = ti.Matrix.field(n, n, dtype=ti.f32, shape=()) M = np.empty(shape=(n, n), dtype=np.float32) @@ -245,7 +247,7 @@ def test_mat_inverse(): _test_mat_inverse_size(n) -@ti.test() +@test_utils.test() def test_matrix_factories(): a = ti.Vector.field(3, dtype=ti.i32, shape=3) b = ti.Matrix.field(2, 2, dtype=ti.f32, shape=2) @@ -267,17 +269,17 @@ def fill(): assert a[i][j] == int(i == j) sqrt3o2 = math.sqrt(3) / 2 - assert b[0].value.to_numpy() == approx(np.eye(2)) - assert b[1].value.to_numpy() == approx( + assert b[0].to_numpy() == test_utils.approx(np.eye(2)) + assert b[1].to_numpy() == test_utils.approx( np.array([[0.5, -sqrt3o2], [sqrt3o2, 0.5]])) - assert c[0].value.to_numpy() == approx(np.zeros((2, 3))) - assert c[1].value.to_numpy() == approx(np.ones((2, 3))) + assert c[0].to_numpy() == test_utils.approx(np.zeros((2, 3))) + assert c[1].to_numpy() == test_utils.approx(np.ones((2, 3))) # TODO: move codes below to test_matrix.py: -@ti.test() +@test_utils.test() def test_init_matrix_from_vectors(): m1 = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(3)) m2 = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(3)) @@ -309,7 +311,7 @@ def fill(): # TODO: Remove this once the APIs are obsolete. @pytest.mark.filterwarnings('ignore') -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_init_matrix_from_vectors_deprecated(): m1 = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(3)) m2 = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(3)) @@ -339,17 +341,7 @@ def fill(): assert m4[0][j, i] == int(i + 3 * j + 1) -@pytest.mark.filterwarnings('ignore') -@ti.test(arch=ti.get_host_arch_list()) -def test_to_numpy_as_vector_deprecated(): - v = ti.Vector.field(3, dtype=ti.f32, shape=(2)) - u = np.array([[2, 3, 4], [5, 6, 7]]) - v.from_numpy(u) - assert v.to_numpy(as_vector=True) == approx(u) - assert v.to_numpy() == approx(u) - - -@ti.test() +@test_utils.test() def test_any_all(): a = ti.Matrix.field(2, 2, dtype=ti.i32, shape=()) b = ti.field(dtype=ti.i32, shape=()) @@ -379,7 +371,7 @@ def func(): assert c[None] == 0 -@ti.test() +@test_utils.test() def test_min_max(): a = ti.Matrix.field(2, 2, dtype=ti.i32, shape=()) b = ti.field(dtype=ti.i32, shape=()) @@ -403,7 +395,7 @@ def func(): # must not throw any error: -@ti.test() +@test_utils.test() def test_matrix_list_assign(): m = ti.Matrix.field(2, 2, dtype=ti.i32, shape=(2, 2, 1)) @@ -427,7 +419,7 @@ def func(): assert np.allclose(v.to_numpy()[1, 0, 0, :], np.array([10, 12])) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_vector_xyzw_accessor(): u = ti.Vector.field(2, dtype=ti.i32, shape=(2, 2, 1)) v = ti.Vector.field(4, dtype=ti.i32, shape=(2, 2, 1)) @@ -450,7 +442,7 @@ def func(): assert np.allclose(v.to_numpy()[1, 0, 0, :], np.array([6, 0, -3, 4])) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_diag(): m1 = ti.Matrix.field(3, 3, dtype=ti.f32, shape=()) @@ -463,6 +455,6 @@ def fill(): for i in range(3): for j in range(3): if i == j: - assert m1[None][i, j] == approx(1.4) + assert m1[None][i, j] == test_utils.approx(1.4) else: assert m1[None][i, j] == 0.0 diff --git a/tests/python/test_listgen.py b/tests/python/test_listgen.py index cf68e70550108..eb61946246ead 100644 --- a/tests/python/test_listgen.py +++ b/tests/python/test_listgen.py @@ -1,9 +1,10 @@ from random import randrange import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_listgen(): x = ti.field(ti.i32) n = 1024 @@ -33,7 +34,7 @@ def fill(c: ti.i32): assert x[i, j] == i * 10 + j + c -@ti.test() +@test_utils.test() def test_nested_3d(): x = ti.field(ti.i32) n = 128 diff --git a/tests/python/test_literal.py b/tests/python/test_literal.py new file mode 100644 index 0000000000000..88ac3827df213 --- /dev/null +++ b/tests/python/test_literal.py @@ -0,0 +1,82 @@ +import pytest + +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_literal_u32(): + @ti.kernel + def pcg_hash(inp: ti.u32) -> ti.u32: + state: ti.u32 = inp * ti.u32(747796405) + ti.u32(2891336453) + word: ti.u32 = ((state >> ( + (state >> ti.u32(28)) + ti.u32(4))) ^ state) * ti.u32(277803737) + return (word >> ti.u32(22)) ^ word + + assert pcg_hash(12345678) == 119515934 + assert pcg_hash(98765432) == 4244201195 + + +@test_utils.test() +def test_literal_multi_args_error(): + @ti.kernel + def multi_args_error(): + a = ti.i64(1, 2) + + with pytest.raises( + ti.TaichiSyntaxError, + match="Type annotation can only be given to a single literal."): + multi_args_error() + + +@test_utils.test() +def test_literal_keywords_error(): + @ti.kernel + def keywords_error(): + a = ti.f64(1, x=2) + + with pytest.raises( + ti.TaichiSyntaxError, + match="Type annotation can only be given to a single literal."): + keywords_error() + + +@test_utils.test() +def test_literal_expr_error(): + @ti.kernel + def expr_error(): + a = 1 + b = ti.f16(a) + + with pytest.raises( + ti.TaichiSyntaxError, + match="Type annotation can only be given to a single literal."): + expr_error() + + +@test_utils.test() +def test_literal_int_annotation_error(): + @ti.kernel + def int_annotation_error(): + a = ti.f32(0) + + with pytest.raises( + ti.TaichiTypeError, + match= + "Integer literals must be annotated with a integer type. For type casting, use `ti.cast`." + ): + int_annotation_error() + + +@test_utils.test() +def test_literal_float_annotation_error(): + @ti.kernel + def float_annotation_error(): + a = ti.i32(0.0) + + with pytest.raises( + ti.TaichiTypeError, + match= + "Floating-point literals must be annotated with a floating-point type. For type casting, use `ti.cast`." + ): + float_annotation_error() diff --git a/tests/python/test_local_atomic_opt.py b/tests/python/test_local_atomic_opt.py index 758c6efe342a5..560fb9071f2fb 100644 --- a/tests/python/test_local_atomic_opt.py +++ b/tests/python/test_local_atomic_opt.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_cse(): A = ti.field(ti.f32, shape=()) @@ -16,7 +17,7 @@ def func(): assert A[None] == 133 -@ti.test() +@test_utils.test() def test_store_forward(): A = ti.field(ti.f32, shape=()) diff --git a/tests/python/test_local_atomics.py b/tests/python/test_local_atomics.py index 660bf3f87252f..272834d97b956 100644 --- a/tests/python/test_local_atomics.py +++ b/tests/python/test_local_atomics.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_explicit_local_atomic_add(): A = ti.field(ti.f32, shape=()) @@ -16,7 +17,7 @@ def func(): assert A[None] == 45 -@ti.test() +@test_utils.test() def test_implicit_local_atomic_add(): A = ti.field(ti.f32, shape=()) @@ -31,7 +32,7 @@ def func(): assert A[None] == 45 -@ti.test() +@test_utils.test() def test_explicit_local_atomic_sub(): A = ti.field(ti.f32, shape=()) @@ -46,7 +47,7 @@ def func(): assert A[None] == -45 -@ti.test() +@test_utils.test() def test_implicit_local_atomic_sub(): A = ti.field(ti.f32, shape=()) @@ -61,7 +62,7 @@ def func(): assert A[None] == -45 -@ti.test() +@test_utils.test() def test_explicit_local_atomic_min(): A = ti.field(ti.f32, shape=()) @@ -76,7 +77,7 @@ def func(): assert A[None] == 0 -@ti.test() +@test_utils.test() def test_explicit_local_atomic_max(): A = ti.field(ti.f32, shape=()) @@ -91,7 +92,7 @@ def func(): assert A[None] == 9 -@ti.test() +@test_utils.test() def test_explicit_local_atomic_and(): A = ti.field(ti.i32, shape=()) max_int = 2147483647 @@ -107,7 +108,7 @@ def func(): assert A[None] == 0 -@ti.test() +@test_utils.test() def test_implicit_local_atomic_and(): A = ti.field(ti.i32, shape=()) max_int = 2147483647 @@ -123,7 +124,7 @@ def func(): assert A[None] == 0 -@ti.test() +@test_utils.test() def test_explicit_local_atomic_or(): A = ti.field(ti.i32, shape=()) @@ -138,7 +139,7 @@ def func(): assert A[None] == 1023 -@ti.test() +@test_utils.test() def test_implicit_local_atomic_or(): A = ti.field(ti.i32, shape=()) @@ -153,7 +154,7 @@ def func(): assert A[None] == 1023 -@ti.test() +@test_utils.test() def test_explicit_local_atomic_xor(): A = ti.field(ti.i32, shape=()) @@ -168,7 +169,7 @@ def func(): assert A[None] == 0 -@ti.test() +@test_utils.test() def test_implicit_local_atomic_xor(): A = ti.field(ti.i32, shape=()) diff --git a/tests/python/test_loop_grad.py b/tests/python/test_loop_grad.py index d91966e0e31ce..ab209613430a2 100644 --- a/tests/python/test_loop_grad.py +++ b/tests/python/test_loop_grad.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(exclude=[ti.vulkan]) +@test_utils.test(exclude=[ti.vulkan]) def test_loop_grad(): x = ti.field(ti.f32) @@ -31,7 +32,7 @@ def func(): assert x.grad[k, i] == 2**(m - 1 - i) -@ti.test(exclude=[ti.vulkan]) +@test_utils.test(exclude=[ti.vulkan]) def test_loop_grad_complex(): return # This case is not supported yet x = ti.field(ti.f32) diff --git a/tests/python/test_loop_unique.py b/tests/python/test_loop_unique.py index c4a33d15d5c50..8c439b3bd8166 100644 --- a/tests/python/test_loop_unique.py +++ b/tests/python/test_loop_unique.py @@ -1,7 +1,10 @@ +from taichi.lang.misc import loop_unique + import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_loop_unique_simple_1d(): x, y = ti.field(ti.i32), ti.field(ti.i32) @@ -12,7 +15,7 @@ def test_loop_unique_simple_1d(): @ti.kernel def inc_y(): for i in x: - a = ti.loop_unique(x[i]) + a = loop_unique(x[i]) y[a] += 1 x[1] = 2 @@ -26,7 +29,7 @@ def inc_y(): assert y[i] == expected_result.get(i, 0) -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_loop_unique_binary_op_1d(): x, y = ti.field(ti.i32), ti.field(ti.i32) @@ -37,7 +40,7 @@ def test_loop_unique_binary_op_1d(): @ti.kernel def inc_y(): for i in x: - a = ti.loop_unique(x[i]) + a = loop_unique(x[i]) y[a + 1] += 1 x[1] = 2 @@ -51,7 +54,7 @@ def inc_y(): assert y[i] == expected_result.get(i, 0) -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_loop_unique_nested_1d(): x, y = ti.field(ti.i32), ti.field(ti.i32) @@ -63,7 +66,7 @@ def test_loop_unique_nested_1d(): def inc_y(): for i in x: for j in range(i): - a = ti.loop_unique(x[i]) + a = loop_unique(x[i]) y[a] += 1 x[1] = 2 @@ -77,7 +80,7 @@ def inc_y(): assert y[i] == expected_result.get(i, 0) -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_loop_unique_2d(): x, y, z = ti.field(ti.i32), ti.field(ti.i32), ti.field(ti.i32) @@ -89,7 +92,7 @@ def test_loop_unique_2d(): @ti.kernel def inc_y_z(): for i, j in x: - a = ti.loop_unique(x[i, j]) + a = loop_unique(x[i, j]) y[a, j] += 1 z[i, i] += 1 # cannot demote this @@ -123,7 +126,7 @@ def inc_y_z(): assert z[i, j] == expected_result_z.get((i, j), 0) -@ti.test() +@test_utils.test() def test_loop_unique_ndrange(): x, y, z = ti.field(ti.i32), ti.field(ti.i32), ti.field(ti.i32) @@ -144,7 +147,7 @@ def prepare_x(): @ti.kernel def inc_y_z(): for i, j in ti.ndrange(a, b): - u = ti.loop_unique(x[i, j]) + u = loop_unique(x[i, j]) y[u] += i z[i, j + 1] += 10 # TODO: demote this diff --git a/tests/python/test_loops.py b/tests/python/test_loops.py index a2bfdb9f97cb3..81b0456df8cd3 100644 --- a/tests/python/test_loops.py +++ b/tests/python/test_loops.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_loops(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -18,7 +19,7 @@ def test_loops(): @ti.kernel def func(): for i in range(ti.static(N // 2 + 3), N): - x[i] = ti.abs(y[i]) + x[i] = abs(y[i]) func() @@ -29,7 +30,7 @@ def func(): assert x[i] == abs(y[i]) -@ti.test() +@test_utils.test() def test_numpy_loops(): x = ti.field(ti.f32) y = ti.field(ti.f32) @@ -44,13 +45,13 @@ def test_numpy_loops(): y[i] = i - 300 import numpy as np - begin = (np.ones(1) * (N // 2 + 3)).astype(np.int32) - end = (np.ones(1) * N).astype(np.int32) + begin = (np.ones(1) * (N // 2 + 3)).astype(np.int32).reshape(()) + end = (np.ones(1) * N).astype(np.int32).reshape(()) @ti.kernel def func(): for i in range(begin, end): - x[i] = ti.abs(y[i]) + x[i] = abs(y[i]) func() @@ -61,7 +62,7 @@ def func(): assert x[i] == abs(y[i]) -@ti.test() +@test_utils.test() def test_nested_loops(): # this may crash if any LLVM allocas are called in the loop body x = ti.field(ti.i32) @@ -79,7 +80,7 @@ def paint(): paint() -@ti.test() +@test_utils.test() def test_zero_outer_loop(): x = ti.field(ti.i32, shape=()) @@ -93,7 +94,7 @@ def test(): assert x[None] == 0 -@ti.test() +@test_utils.test() def test_zero_inner_loop(): x = ti.field(ti.i32, shape=()) @@ -108,7 +109,7 @@ def test(): assert x[None] == 0 -@ti.test() +@test_utils.test() def test_dynamic_loop_range(): x = ti.field(ti.i32) c = ti.field(ti.i32) @@ -129,7 +130,7 @@ def test(): assert sum(x.to_numpy()) == (n * (n - 1) // 2) + n * n -@ti.test() +@test_utils.test() def test_loop_arg_as_range(): # Dynamic range loops are intended to make sure global tmps work x = ti.field(ti.i32) @@ -153,7 +154,7 @@ def test(b: ti.i32, e: ti.i32): assert x[i - b] == i -@ti.test() +@test_utils.test() def test_assignment_in_nested_loops(): # https://github.com/taichi-dev/taichi/issues/1109 m = ti.field(ti.f32, 3) @@ -171,3 +172,18 @@ def func(): x[None] = 1 func() assert x[None] == 1 + + +@test_utils.test() +def test_break_in_outermost_for_not_in_outermost_scope(): + @ti.kernel + def foo() -> ti.i32: + a = 0 + if True: + for i in range(1000): + if i == 100: + break + a += 1 + return a + + assert foo() == 100 diff --git a/tests/python/test_materialize_check.py b/tests/python/test_materialize_check.py new file mode 100644 index 0000000000000..e7acc6689143d --- /dev/null +++ b/tests/python/test_materialize_check.py @@ -0,0 +1,35 @@ +import pytest + +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_check_field_not_placed(): + a = ti.field(ti.i32) + + @ti.kernel + def foo(): + pass + + with pytest.raises(RuntimeError, + match=r"These field\(s\) are not placed.*"): + foo() + + +@test_utils.test() +def test_check_matrix_field_member_shape(): + a = ti.Matrix.field(2, 2, ti.i32) + ti.root.dense(ti.i, 10).place(a.get_scalar_field(0, 0)) + ti.root.dense(ti.i, 11).place(a.get_scalar_field(0, 1)) + ti.root.dense(ti.i, 10).place(a.get_scalar_field(1, 0)) + ti.root.dense(ti.i, 11).place(a.get_scalar_field(1, 1)) + + @ti.kernel + def foo(): + pass + + with pytest.raises( + RuntimeError, + match=r"Members of the following field have different shapes.*"): + foo() diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 8e1af4a0ca561..22d56d630e7b2 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1,10 +1,13 @@ +import math import operator import numpy as np import pytest +from taichi.lang import impl +from taichi.lang.misc import get_host_arch_list import taichi as ti -from taichi import approx +from tests import test_utils operation_types = [operator.add, operator.sub, operator.matmul] test_matrix_arrays = [ @@ -21,7 +24,7 @@ ] -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_python_scope_vector_operations(): for ops in vector_operation_types: a, b = test_vector_arrays[:2] @@ -30,7 +33,7 @@ def test_python_scope_vector_operations(): assert np.allclose(c.to_numpy(), ops(a, b)) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_python_scope_matrix_operations(): for ops in operation_types: a, b = test_matrix_arrays[:2] @@ -45,19 +48,19 @@ def test_python_scope_matrix_operations(): # ideally we should use pytest.fixture to parameterize the tests # over explicit loops @pytest.mark.parametrize('ops', vector_operation_types) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_python_scope_vector_field(ops): t1 = ti.Vector.field(2, dtype=ti.i32, shape=()) t2 = ti.Vector.field(2, dtype=ti.i32, shape=()) a, b = test_vector_arrays[:2] t1[None], t2[None] = a.tolist(), b.tolist() - c = ops(t1[None].value, t2[None].value) + c = ops(t1[None], t2[None]) assert np.allclose(c.to_numpy(), ops(a, b)) @pytest.mark.parametrize('ops', vector_operation_types) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_python_scope_matrix_field(ops): t1 = ti.Matrix.field(2, 2, dtype=ti.i32, shape=()) t2 = ti.Matrix.field(2, 2, dtype=ti.i32, shape=()) @@ -65,26 +68,25 @@ def test_python_scope_matrix_field(ops): # ndarray not supported here t1[None], t2[None] = a.tolist(), b.tolist() - c = ops(t1[None].value, t2[None].value) + c = ops(t1[None], t2[None]) print(c) assert np.allclose(c.to_numpy(), ops(a, b)) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_constant_matrices(): - assert ti.cos(ti.math.pi / 3) == approx(0.5) + assert ti.cos(math.pi / 3) == test_utils.approx(0.5) assert np.allclose((-ti.Vector([2, 3])).to_numpy(), np.array([-2, -3])) - assert ti.cos(ti.Vector([2, - 3])).to_numpy() == approx(np.cos(np.array([2, - 3]))) + assert ti.cos(ti.Vector([2, 3])).to_numpy() == test_utils.approx( + np.cos(np.array([2, 3]))) assert ti.max(2, 3) == 3 res = ti.max(4, ti.Vector([3, 4, 5])) assert np.allclose(res.to_numpy(), np.array([4, 4, 5])) res = ti.Vector([2, 3]) + ti.Vector([3, 4]) assert np.allclose(res.to_numpy(), np.array([5, 7])) res = ti.atan2(ti.Vector([2, 3]), ti.Vector([3, 4])) - assert res.to_numpy() == approx( + assert res.to_numpy() == test_utils.approx( np.arctan2(np.array([2, 3]), np.array([3, 4]))) res = ti.Matrix([[2, 3], [4, 5]]) @ ti.Vector([2, 3]) assert np.allclose(res.to_numpy(), np.array([13, 23])) @@ -92,8 +94,8 @@ def test_constant_matrices(): w = ti.Vector([5, -12]) r = ti.Vector([1, 2, 3, 4]) s = ti.Matrix([[1, 2], [3, 4]]) - assert v.normalized().to_numpy() == approx(np.array([0.6, 0.8])) - assert v.cross(w) == approx(-12 * 3 - 4 * 5) + assert v.normalized().to_numpy() == test_utils.approx(np.array([0.6, 0.8])) + assert v.cross(w) == test_utils.approx(-12 * 3 - 4 * 5) w.y = v.x * w[0] r.x = r.y r.y = r.z @@ -120,7 +122,7 @@ def func(t: ti.i32): @pytest.mark.parametrize('ops', vector_operation_types) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_taichi_scope_vector_operations_with_global_vectors(ops): a, b, c = test_vector_arrays[:3] m1, m2 = ti.Vector(a), ti.Vector(b) @@ -136,12 +138,12 @@ def run(): run() - assert np.allclose(r1[None].value.to_numpy(), ops(a, b)) - assert np.allclose(r2[None].value.to_numpy(), ops(a, c)) + assert np.allclose(r1[None].to_numpy(), ops(a, b)) + assert np.allclose(r2[None].to_numpy(), ops(a, c)) @pytest.mark.parametrize('ops', vector_operation_types) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_taichi_scope_matrix_operations_with_global_matrices(ops): a, b, c = test_matrix_arrays[:3] m1, m2 = ti.Matrix(a), ti.Matrix(b) @@ -157,11 +159,11 @@ def run(): run() - assert np.allclose(r1[None].value.to_numpy(), ops(a, b)) - assert np.allclose(r2[None].value.to_numpy(), ops(a, c)) + assert np.allclose(r1[None].to_numpy(), ops(a, b)) + assert np.allclose(r2[None].to_numpy(), ops(a, c)) -@ti.test() +@test_utils.test() def test_matrix_non_constant_index_numpy(): @ti.kernel def func1(a: ti.any_arr(element_dim=2)): @@ -190,7 +192,9 @@ def func2(b: ti.any_arr(element_dim=1, layout=ti.Layout.SOA)): assert v[9][1] == 9 -@ti.test(require=ti.extension.dynamic_index, dynamic_index=True) +@test_utils.test(require=ti.extension.dynamic_index, + dynamic_index=True, + debug=True) def test_matrix_non_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) v = ti.Vector.field(10, ti.i32, 5) @@ -221,10 +225,10 @@ def func2(): @ti.kernel def func3(): - tmp = ti.Vector([1, 2, 3], dt=ti.i32) + tmp = ti.Vector([1, 2, 3]) for i in range(3): tmp[i] = i * i - vec = ti.Vector([4, 5, 6], dt=ti.i32) + vec = ti.Vector([4, 5, 6]) for j in range(3): vec[tmp[i] % 3] += vec[j % 3] assert tmp[0] == 0 @@ -236,12 +240,14 @@ def func3(): @ti.kernel def func4(k: ti.i32): tmp = ti.Vector([k, k * 2, k * 3]) + assert tmp[0] == k + assert tmp[1] == k * 2 + assert tmp[2] == k * 3 - with pytest.raises(Exception): - func4(10) + func4(10) -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_matrix_constant_index(): m = ti.Matrix.field(2, 2, ti.i32, 5) @@ -256,7 +262,7 @@ def func(): assert np.allclose(m.to_numpy(), np.ones((5, 2, 2), np.int32) * 12) -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_vector_to_list(): a = ti.Vector.field(2, float, ()) @@ -266,10 +272,10 @@ def test_vector_to_list(): assert len(b) == len(data) a[None] = b - assert all(a[None].value == ti.Vector(data)) + assert all(a[None] == ti.Vector(data)) -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_matrix_to_list(): a = ti.Matrix.field(2, 3, float, ()) @@ -279,10 +285,10 @@ def test_matrix_to_list(): assert len(b) == len(data) a[None] = b - assert all(a[None].value == ti.Matrix(data)) + assert all(a[None] == ti.Matrix(data)) -@ti.test() +@test_utils.test() def test_matrix_needs_grad(): # Just make sure the usage doesn't crash, see https://github.com/taichi-dev/taichi/pull/1545 n = 8 @@ -296,3 +302,225 @@ def func(): gr[i] = m1.grad[i] + m2.grad[i] func() + + +@test_utils.test(debug=True) +def test_copy_python_scope_matrix_to_taichi_scope(): + a = ti.Vector([1, 2, 3]) + + @ti.kernel + def test(): + b = a + assert b[0] == 1 + assert b[1] == 2 + assert b[2] == 3 + b = ti.Vector([4, 5, 6]) + assert b[0] == 4 + assert b[1] == 5 + assert b[2] == 6 + + test() + + +@test_utils.test(debug=True) +def test_copy_matrix_field_element_to_taichi_scope(): + a = ti.Vector.field(3, ti.i32, shape=()) + a[None] = ti.Vector([1, 2, 3]) + + @ti.kernel + def test(): + b = a[None] + assert b[0] == 1 + assert b[1] == 2 + assert b[2] == 3 + b[0] = 5 + b[1] = 9 + b[2] = 7 + assert b[0] == 5 + assert b[1] == 9 + assert b[2] == 7 + assert a[None][0] == 1 + assert a[None][1] == 2 + assert a[None][2] == 3 + + test() + + +@test_utils.test(debug=True) +def test_copy_matrix_in_taichi_scope(): + @ti.kernel + def test(): + a = ti.Vector([1, 2, 3]) + b = a + assert b[0] == 1 + assert b[1] == 2 + assert b[2] == 3 + b[0] = 5 + b[1] = 9 + b[2] = 7 + assert b[0] == 5 + assert b[1] == 9 + assert b[2] == 7 + assert a[0] == 1 + assert a[1] == 2 + assert a[2] == 3 + + test() + + +@test_utils.test(arch=[ti.cpu, ti.cuda], dynamic_index=True, debug=True) +def test_matrix_field_dynamic_index_stride(): + # placeholders + temp_a = ti.field(ti.f32) + temp_b = ti.field(ti.f32) + temp_c = ti.field(ti.f32) + # target + v = ti.Vector.field(3, ti.i32) + x = v.get_scalar_field(0) + y = v.get_scalar_field(1) + z = v.get_scalar_field(2) + + S0 = ti.root + S1 = S0.pointer(ti.i, 4) + S2 = S1.dense(ti.i, 2) + S3 = S2.pointer(ti.i, 8) + S3.place(temp_a) + S4 = S2.dense(ti.i, 16) + S4.place(x) + S5 = S1.dense(ti.i, 2) + S6 = S5.pointer(ti.i, 8) + S6.place(temp_b) + S7 = S5.dense(ti.i, 16) + S7.place(y) + S8 = S1.dense(ti.i, 2) + S9 = S8.dense(ti.i, 32) + S9.place(temp_c) + S10 = S8.dense(ti.i, 16) + S10.place(z) + + @ti.kernel + def check_stride(): + for i in range(128): + assert ti.get_addr(y, i) - ti.get_addr(x, + i) == v.dynamic_index_stride + assert ti.get_addr(z, i) - ti.get_addr(y, + i) == v.dynamic_index_stride + + check_stride() + + @ti.kernel + def run(): + for i in range(128): + for j in range(3): + v[i][j] = i * j + + run() + for i in range(128): + for j in range(3): + assert v[i][j] == i * j + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_matrix_field_dynamic_index_different_path_length(): + v = ti.Vector.field(2, ti.i32) + x = v.get_scalar_field(0) + y = v.get_scalar_field(1) + + ti.root.dense(ti.i, 8).place(x) + ti.root.dense(ti.i, 2).dense(ti.i, 4).place(y) + + impl.get_runtime().materialize() + assert v.dynamic_index_stride is None + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_matrix_field_dynamic_index_not_pure_dense(): + v = ti.Vector.field(2, ti.i32) + x = v.get_scalar_field(0) + y = v.get_scalar_field(1) + + ti.root.dense(ti.i, 2).pointer(ti.i, 4).place(x) + ti.root.dense(ti.i, 2).dense(ti.i, 4).place(y) + + impl.get_runtime().materialize() + assert v.dynamic_index_stride is None + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_matrix_field_dynamic_index_different_cell_size_bytes(): + temp = ti.field(ti.f32) + + v = ti.Vector.field(2, ti.i32) + x = v.get_scalar_field(0) + y = v.get_scalar_field(1) + + ti.root.dense(ti.i, 8).place(x, temp) + ti.root.dense(ti.i, 8).place(y) + + impl.get_runtime().materialize() + assert v.dynamic_index_stride is None + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_matrix_field_dynamic_index_different_offset_bytes_in_parent_cell(): + temp_a = ti.field(ti.f32) + temp_b = ti.field(ti.f32) + + v = ti.Vector.field(2, ti.i32) + x = v.get_scalar_field(0) + y = v.get_scalar_field(1) + + ti.root.dense(ti.i, 8).place(temp_a, x) + ti.root.dense(ti.i, 8).place(y, temp_b) + + impl.get_runtime().materialize() + assert v.dynamic_index_stride is None + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_matrix_field_dynamic_index_different_stride(): + temp = ti.field(ti.f32) + + v = ti.Vector.field(3, ti.i32) + x = v.get_scalar_field(0) + y = v.get_scalar_field(1) + z = v.get_scalar_field(2) + + ti.root.dense(ti.i, 8).place(x, y, temp, z) + + impl.get_runtime().materialize() + assert v.dynamic_index_stride is None + + +@test_utils.test(arch=[ti.cpu, ti.cuda], dynamic_index=True) +def test_matrix_field_dynamic_index_multiple_materialize(): + @ti.kernel + def empty(): + pass + + empty() + + n = 5 + a = ti.Vector.field(3, dtype=ti.i32, shape=n) + + @ti.kernel + def func(): + for i in a: + a[i][i % 3] = i + + func() + for i in range(n): + for j in range(3): + assert a[i][j] == (i if j == i % 3 else 0) + + +@test_utils.test(arch=[ti.cpu, ti.cuda], dynamic_index=True, debug=True) +def test_local_vector_initialized_in_a_loop(): + @ti.kernel + def foo(): + for c in range(10): + p = ti.Vector([c, c * 2]) + for i in range(2): + assert p[i] == c * (i + 1) + + foo() diff --git a/tests/python/test_matrix_arg.py b/tests/python/test_matrix_arg.py new file mode 100644 index 0000000000000..f07722f5fea82 --- /dev/null +++ b/tests/python/test_matrix_arg.py @@ -0,0 +1,37 @@ +import pytest + +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_matrix_arg(): + mat1 = ti.Matrix([[1, 2, 3], [4, 5, 6]]) + + @ti.kernel + def foo(mat: ti.types.matrix(2, 3, ti.i32)) -> ti.i32: + return mat[0, 0] + mat[1, 2] + + assert foo(mat1) == 7 + + mat3 = ti.Matrix([[1, 2], [3, 4], [5, 6]]) + + @ti.kernel + def foo2(var: ti.i32, mat: ti.types.matrix(3, 2, ti.i32)) -> ti.i32: + for i in ti.static(range(3)): + for j in ti.static(range(2)): + mat[i, j] += var + return mat[2, 1] + + assert foo2(3, mat3) == 9 + + +@test_utils.test() +def test_vector_arg(): + vec1 = ti.Vector([1, 2, 3]) + + @ti.kernel + def foo(vec: ti.types.vector(3, ti.i32)) -> int: + return vec[0] + vec[1] + vec[2] + + assert foo(vec1) == 6 diff --git a/tests/python/test_matrix_different_type.py b/tests/python/test_matrix_different_type.py index 30aeebe7fc01e..5fd00d78eeb2f 100644 --- a/tests/python/test_matrix_different_type.py +++ b/tests/python/test_matrix_different_type.py @@ -1,10 +1,11 @@ from pytest import approx import taichi as ti +from tests import test_utils # TODO: test more matrix operations -@ti.test() +@test_utils.test() def test_vector(): type_list = [ti.f32, ti.i32] @@ -31,7 +32,7 @@ def verify(): # TODO: Support different element types of Matrix on opengl -@ti.test(require=ti.extension.data64, exclude=ti.opengl) +@test_utils.test(require=ti.extension.data64, exclude=ti.opengl) def test_matrix(): type_list = [[ti.f32, ti.i32], [ti.i64, ti.f32]] a = ti.Matrix.field(len(type_list), @@ -67,12 +68,12 @@ def verify(): verify() -@ti.test(require=ti.extension.quant_basic) +@test_utils.test(require=ti.extension.quant_basic) def test_custom_type(): - cit1 = ti.quant.int(bits=10, signed=True) - cft1 = ti.type_factory.custom_float(cit1, scale=0.1) - cit2 = ti.quant.int(bits=22, signed=False) - cft2 = ti.type_factory.custom_float(cit2, scale=0.1) + cit1 = ti.types.quantized_types.quant.int(bits=10, signed=True) + cft1 = ti.types.quantized_types.type_factory.custom_float(cit1, scale=0.1) + cit2 = ti.types.quantized_types.quant.int(bits=22, signed=False) + cft2 = ti.types.quantized_types.type_factory.custom_float(cit2, scale=0.1) type_list = [[cit1, cft2], [cft1, cit2]] a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) diff --git a/tests/python/test_matrix_return.py b/tests/python/test_matrix_return.py new file mode 100644 index 0000000000000..846171e4f4816 --- /dev/null +++ b/tests/python/test_matrix_return.py @@ -0,0 +1,31 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_arch(): + @ti.kernel + def func() -> ti.types.vector(3, ti.i32): + return ti.Vector([1, 2, 3]) + + assert func()[1] == 2 + + +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.metal]) +def test_ret_i16(): + @ti.kernel + def func() -> ti.types.matrix(2, 3, ti.i16): + return ti.Matrix([[1, 2, 3], [4, 5, 6]]) + + assert func()[1, 2] == 6 + + +@test_utils.test() +def test_arch_exceed_limit(): + @ti.kernel + def func() -> ti.types.matrix(3, 10, ti.i32): + return ti.Matrix([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]]) + + assert func()[1, 2] == 12 diff --git a/tests/python/test_memory.py b/tests/python/test_memory.py index f9e46179c2a46..f90c2fc307e7d 100644 --- a/tests/python/test_memory.py +++ b/tests/python/test_memory.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(arch=ti.cuda) +@test_utils.test(arch=ti.cuda) def test_memory_allocate(): HUGE_SIZE = 1024**2 * 128 x = ti.field(ti.i32, shape=(HUGE_SIZE, )) diff --git a/tests/python/test_mesh.py b/tests/python/test_mesh.py new file mode 100644 index 0000000000000..19a1b178e3dd1 --- /dev/null +++ b/tests/python/test_mesh.py @@ -0,0 +1,344 @@ +import os + +import numpy as np + +import taichi as ti +from tests import test_utils + +this_dir = os.path.dirname(os.path.abspath(__file__)) +model_file_path = os.path.join(this_dir, 'ell.json') + + +@test_utils.test(require=ti.extension.mesh) +def test_mesh_patch_idx(): + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'idx': ti.i32}) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + + @ti.kernel + def foo(): + for v in model.verts: + v.idx = ti.mesh_patch_idx() + + foo() + idx = model.verts.idx.to_numpy() + assert idx[0] == 6 + assert idx.sum() == 89 + + +def _test_mesh_for(cell_reorder=False, vert_reorder=False, extra_tests=True): + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'t': ti.i32}, reorder=vert_reorder) + mesh_builder.cells.place({'t': ti.i32}, reorder=cell_reorder) + mesh_builder.cells.link(mesh_builder.verts) + mesh_builder.verts.link(mesh_builder.cells) + mesh_builder.cells.link(mesh_builder.cells) + mesh_builder.verts.link(mesh_builder.verts) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + + @ti.kernel + def cell_vert(): + for c in model.cells: + for j in range(c.verts.size): + c.t += c.verts[j].id + + cell_vert() + total = model.cells.t.to_numpy().sum() + model.cells.t.fill(0) + assert total == 892 + + @ti.kernel + def vert_cell(): + for v in model.verts: + for j in range(v.cells.size): + v.t += v.cells[j].id + + vert_cell() + total = model.verts.t.to_numpy().sum() + model.verts.t.fill(0) + assert total == 1104 + + if not extra_tests: + return + + @ti.kernel + def cell_cell(): + for c in model.cells: + for j in range(c.cells.size): + c.t += c.cells[j].id + + cell_cell() + total = model.cells.t.to_numpy().sum() + model.cells.t.fill(0) + assert total == 690 + + @ti.kernel + def vert_vert(): + for v in model.verts: + for j in range(v.verts.size): + v.t += v.verts[j].id + + vert_vert() + total = model.verts.t.to_numpy().sum() + model.verts.t.fill(0) + assert total == 1144 + + +@test_utils.test(require=ti.extension.mesh) +def test_mesh_for(): + _test_mesh_for(False, False) + _test_mesh_for(False, True) + + +@test_utils.test(require=ti.extension.mesh, + optimize_mesh_reordered_mapping=False) +def test_mesh_reordered_opt(): + _test_mesh_for(True, True, False) + + +@test_utils.test(require=ti.extension.mesh, mesh_localize_to_end_mapping=False) +def test_mesh_localize_mapping0(): + _test_mesh_for(False, False, False) + _test_mesh_for(True, True, False) + + +@test_utils.test(require=ti.extension.mesh, + mesh_localize_from_end_mapping=True) +def test_mesh_localize_mapping1(): + _test_mesh_for(False, False, False) + _test_mesh_for(True, True, False) + + +@test_utils.test(require=ti.extension.mesh) +def test_mesh_reorder(): + vec3i = ti.types.vector(3, ti.i32) + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'s': ti.i32, 's3': vec3i}, reorder=True) + mesh_builder.cells.link(mesh_builder.verts) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + + id2 = np.array([x**2 for x in range(len(model.verts))]) + id123 = np.array([[x**1, x**2, x**3] for x in range(len(model.verts))]) + model.verts.s.from_numpy(id2) + model.verts.s3.from_numpy(id123) + + @ti.kernel + def foo(): + for v in model.verts: + assert v.s == v.id**2 + assert v.s3[0] == v.id**1 and v.s3[1] == v.id**2 and v.s3[ + 2] == v.id**3 + v.s = v.id**3 + v.s3 *= v.id + + foo() + + id3 = model.verts.s.to_numpy() + id234 = model.verts.s3.to_numpy() + + for i in range(len(model.verts)): + assert model.verts.s[i] == i**3 + assert id3[i] == i**3 + assert model.verts.s3[i][0] == i**2 + assert model.verts.s3[i][1] == i**3 + assert model.verts.s3[i][2] == i**4 + assert id234[i][0] == i**2 + assert id234[i][1] == i**3 + assert id234[i][2] == i**4 + + +@test_utils.test(require=ti.extension.mesh) +def test_mesh_minor_relations(): + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'y': ti.i32}) + mesh_builder.edges.place({'x': ti.i32}) + mesh_builder.cells.link(mesh_builder.edges) + mesh_builder.verts.link(mesh_builder.cells) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + model.edges.x.fill(1) + + @ti.kernel + def foo(): + for v in model.verts: + for i in range(v.cells.size): + c = v.cells[i] + for j in range(c.edges.size): + e = c.edges[j] + v.y += e.x + + foo() + total = model.verts.y.to_numpy().sum() + assert total == 576 + + +@test_utils.test(require=ti.extension.mesh, demote_no_access_mesh_fors=True) +def test_multiple_meshes(): + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'y': ti.i32}) + meta = ti.Mesh.load_meta(model_file_path) + model1 = mesh_builder.build(meta) + model2 = mesh_builder.build(meta) + + model1.verts.y.from_numpy( + np.array([x**2 for x in range(len(model1.verts))])) + + @ti.kernel + def foo(): + for v in model1.verts: + model2.verts.y[v.id] = v.y + + foo() + out = model2.verts.y.to_numpy() + for i in range(len(out)): + assert out[i] == i**2 + + +@test_utils.test(require=ti.extension.mesh) +def test_mesh_local(): + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'a': ti.i32}) + mesh_builder.faces.link(mesh_builder.verts) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + ext_a = ti.field(ti.i32, shape=len(model.verts)) + + @ti.kernel + def foo(cache: ti.template()): + if ti.static(cache): + ti.mesh_local(ext_a, model.verts.a) + for f in model.faces: + m = f.verts[0].id + f.verts[1].id + f.verts[2].id + f.verts[0].a += m + f.verts[1].a += m + f.verts[2].a += m + ext_a[f.verts[0].id] += m + ext_a[f.verts[1].id] += m + ext_a[f.verts[2].id] += m + + foo(False) + res1 = model.verts.a.to_numpy() + res2 = ext_a.to_numpy() + model.verts.a.fill(0) + ext_a.fill(0) + foo(True) + res3 = model.verts.a.to_numpy() + res4 = ext_a.to_numpy() + + for i in range(len(model.verts)): + assert res1[i] == res2[i] + assert res1[i] == res3[i] + assert res1[i] == res4[i] + + +@test_utils.test(require=ti.extension.mesh, experimental_auto_mesh_local=True) +def test_auto_mesh_local(): + mesh_builder = ti.Mesh.Tet() + mesh_builder.verts.place({'a': ti.i32, 's': ti.i32}) + mesh_builder.faces.link(mesh_builder.verts) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + ext_a = ti.field(ti.i32, shape=len(model.verts)) + + @ti.kernel + def foo(cache: ti.template()): + for v in model.verts: + v.s = v.id + if ti.static(cache): + ti.mesh_local(ext_a, model.verts.a) + for f in model.faces: + m = f.verts[0].s + f.verts[1].s + f.verts[2].s + f.verts[0].a += m + f.verts[1].a += m + f.verts[2].a += m + for i in range(3): + ext_a[f.verts[i].id] += m + + foo(False) + res1 = model.verts.a.to_numpy() + res2 = ext_a.to_numpy() + model.verts.a.fill(0) + ext_a.fill(0) + foo(True) + res3 = model.verts.a.to_numpy() + res4 = ext_a.to_numpy() + + for i in range(len(model.verts)): + assert res1[i] == res2[i] + assert res1[i] == res3[i] + assert res1[i] == res4[i] + + +@test_utils.test(require=ti.extension.mesh) +def test_nested_mesh_for(): + mesh_builder = ti.Mesh.Tet() + mesh_builder.faces.place({'a': ti.i32, 'b': ti.i32}) + mesh_builder.faces.link(mesh_builder.verts) + model = mesh_builder.build(ti.Mesh.load_meta(model_file_path)) + + @ti.kernel + def foo(): + for f in model.faces: + for i in range(f.verts.size): + f.a += f.verts[i].id + for v in f.verts: + f.b += v.id + + a = model.faces.a.to_numpy() + b = model.faces.b.to_numpy() + assert (a == b).all() == 1 + + +@test_utils.test(require=ti.extension.mesh) +def test_multiple_mesh_major_relations(): + mesh = ti.TetMesh() + mesh.verts.place({ + 's': ti.i32, + 's_': ti.i32, + 's1': ti.i32, + 'a': ti.i32, + 'b': ti.i32, + 'c': ti.i32 + }) + mesh.edges.place({'s2': ti.i32}) + mesh.cells.place({'s3': ti.i32}) + mesh.verts.link(mesh.verts) + mesh.verts.link(mesh.edges) + mesh.verts.link(mesh.cells) + + model = mesh.build(ti.Mesh.load_meta(model_file_path)) + + @ti.kernel + def foo(): + for u in model.verts: + u.s1 = u.id + for e in model.edges: + e.s2 = e.id + for c in model.cells: + c.s3 = c.id + + ti.mesh_local(model.verts.s1, model.edges.s2, model.cells.s3) + for u in model.verts: + a, b, c = 0, 0, 0 + for i in range(u.verts.size): + a += u.verts[i].s1 + for i in range(u.edges.size): + b += u.edges[i].s2 + for i in range(u.cells.size): + c += u.cells[i].s3 + u.s = a * b * c + + for u in model.verts: + for i in range(u.verts.size): + u.a += u.verts[i].s1 + for u in model.verts: + for i in range(u.edges.size): + u.b += u.edges[i].s2 + for u in model.verts: + for i in range(u.cells.size): + u.c += u.cells[i].s3 + for u in model.verts: + u.s_ = u.a * u.b * u.c + + foo() + + sum1 = model.verts.s.to_numpy().sum() + sum2 = model.verts.s_.to_numpy().sum() + assert sum1 == sum2 diff --git a/tests/python/test_mod.py b/tests/python/test_mod.py index 0490c7a0b126c..987b1d7c46920 100644 --- a/tests/python/test_mod.py +++ b/tests/python/test_mod.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def _test_py_style_mod(arg1, a, arg2, b, arg3, c): z = ti.field(arg3, shape=()) @@ -13,7 +14,7 @@ def func(x: arg1, y: arg2): assert z[None] == c -@ti.test() +@test_utils.test() def _test_c_style_mod(arg1, a, arg2, b, arg3, c): z = ti.field(arg3, shape=()) @@ -49,7 +50,7 @@ def func(a, b): func(-10, -3) -@ti.test() +@test_utils.test() def test_mod_scan(): z = ti.field(ti.i32, shape=()) w = ti.field(ti.i32, shape=()) diff --git a/tests/python/test_module_import.py b/tests/python/test_module_import.py new file mode 100644 index 0000000000000..35f4ff4a9ae6e --- /dev/null +++ b/tests/python/test_module_import.py @@ -0,0 +1,12 @@ +import taichi as myowntaichi +from tests import test_utils + + +@test_utils.test() +def test_module_import(): + @myowntaichi.kernel + def func(): + for _ in myowntaichi.static(range(8)): + pass + + func() diff --git a/tests/python/test_mpm88.py b/tests/python/test_mpm88.py index 20a661ade1a22..c52d30acc8c72 100644 --- a/tests/python/test_mpm88.py +++ b/tests/python/test_mpm88.py @@ -3,7 +3,7 @@ import pytest import taichi as ti -from taichi import approx +from tests import test_utils def run_mpm88_test(): @@ -33,16 +33,15 @@ def substep(): fx = x[p] * inv_dx - base.cast(float) w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] stress = -dt * p_vol * (J[p] - 1) * 4 * inv_dx * inv_dx * E - affine = ti.Matrix([[stress, 0], [0, stress]], - dt=ti.f32) + p_mass * C[p] + affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p] for i in ti.static(range(3)): for j in ti.static(range(3)): offset = ti.Vector([i, j]) dpos = (offset.cast(float) - fx) * dx weight = w[i][0] * w[j][1] - grid_v[base + offset].atomic_add( - weight * (p_mass * v[p] + affine @ dpos)) - grid_m[base + offset].atomic_add(weight * p_mass) + ti.atomic_add(grid_v[base + offset], + weight * (p_mass * v[p] + affine @ dpos)) + ti.atomic_add(grid_m[base + offset], weight * p_mass) for i, j in grid_m: if grid_m[i, j] > 0: @@ -79,9 +78,6 @@ def substep(): J[p] *= 1 + dt * new_C.trace() C[p] = new_C - # gui = ti.core.GUI("MPM88", ti.veci(512, 512)) - # canvas = gui.get_canvas() - for i in range(n_particles): x[i] = [i % N / N * 0.4 + 0.2, i / N / N * 0.4 + 0.05] v[i] = [0, -3] @@ -102,10 +98,11 @@ def substep(): 0.07810827, ] for i in range(4): - assert (pos**(i + 1)).mean() == approx(regression[i], rel=1e-2) + assert (pos**(i + 1)).mean() == test_utils.approx(regression[i], + rel=1e-2) -@ti.test() +@test_utils.test() def test_mpm88(): run_mpm88_test() @@ -118,7 +115,9 @@ def _is_appveyor(): #TODO: Remove exclude of ti.metal @pytest.mark.skipif(_is_appveyor(), reason='Stuck on Appveyor.') -@ti.test(require=ti.extension.async_mode, exclude=[ti.metal], async_mode=True) +@test_utils.test(require=ti.extension.async_mode, + exclude=[ti.metal], + async_mode=True) def test_mpm88_async(): # It seems that all async tests on Appveyor run super slow. For example, # on Appveyor, 10+ tests have passed during the execution of @@ -126,7 +125,7 @@ def test_mpm88_async(): run_mpm88_test() -@ti.test(exclude=[ti.vulkan]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.opengl]) def test_mpm88_numpy_and_ndarray(): import numpy as np @@ -151,16 +150,15 @@ def substep(x: ti.any_arr(element_dim=1), v: ti.any_arr(element_dim=1), fx = x[p] * inv_dx - base.cast(float) w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2] stress = -dt * p_vol * (J[p] - 1) * 4 * inv_dx * inv_dx * E - affine = ti.Matrix([[stress, 0], [0, stress]], - dt=ti.f32) + p_mass * C[p] + affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p] for i in ti.static(range(3)): for j in ti.static(range(3)): offset = ti.Vector([i, j]) dpos = (offset.cast(float) - fx) * dx weight = w[i][0] * w[j][1] - grid_v[base + offset].atomic_add( - weight * (p_mass * v[p] + affine @ dpos)) - grid_m[base + offset].atomic_add(weight * p_mass) + ti.atomic_add(grid_v[base + offset], + weight * (p_mass * v[p] + affine @ dpos)) + ti.atomic_add(grid_m[base + offset], weight * p_mass) for i, j in grid_m: if grid_m[i, j] > 0: @@ -218,7 +216,8 @@ def run_test(x, v, C, J, grid_v, grid_m): 0.07810827, ] for i in range(4): - assert (pos**(i + 1)).mean() == approx(regression[i], rel=1e-2) + assert (pos**(i + 1)).mean() == test_utils.approx(regression[i], + rel=1e-2) def test_numpy(): x = np.zeros((n_particles, dim), dtype=np.float32) diff --git a/tests/python/test_mpm_particle_list.py b/tests/python/test_mpm_particle_list.py index a5d40615d2771..3724713678265 100644 --- a/tests/python/test_mpm_particle_list.py +++ b/tests/python/test_mpm_particle_list.py @@ -1,6 +1,7 @@ import random import taichi as ti +from tests import test_utils @ti.data_oriented @@ -20,7 +21,7 @@ def __init__(self, res): voxel = block.dense(indices, 8) voxel.place(self.grid_m) - block.dynamic(ti.indices(dim), 1024 * 1024, + block.dynamic(ti.axes(dim), 1024 * 1024, chunk_size=4096).place(self.pid) ti.root.dynamic(ti.i, 2**25, 2**20).place(self.x) @@ -43,17 +44,19 @@ def step(self): self.build_pid() -@ti.test(require=ti.extension.sparse, exclude=[ti.metal], device_memory_GB=1.0) +@test_utils.test(require=ti.extension.sparse, + exclude=[ti.metal], + device_memory_GB=1.0) def test_mpm_particle_list_no_leakage(): # By default Taichi will allocate 0.5 GB for testing. mpm = MPMSolver(res=(128, 128)) mpm.step() -@ti.test(require=[ti.extension.sparse, ti.extension.packed], - exclude=[ti.metal], - device_memory_GB=1.0, - packed=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], + exclude=[ti.metal], + device_memory_GB=1.0, + packed=True) def test_mpm_particle_list_no_leakage_packed(): # By default Taichi will allocate 0.5 GB for testing. mpm = MPMSolver(res=(128, 128)) diff --git a/tests/python/test_name_error.py b/tests/python/test_name_error.py new file mode 100644 index 0000000000000..e604c436d423d --- /dev/null +++ b/tests/python/test_name_error.py @@ -0,0 +1,15 @@ +import pytest + +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_name_error(): + with pytest.raises(ti.TaichiNameError, match='Name "a" is not defined'): + + @ti.kernel + def foo(): + a + 1 + + foo() diff --git a/tests/python/test_native_functions.py b/tests/python/test_native_functions.py index 4d92930b8cab2..a0d537d0116ed 100644 --- a/tests/python/test_native_functions.py +++ b/tests/python/test_native_functions.py @@ -1,9 +1,10 @@ import numpy as np import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_abs(): x = ti.field(ti.f32) @@ -24,7 +25,7 @@ def func(): assert x[i] == i -@ti.test() +@test_utils.test() def test_int(): x = ti.field(ti.f32) @@ -47,7 +48,7 @@ def func(): assert x[i] == i // 2 -@ti.test() +@test_utils.test() def test_minmax(): x = ti.field(ti.f32) y = ti.field(ti.f32) diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index ef703d91b898e..4cb6ebd70b3d2 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -1,7 +1,16 @@ +import copy + import numpy as np import pytest +from taichi.lang import impl +from taichi.lang.misc import get_host_arch_list +from taichi.lang.util import has_pytorch import taichi as ti +from tests import test_utils + +if has_pytorch(): + import torch # properties @@ -9,76 +18,87 @@ ndarray_shapes = [(), 8, (6, 12)] vector_dims = [3] matrix_dims = [(1, 2), (2, 3)] +supported_archs_taichi_ndarray = [ti.cpu, ti.cuda, ti.opengl, ti.vulkan] -@pytest.mark.parametrize('dtype', data_types) -@pytest.mark.parametrize('shape', ndarray_shapes) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(arch=ti.get_host_arch_list()) -def test_scalar_ndarray(dtype, shape): +def _test_scalar_ndarray(dtype, shape): x = ti.ndarray(dtype, shape) if isinstance(shape, tuple): assert x.shape == shape else: assert x.shape == (shape, ) + assert x.element_shape == () assert x.dtype == dtype -@pytest.mark.parametrize('n', vector_dims) @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', ndarray_shapes) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(arch=ti.get_host_arch_list()) -def test_vector_ndarray(n, dtype, shape): +@test_utils.test(arch=get_host_arch_list()) +def test_scalar_ndarray(dtype, shape): + _test_scalar_ndarray(dtype, shape) + + +def _test_vector_ndarray(n, dtype, shape): x = ti.Vector.ndarray(n, dtype, shape) if isinstance(shape, tuple): assert x.shape == shape else: assert x.shape == (shape, ) + assert x.element_shape == (n, ) assert x.dtype == dtype assert x.n == n -@pytest.mark.parametrize('n,m', matrix_dims) +@pytest.mark.parametrize('n', vector_dims) @pytest.mark.parametrize('dtype', data_types) @pytest.mark.parametrize('shape', ndarray_shapes) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(arch=ti.get_host_arch_list()) -def test_matrix_ndarray(n, m, dtype, shape): +@test_utils.test(arch=get_host_arch_list()) +def test_vector_ndarray(n, dtype, shape): + _test_vector_ndarray(n, dtype, shape) + + +def _test_matrix_ndarray(n, m, dtype, shape): x = ti.Matrix.ndarray(n, m, dtype, shape) if isinstance(shape, tuple): assert x.shape == shape else: assert x.shape == (shape, ) + assert x.element_shape == (n, m) assert x.dtype == dtype assert x.n == n assert x.m == m +@pytest.mark.parametrize('n,m', matrix_dims) +@pytest.mark.parametrize('dtype', data_types) +@pytest.mark.parametrize('shape', ndarray_shapes) +@test_utils.test(arch=get_host_arch_list()) +def test_matrix_ndarray(n, m, dtype, shape): + _test_matrix_ndarray(n, m, dtype, shape) + + @pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') def test_default_fp_ndarray(dtype): - ti.init(default_fp=dtype) + ti.init(arch=supported_archs_taichi_ndarray, default_fp=dtype) x = ti.Vector.ndarray(2, float, ()) - assert x.dtype == ti.get_runtime().default_fp + assert x.dtype == impl.get_runtime().default_fp @pytest.mark.parametrize('dtype', [ti.i32, ti.i64]) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') def test_default_ip_ndarray(dtype): - ti.init(default_ip=dtype) + ti.init(arch=supported_archs_taichi_ndarray, default_ip=dtype) x = ti.Vector.ndarray(2, int, ()) - assert x.dtype == ti.get_runtime().default_ip + assert x.dtype == impl.get_runtime().default_ip # access @@ -86,9 +106,28 @@ def test_default_ip_ndarray(dtype): layouts = [ti.Layout.SOA, ti.Layout.AOS] -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) -def test_ndarray_2d(): +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_1d(): + n = 4 + + @ti.kernel + def run(x: ti.any_arr(), y: ti.any_arr()): + for i in range(n): + x[i] += i + y[i] + + a = ti.ndarray(ti.i32, shape=(n, )) + for i in range(n): + a[i] = i * i + b = np.ones((n, ), dtype=np.int32) + run(a, b) + for i in range(n): + assert a[i] == i * i + i + 1 + run(b, a) + for i in range(n): + assert b[i] == i * i + (i + 1) * 2 + + +def _test_ndarray_2d(): n = 4 m = 7 @@ -113,9 +152,158 @@ def run(x: ti.any_arr(), y: ti.any_arr()): assert b[i, j] == i * j + (i + j + 1) * 2 -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) -def test_ndarray_numpy_io(): +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_2d(): + _test_ndarray_2d() + + +def _test_ndarray_copy_from_ndarray(): + n = 16 + a = ti.ndarray(ti.i32, shape=n) + b = ti.ndarray(ti.i32, shape=n) + a[0] = 1 + a[4] = 2 + b[0] = 4 + b[4] = 5 + + a.copy_from(b) + + assert a[0] == 4 + assert a[4] == 5 + + x = ti.Vector.ndarray(10, ti.i32, 5, layout=ti.Layout.SOA) + y = ti.Vector.ndarray(10, ti.i32, 5, layout=ti.Layout.SOA) + x[1][0] = 1 + x[2][4] = 2 + y[1][0] = 4 + y[2][4] = 5 + + x.copy_from(y) + + assert x[1][0] == 4 + assert x[2][4] == 5 + + x = ti.Matrix.ndarray(2, 2, ti.i32, 5, layout=ti.Layout.AOS) + y = ti.Matrix.ndarray(2, 2, ti.i32, 5, layout=ti.Layout.AOS) + x[0][0, 0] = 1 + x[4][1, 0] = 3 + y[0][0, 0] = 4 + y[4][1, 0] = 6 + + x.copy_from(y) + + assert x[0][0, 0] == 4 + assert x[4][1, 0] == 6 + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_copy_from_ndarray(): + _test_ndarray_copy_from_ndarray() + + +def _test_ndarray_deepcopy(): + n = 16 + x = ti.ndarray(ti.i32, shape=n) + x[0] = 1 + x[4] = 2 + + y = copy.deepcopy(x) + + assert y.shape == x.shape + assert y.dtype == x.dtype + assert y[0] == 1 + assert y[4] == 2 + x[0] = 4 + x[4] = 5 + assert y[0] == 1 + assert y[4] == 2 + + x = ti.Vector.ndarray(10, ti.i32, 5, layout=ti.Layout.SOA) + x[1][0] = 4 + x[2][4] = 5 + + y = copy.deepcopy(x) + + assert y.shape == x.shape + assert y.dtype == x.dtype + assert y.n == x.n + assert y.layout == x.layout + assert y[1][0] == 4 + assert y[2][4] == 5 + x[1][0] = 1 + x[2][4] = 2 + assert y[1][0] == 4 + assert y[2][4] == 5 + + x = ti.Matrix.ndarray(2, 2, ti.i32, 5, layout=ti.Layout.AOS) + x[0][0, 0] = 7 + x[4][1, 0] = 9 + + y = copy.deepcopy(x) + + assert y.shape == x.shape + assert y.dtype == x.dtype + assert y.m == x.m + assert y.n == x.n + assert y.layout == x.layout + assert y[0][0, 0] == 7 + assert y[4][1, 0] == 9 + x[0][0, 0] = 3 + x[4][1, 0] = 5 + assert y[0][0, 0] == 7 + assert y[4][1, 0] == 9 + + +def test_ndarray_cuda_caching_allocator(): + ti.init(arch=ti.cuda, ndarray_use_cached_allocator=True) + n = 8 + a = ti.ndarray(ti.i32, shape=(n)) + a.fill(2) + a = 1 + b = ti.ndarray(ti.i32, shape=(n)) + b.fill(2) + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_fill(): + n = 8 + a = ti.ndarray(ti.i32, shape=(n)) + anp = np.ones((n, ), dtype=np.int32) + a.fill(2) + anp.fill(2) + assert (a.to_numpy() == anp).all() + + b = ti.Vector.ndarray(4, ti.f32, shape=(n)) + bnp = np.ones(shape=b.arr.shape, dtype=np.float32) + b.fill(2.5) + bnp.fill(2.5) + assert (b.to_numpy() == bnp).all() + + c = ti.Matrix.ndarray(4, 4, ti.f32, shape=(n)) + cnp = np.ones(shape=c.arr.shape, dtype=np.float32) + c.fill(1.5) + cnp.fill(1.5) + assert (c.to_numpy() == cnp).all() + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_rw_cache(): + a = ti.Vector.ndarray(3, ti.f32, ()) + b = ti.Vector.ndarray(3, ti.f32, 12) + + n = 1000 + for i in range(n): + c_a = copy.deepcopy(a) + c_b = copy.deepcopy(b) + c_a[None] = c_b[10] + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_deepcopy(): + _test_ndarray_deepcopy() + + +def _test_ndarray_numpy_io(): n = 7 m = 4 a = ti.ndarray(ti.i32, shape=(n, m)) @@ -124,11 +312,30 @@ def test_ndarray_numpy_io(): b.from_numpy(np.ones((n, m), dtype=np.int32) * 2) assert (a.to_numpy() == b.to_numpy()).all() + d = 2 + p = 4 + x = ti.Vector.ndarray(d, ti.f32, p) + x.fill(2) + y = ti.Vector.ndarray(d, ti.f32, p) + y.from_numpy(np.ones((p, d), dtype=np.int32) * 2) + assert (x.to_numpy() == y.to_numpy()).all() + + c = 2 + d = 2 + p = 4 + x = ti.Matrix.ndarray(c, d, ti.f32, p) + x.fill(2) + y = ti.Matrix.ndarray(c, d, ti.f32, p) + y.from_numpy(np.ones((p, c, d), dtype=np.int32) * 2) + assert (x.to_numpy() == y.to_numpy()).all() + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_numpy_io(): + _test_ndarray_numpy_io() -@pytest.mark.parametrize('layout', layouts) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) -def test_matrix_ndarray_python_scope(layout): + +def _test_matrix_ndarray_python_scope(layout): a = ti.Matrix.ndarray(2, 2, ti.i32, 5, layout=layout) for i in range(5): for j, k in ti.ndrange(2, 2): @@ -141,9 +348,12 @@ def test_matrix_ndarray_python_scope(layout): @pytest.mark.parametrize('layout', layouts) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) -def test_matrix_ndarray_taichi_scope(layout): +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_matrix_ndarray_python_scope(layout): + _test_matrix_ndarray_python_scope(layout) + + +def _test_matrix_ndarray_taichi_scope(layout): @ti.kernel def func(a: ti.any_arr()): for i in range(5): @@ -160,9 +370,12 @@ def func(a: ti.any_arr()): @pytest.mark.parametrize('layout', layouts) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) -def test_matrix_ndarray_taichi_scope_struct_for(layout): +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_matrix_ndarray_taichi_scope(layout): + _test_matrix_ndarray_taichi_scope(layout) + + +def _test_matrix_ndarray_taichi_scope_struct_for(layout): @ti.kernel def func(a: ti.any_arr()): for i in a: @@ -179,14 +392,19 @@ def func(a: ti.any_arr()): @pytest.mark.parametrize('layout', layouts) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_matrix_ndarray_taichi_scope_struct_for(layout): + _test_matrix_ndarray_taichi_scope_struct_for(layout) + + +@pytest.mark.parametrize('layout', layouts) +@test_utils.test(arch=supported_archs_taichi_ndarray) def test_vector_ndarray_python_scope(layout): a = ti.Vector.ndarray(10, ti.i32, 5, layout=layout) for i in range(5): for j in range(4): a[i][j * j] = j * j - assert a[0][6] == 0 + assert a[0][9] == 9 assert a[1][0] == 0 assert a[2][1] == 1 assert a[3][4] == 4 @@ -194,8 +412,7 @@ def test_vector_ndarray_python_scope(layout): @pytest.mark.parametrize('layout', layouts) -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@test_utils.test(arch=supported_archs_taichi_ndarray) def test_vector_ndarray_taichi_scope(layout): @ti.kernel def func(a: ti.any_arr()): @@ -205,7 +422,7 @@ def func(a: ti.any_arr()): v = ti.Vector.ndarray(10, ti.i32, 5, layout=layout) func(v) - assert v[0][6] == 0 + assert v[0][9] == 9 assert v[1][0] == 0 assert v[2][1] == 1 assert v[3][4] == 4 @@ -215,9 +432,7 @@ def func(a: ti.any_arr()): # number of compiled functions -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) -def test_compiled_functions(): +def _test_compiled_functions(): @ti.kernel def func(a: ti.any_arr(element_dim=1)): for i in range(5): @@ -226,25 +441,27 @@ def func(a: ti.any_arr(element_dim=1)): v = ti.Vector.ndarray(10, ti.i32, 5) func(v) - assert ti.get_runtime().get_num_compiled_functions() == 1 + assert impl.get_runtime().get_num_compiled_functions() == 1 v = np.zeros((6, 10), dtype=np.int32) func(v) - assert ti.get_runtime().get_num_compiled_functions() == 1 - import torch - v = torch.zeros((6, 11), dtype=torch.int32) + assert impl.get_runtime().get_num_compiled_functions() == 1 + v = np.zeros((6, 11), dtype=np.int32) func(v) - assert ti.get_runtime().get_num_compiled_functions() == 2 + assert impl.get_runtime().get_num_compiled_functions() == 2 v = ti.Vector.ndarray(10, ti.i32, 5, layout=ti.Layout.SOA) func(v) - assert ti.get_runtime().get_num_compiled_functions() == 3 + assert impl.get_runtime().get_num_compiled_functions() == 3 + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_compiled_functions(): + _test_compiled_functions() # annotation compatibility -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(arch=ti.get_host_arch_list()) -def test_arg_not_match(): +def _test_arg_not_match(): @ti.kernel def func1(a: ti.any_arr(element_dim=1)): pass @@ -292,3 +509,71 @@ def func4(a: ti.any_arr(layout=ti.Layout.SOA)): r'Invalid argument into ti\.any_arr\(\) - required layout=Layout\.SOA, but .* is provided' ): func4(x) + + @ti.kernel + def func5(a: ti.any_arr(element_shape=(2, 3))): + pass + + x = ti.Vector.ndarray(2, ti.i32, shape=(4, 7)) + with pytest.raises( + ValueError, + match= + r'Invalid argument into ti\.any_arr\(\) - required element_dim'): + func5(x) + + with pytest.raises( + ValueError, + match=r'Both element_shape and element_dim are specified'): + + @ti.kernel + def func6(a: ti.any_arr(element_dim=1, element_shape=(2, 3))): + pass + + @ti.kernel + def func7(a: ti.any_arr(field_dim=2)): + pass + + x = ti.ndarray(ti.i32, shape=(3, )) + with pytest.raises( + ValueError, + match=r'Invalid argument into ti\.any_arr\(\) - required field_dim' + ): + func7(x) + + +@test_utils.test(arch=get_host_arch_list()) +def test_arg_not_match(): + _test_arg_not_match() + + +def _test_size_in_bytes(): + a = ti.ndarray(ti.i32, 8) + assert a._get_element_size() == 4 + assert a._get_nelement() == 8 + + b = ti.Vector.ndarray(10, ti.f64, 5) + assert b._get_element_size() == 8 + assert b._get_nelement() == 50 + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_size_in_bytes(): + _test_size_in_bytes() + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_different_shape(): + n1 = 4 + x = ti.ndarray(dtype=ti.f32, shape=(n1, n1)) + + @ti.kernel + def init(d: ti.i32, arr: ti.any_arr()): + for i, j in arr: + arr[i, j] = d + + init(2, x) + assert (x.to_numpy() == (np.ones(shape=(n1, n1)) * 2)).all() + n2 = 8 + y = ti.ndarray(dtype=ti.f32, shape=(n2, n2)) + init(3, y) + assert (y.to_numpy() == (np.ones(shape=(n2, n2)) * 3)).all() diff --git a/tests/python/test_ndrange.py b/tests/python/test_ndrange.py index 5688ee76ff91a..22def2d89bc71 100644 --- a/tests/python/test_ndrange.py +++ b/tests/python/test_ndrange.py @@ -1,9 +1,11 @@ import numpy as np +import pytest import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_1d(): x = ti.field(ti.f32, shape=(16)) @@ -21,7 +23,7 @@ def func(): assert x[i] == 0 -@ti.test() +@test_utils.test() def test_2d(): x = ti.field(ti.f32, shape=(16, 32)) @@ -42,7 +44,7 @@ def func(): assert x[i, j] == 0 -@ti.test() +@test_utils.test() def test_3d(): x = ti.field(ti.f32, shape=(16, 32, 64)) @@ -61,7 +63,7 @@ def func(): assert x[i, j, k] == 0 -@ti.test() +@test_utils.test() def test_tensor_based_3d(): x = ti.field(ti.i32, shape=(6, 6, 6)) y = ti.field(ti.i32, shape=(6, 6, 6)) @@ -87,7 +89,7 @@ def func(): assert x[i, j, k] == y[i, j, k] -@ti.test() +@test_utils.test() def test_static_grouped(): x = ti.field(ti.f32, shape=(16, 32, 64)) @@ -106,7 +108,7 @@ def func(): assert x[i, j, k] == 0 -@ti.test() +@test_utils.test() def test_static_grouped_static(): x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=(16, 4)) @@ -124,7 +126,7 @@ def func(): assert x[i, j][k, l] == k + l * 10 + i + j * 4 -@ti.test() +@test_utils.test() def test_field_init_eye(): # https://github.com/taichi-dev/taichi/issues/1824 @@ -142,7 +144,7 @@ def init(): assert np.allclose(A.to_numpy(), np.eye(n, dtype=np.float32)) -@ti.test() +@test_utils.test() def test_ndrange_index_floordiv(): # https://github.com/taichi-dev/taichi/issues/1829 @@ -165,7 +167,7 @@ def init(): assert A[i, j] == 0 -@ti.test() +@test_utils.test() def test_nested_ndrange(): # https://github.com/taichi-dev/taichi/issues/1829 @@ -189,7 +191,7 @@ def init(): assert A[i, j, k, l] == r -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_ndrange_ast_transform(): n, u, v = 4, 3, 2 @@ -217,3 +219,30 @@ def func(): else: r = 0 assert A[i, j] == r + + +@test_utils.test() +def test_grouped_ndrange_star(): + @ti.kernel + def foo() -> ti.i32: + ret = 0 + for I in ti.grouped(ti.ndrange(*[[1, 3]] * 3)): + ret += I[0] + I[1] + I[2] + return ret + + assert foo() == 36 + + +@test_utils.test() +def test_ndrange_three_arguments(): + @ti.kernel + def foo(): + for i in ti.ndrange((1, 2, 3)): + pass + + with pytest.raises( + ti.TaichiSyntaxError, + match= + r"Every argument of ndrange should be a scalar or a tuple/list like \(begin, end\)" + ): + foo() diff --git a/tests/python/test_nested_kernel_error.py b/tests/python/test_nested_kernel_error.py index f5632149781a6..6054322630c00 100644 --- a/tests/python/test_nested_kernel_error.py +++ b/tests/python/test_nested_kernel_error.py @@ -1,8 +1,10 @@ +import pytest + import taichi as ti +from tests import test_utils -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_nested_kernel_error(): @ti.kernel def B(): @@ -12,4 +14,5 @@ def B(): def A(): B() - A() + with pytest.raises(ti.TaichiCompilationError): + A() diff --git a/tests/python/test_new_allocator.py b/tests/python/test_new_allocator.py index efdfe1d50e976..361f9af04e1e9 100644 --- a/tests/python/test_new_allocator.py +++ b/tests/python/test_new_allocator.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_1d(): N = 16 @@ -22,7 +23,7 @@ def func(): assert y[i] == i * 2 -@ti.test() +@test_utils.test() def test_3d(): N = 2 M = 2 @@ -46,7 +47,7 @@ def func(): assert y[i, j] == i * 10 + j -@ti.test() +@test_utils.test() def test_matrix(): N = 16 @@ -66,7 +67,7 @@ def func(): assert x[i][1, 1] == i + 3 -@ti.test() +@test_utils.test() def test_alloc_in_kernel(): return # build bots may not have this much memory to tests... x = ti.field(ti.f32) diff --git a/tests/python/test_no_activate.py b/tests/python/test_no_activate.py index 3ec81b6f29504..43311f2c1dec1 100644 --- a/tests/python/test_no_activate.py +++ b/tests/python/test_no_activate.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_no_activate(): x = ti.field(ti.f32) diff --git a/tests/python/test_no_grad.py b/tests/python/test_no_grad.py index 40bd480f12a74..e89275e0b176b 100644 --- a/tests/python/test_no_grad.py +++ b/tests/python/test_no_grad.py @@ -2,9 +2,10 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_no_grad(): x = ti.field(ti.f32) loss = ti.field(ti.f32) @@ -24,7 +25,7 @@ def func(): func() -@ti.test() +@test_utils.test() def test_raise_no_gradient(): y = ti.field(shape=(), name='y', dtype=ti.f32, needs_grad=True) x = ti.field(shape=(), name='x', dtype=ti.f32) @@ -36,9 +37,9 @@ def func(x: ti.template()): z[0] = x.grad[None] x[None] = 5. - with pytest.raises(RuntimeError) as e: + with pytest.raises( + ti.TaichiCompilationError, + match= + 'Gradient x.grad has not been placed, check whether `needs_grad=True`' + ): func(x) - - assert e.type is RuntimeError - assert e.value.args[ - 0] == f"Gradient x.grad has not been placed, check whether `needs_grad=True`" diff --git a/tests/python/test_non_taichi_types_in_kernel.py b/tests/python/test_non_taichi_types_in_kernel.py index 6be1c8f7a8716..96ed9032e544d 100644 --- a/tests/python/test_non_taichi_types_in_kernel.py +++ b/tests/python/test_non_taichi_types_in_kernel.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_subscript_user_classes_in_kernel(): class MyList: def __init__(self, elements): diff --git a/tests/python/test_numpy.py b/tests/python/test_numpy.py index 84f6838086281..58019da6ecf5e 100644 --- a/tests/python/test_numpy.py +++ b/tests/python/test_numpy.py @@ -1,6 +1,8 @@ import numpy as np +import pytest import taichi as ti +from tests import test_utils def with_data_type(dt): @@ -26,27 +28,27 @@ def test_numpy(arr: ti.ext_arr()): assert a[i] == i * i * 4 -@ti.test() +@test_utils.test() def test_numpy_f32(): with_data_type(np.float32) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_numpy_f64(): with_data_type(np.float64) -@ti.test() +@test_utils.test() def test_numpy_i32(): with_data_type(np.int32) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_numpy_i64(): with_data_type(np.int64) -@ti.test() +@test_utils.test() def test_numpy_2d(): val = ti.field(ti.i32) @@ -74,7 +76,7 @@ def test_numpy(arr: ti.ext_arr()): assert a[i, j] == i * j + i + j -@ti.test() +@test_utils.test() def test_numpy_2d_transpose(): val = ti.field(ti.i32) @@ -101,7 +103,7 @@ def test_numpy(arr: ti.ext_arr()): assert val[i, j] == i * j + j * 4 -@ti.test() +@test_utils.test() def test_numpy_3d(): val = ti.field(ti.i32) @@ -133,8 +135,7 @@ def test_numpy(arr: ti.ext_arr()): assert a[i, j, k] == i * j * (k + 1) + i + j + k * 2 -@ti.test() -@ti.must_throw(IndexError) +@test_utils.test() def test_numpy_3d_error(): val = ti.field(ti.i32) @@ -153,10 +154,11 @@ def test_numpy(arr: ti.ext_arr()): a = np.empty(shape=(n, m, p), dtype=np.int32) - test_numpy(a) + with pytest.raises(ti.TaichiCompilationError): + test_numpy(a) -@ti.test() +@test_utils.test() def test_numpy_multiple_external_arrays(): n = 4 @@ -178,14 +180,14 @@ def test_numpy(a: ti.ext_arr(), b: ti.ext_arr()): assert b[i] == d[i] -@ti.test() -@ti.must_throw(AssertionError) +@test_utils.test() def test_index_mismatch(): - val = ti.field(ti.i32, shape=(1, 2, 3)) - val[0, 0] = 1 + with pytest.raises(AssertionError): + val = ti.field(ti.i32, shape=(1, 2, 3)) + val[0, 0] = 1 -@ti.test() +@test_utils.test() def test_numpy_zero(): @ti.kernel def test_numpy(arr: ti.ext_arr()): @@ -196,7 +198,7 @@ def test_numpy(arr: ti.ext_arr()): test_numpy(np.empty(shape=(5, 0), dtype=np.int32)) -@ti.test() +@test_utils.test() def test_numpy_struct_for(): @ti.kernel def func1(a: ti.any_arr()): diff --git a/tests/python/test_numpy_io.py b/tests/python/test_numpy_io.py index 23adb0e561410..e2795ff8a5f9f 100644 --- a/tests/python/test_numpy_io.py +++ b/tests/python/test_numpy_io.py @@ -1,9 +1,10 @@ import numpy as np import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_to_numpy_2d(): val = ti.field(ti.i32) @@ -24,7 +25,7 @@ def test_to_numpy_2d(): assert arr[i, j] == i + j * 3 -@ti.test() +@test_utils.test() def test_from_numpy_2d(): val = ti.field(ti.i32) @@ -46,7 +47,7 @@ def test_from_numpy_2d(): assert val[i, j] == i + j * 3 -@ti.test() +@test_utils.test() def test_to_numpy_struct(): n = 16 f = ti.Struct.field({"a": ti.i32, "b": ti.f32}, shape=(n, )) @@ -62,7 +63,7 @@ def test_to_numpy_struct(): assert arr_dict["b"][i] == i * 2 -@ti.test() +@test_utils.test() def test_from_numpy_struct(): n = 16 f = ti.Struct.field({"a": ti.i32, "b": ti.f32}, shape=(n, )) @@ -79,7 +80,7 @@ def test_from_numpy_struct(): assert f[i].b == i * 2 -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_f64(): val = ti.field(ti.f64) @@ -99,7 +100,7 @@ def test_f64(): assert val[i, j] == (i + j * 3) * 2e100 -@ti.test() +@test_utils.test() def test_matrix(): n = 4 m = 7 @@ -117,7 +118,7 @@ def test_matrix(): assert (nparr == new_nparr).all() -@ti.test() +@test_utils.test() def test_numpy_io_example(): n = 4 m = 7 diff --git a/tests/python/test_offload.py b/tests/python/test_offload.py index ab9c10ae3a1e1..f1ebfe7bd6f35 100644 --- a/tests/python/test_offload.py +++ b/tests/python/test_offload.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_running_loss(): return steps = 16 @@ -19,8 +20,8 @@ def test_running_loss(): def compute_loss(): total_loss[None] = 0.0 for i in range(steps): - total_loss[None].atomic_add(running_loss[i] * 2) - total_loss[None].atomic_add(additional_loss[None] * 3) + ti.atomic_add(total_loss[None], running_loss[i] * 2) + ti.atomic_add(total_loss[None], additional_loss[None] * 3) compute_loss() @@ -30,7 +31,7 @@ def compute_loss(): assert additional_loss.grad[None] == 3 -@ti.test() +@test_utils.test() def test_reduce_separate(): a = ti.field(ti.f32, shape=(16)) b = ti.field(ti.f32, shape=(4)) @@ -58,7 +59,7 @@ def reduce2(): assert a.grad[i] == 1 -@ti.test() +@test_utils.test() def test_reduce_merged(): a = ti.field(ti.f32, shape=(16)) b = ti.field(ti.f32, shape=(4)) diff --git a/tests/python/test_offload_cross.py b/tests/python/test_offload_cross.py index 89048dbd86186..4a2b6d313f34f 100644 --- a/tests/python/test_offload_cross.py +++ b/tests/python/test_offload_cross.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_offload_with_cross_block_locals(): ret = ti.field(ti.f32) @@ -19,7 +20,7 @@ def ker(): assert ret[None] == 45 -@ti.test() +@test_utils.test() def test_offload_with_cross_block_locals2(): ret = ti.field(ti.f32) @@ -40,7 +41,7 @@ def ker(): assert ret[None] == 45 * 21 -@ti.test() +@test_utils.test() def test_offload_with_cross_block_locals3(): ret = ti.field(ti.f32, shape=()) @@ -57,7 +58,7 @@ def ker(): assert ret[None] == 1 -@ti.test() +@test_utils.test() def test_offload_with_cross_block_locals4(): ret = ti.field(ti.f32, shape=()) @@ -74,7 +75,7 @@ def ker(): assert ret[None] == 10 -@ti.test() +@test_utils.test() def test_offload_with_flexible_bounds(): s = ti.field(ti.i32, shape=()) lower = ti.field(ti.i32, shape=()) @@ -92,7 +93,7 @@ def ker(): assert s[None] == 29 * 10 // 2 -@ti.test() +@test_utils.test() def test_offload_with_cross_block_globals(): ret = ti.field(ti.f32) @@ -108,3 +109,27 @@ def ker(): ker() assert ret[None] == 46 + + +@test_utils.test() +def test_offload_with_cross_nested_for(): + @ti.kernel + def run(a: ti.i32): + b = a + 1 + for x in range(1): + for i in range(b): + print('OK') + + run(2) + + +@test_utils.test() +def test_offload_with_cross_if_inside_for(): + @ti.kernel + def run(a: ti.i32): + b = a > 2 + for x in range(1): + if b: + print('OK') + + run(2) diff --git a/tests/python/test_offset.py b/tests/python/test_offset.py index 7620015c94375..fd41b732a0f12 100644 --- a/tests/python/test_offset.py +++ b/tests/python/test_offset.py @@ -1,7 +1,10 @@ +import pytest + import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_accessor(): a = ti.field(dtype=ti.i32) @@ -11,7 +14,7 @@ def test_accessor(): assert a[1029, 2100, 2200] == 1 -@ti.test() +@test_utils.test() def test_struct_for_huge_offsets(): a = ti.field(dtype=ti.i32) @@ -32,7 +35,7 @@ def test(): assert a[i, j, k, l] == i + j * 10 + k * 100 + l * 1000 -@ti.test() +@test_utils.test() def test_struct_for_negative(): a = ti.field(dtype=ti.i32) @@ -51,7 +54,7 @@ def test(): assert a[i, j] == i + j * 10 -@ti.test() +@test_utils.test() def test_offset_for_var(): a = ti.field(dtype=ti.i32, shape=16, offset=-48) b = ti.field(dtype=ti.i32, shape=(16, ), offset=(16, )) @@ -73,7 +76,7 @@ def test(): assert e[i, j] == i * j -@ti.test() +@test_utils.test() def test_offset_for_vector(): a = ti.field(dtype=ti.i32, shape=16, offset=-48) b = ti.field(dtype=ti.i32, shape=16, offset=None) @@ -92,7 +95,7 @@ def test(): assert c[i][0] == 2 * i -@ti.test() +@test_utils.test() def test_offset_for_matrix(): a = ti.Matrix.field(3, 3, @@ -113,22 +116,26 @@ def test(): assert a[i, j][0, 0] == i + j -@ti.test() -@ti.must_throw(AssertionError) +@test_utils.test() def test_offset_must_throw_var(): - a = ti.field(dtype=ti.float32, shape=3, offset=(3, 4)) - b = ti.field(dtype=ti.float32, shape=None, offset=(3, 4)) + with pytest.raises(AssertionError): + a = ti.field(dtype=ti.float32, shape=3, offset=(3, 4)) + b = ti.field(dtype=ti.float32, shape=None, offset=(3, 4)) -@ti.test() -@ti.must_throw(AssertionError) +@test_utils.test() def test_offset_must_throw_vector(): - a = ti.Vector.field(3, dtype=ti.float32, shape=3, offset=(3, 4)) - b = ti.Vector.field(3, dtype=ti.float32, shape=None, offset=(3, )) + with pytest.raises(AssertionError): + a = ti.Vector.field(3, dtype=ti.float32, shape=3, offset=(3, 4)) + b = ti.Vector.field(3, dtype=ti.float32, shape=None, offset=(3, )) -@ti.test() -@ti.must_throw(AssertionError) +@test_utils.test() def test_offset_must_throw_matrix(): - c = ti.Matrix.field(3, 3, dtype=ti.i32, shape=(32, 16, 8), offset=(32, 16)) - d = ti.Matrix.field(3, 3, dtype=ti.i32, shape=None, offset=(32, 16)) + with pytest.raises(AssertionError): + c = ti.Matrix.field(3, + 3, + dtype=ti.i32, + shape=(32, 16, 8), + offset=(32, 16)) + d = ti.Matrix.field(3, 3, dtype=ti.i32, shape=None, offset=(32, 16)) diff --git a/tests/python/test_oop.py b/tests/python/test_oop.py index 45735cb952855..059d774f93578 100644 --- a/tests/python/test_oop.py +++ b/tests/python/test_oop.py @@ -1,7 +1,11 @@ +import pytest +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_classfunc(): @ti.data_oriented class Array2D: @@ -33,7 +37,7 @@ def fill(self): assert arr.val[i, j] == i * j * 2 -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_oop(): @ti.data_oriented class Array2D: @@ -95,7 +99,7 @@ def double(): assert arr.val.grad[i, j] == 8 -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_oop_two_items(): @ti.data_oriented class Array2D: @@ -145,7 +149,7 @@ def reduce(self): assert arr2.val.grad[i, j] == arr2_mult -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_oop_inherit_ok(): # Array1D inherits from object, which makes the callstack being 'class Array2D(object)' # instead of '@ti.data_oriented'. Make sure this also works. @@ -175,8 +179,7 @@ def reduce(self): assert arr.val.grad[i, j] == 42 -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(ti.KernelDefError) +@test_utils.test(arch=get_host_arch_list()) def test_oop_class_must_be_data_oriented(): class Array1D(object): def __init__(self, n, mul): @@ -197,10 +200,11 @@ def reduce(self): ti.root.lazy_grad() # Array1D is not properly decorated, this will raise an Exception - arr.reduce() + with pytest.raises(ti.TaichiSyntaxError): + arr.reduce() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_hook(): @ti.data_oriented class Solver: @@ -224,7 +228,7 @@ def hook(x: ti.template()): assert (solver.val[i, j] == 1.0) -@ti.test() +@test_utils.test() def test_oop_with_portery_decorator(): @ti.data_oriented class TestPortery: @@ -243,7 +247,7 @@ def raw_proterty(self): assert a.raw_proterty == 3 -@ti.test() +@test_utils.test() def test_oop_with_static_decorator(): @ti.data_oriented class TestStatic: diff --git a/tests/python/test_optimization.py b/tests/python/test_optimization.py index 4e6b96fef8b8c..258fd41fa8467 100644 --- a/tests/python/test_optimization.py +++ b/tests/python/test_optimization.py @@ -1,7 +1,10 @@ +from taichi.lang.misc import serialize + import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_advanced_store_forwarding_nested_loops(): val = ti.field(ti.i32) ti.root.place(val) @@ -22,7 +25,7 @@ def func(): assert val[None] == 10 -@ti.test() +@test_utils.test() def test_advanced_unused_store_elimination_if(): val = ti.field(ti.i32) ti.root.place(val) @@ -47,14 +50,14 @@ def func(): assert val[None] == 3 -@ti.test() +@test_utils.test() def test_local_store_in_nested_for_and_if(): # See https://github.com/taichi-dev/taichi/pull/862. val = ti.field(ti.i32, shape=(3, 3, 3)) @ti.kernel def func(): - ti.serialize() + serialize() for i, j, k in val: if i < 2 and j < 2 and k < 2: a = 0 @@ -74,7 +77,7 @@ def func(): assert (val[i, j, k] == 1) -@ti.test() +@test_utils.test() def test_advanced_store_forwarding_continue_in_if(): val = ti.field(ti.i32) ti.root.place(val) @@ -104,7 +107,7 @@ def func(n: ti.i32): assert val[None] == 1515 -@ti.test() +@test_utils.test() def test_advanced_store_elimination_in_loop(): val = ti.field(ti.i32) ti.root.place(val) @@ -127,7 +130,7 @@ def func(): assert val[None] == 8 -@ti.test() +@test_utils.test() def test_parallel_assignment(): mat = ti.field(ti.i32, shape=(3, 4)) diff --git a/tests/python/test_packed_size.py b/tests/python/test_packed_size.py index 1e934e76cb752..e52c0e5e95dfc 100644 --- a/tests/python/test_packed_size.py +++ b/tests/python/test_packed_size.py @@ -1,9 +1,10 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.packed, packed=True) +@test_utils.test(require=ti.extension.packed, packed=True) def test_packed_size(): x = ti.field(ti.i32) ti.root.dense(ti.i, 17).dense(ti.ijk, 129).place(x) assert x.shape == (17 * 129, 129, 129) - assert x.snode.parent().parent().cell_size_bytes == 4 * 129**3 + assert x.snode.parent().parent()._cell_size_bytes == 4 * 129**3 diff --git a/tests/python/test_parallel_range_for.py b/tests/python/test_parallel_range_for.py index 04b042a9a72be..429a2809271c4 100644 --- a/tests/python/test_parallel_range_for.py +++ b/tests/python/test_parallel_range_for.py @@ -1,8 +1,9 @@ import taichi as ti +from tests import test_utils # such small block_dim will cause grid_dim too large for OpenGL... -@ti.test(exclude=ti.opengl) +@test_utils.test(exclude=ti.opengl) def test_parallel_range_for(): n = 1024 * 1024 val = ti.field(ti.i32, shape=(n)) diff --git a/tests/python/test_pow.py b/tests/python/test_pow.py index 7263916dc90de..8ee43894a0de9 100644 --- a/tests/python/test_pow.py +++ b/tests/python/test_pow.py @@ -1,4 +1,5 @@ import taichi as ti +from tests import test_utils def _test_pow_f(dt): @@ -27,21 +28,21 @@ def func(x: dt, y: ti.template()): assert z[None] == x**y -@ti.test() +@test_utils.test() def test_pow_f32(): _test_pow_f(ti.f32) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_pow_f64(): _test_pow_f(ti.f64) -@ti.test() +@test_utils.test() def test_pow_i32(): _test_pow_i(ti.i32) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_pow_i64(): _test_pow_i(ti.i64) diff --git a/tests/python/test_print.py b/tests/python/test_print.py index 1442fa76108f3..3e62d19199378 100644 --- a/tests/python/test_print.py +++ b/tests/python/test_print.py @@ -1,6 +1,7 @@ import pytest import taichi as ti +from tests import test_utils # Not really testable.. @@ -8,7 +9,7 @@ # Metal doesn't support print() or 64-bit data # While OpenGL does support print, but not 64-bit data @pytest.mark.parametrize('dt', [ti.i32, ti.f32, ti.i64, ti.f64]) -@ti.test(exclude=[ti.metal, ti.opengl, ti.vulkan]) +@test_utils.test(exclude=[ti.metal, ti.opengl, ti.vulkan]) def test_print(dt): @ti.kernel def func(): @@ -22,7 +23,7 @@ def func(): # TODO: As described by @k-ye above, what we want to ensure # is that, the content shows on console is *correct*. -@ti.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan +@test_utils.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan def test_multi_print(): @ti.kernel def func(x: ti.i32, y: ti.f32): @@ -32,7 +33,7 @@ def func(x: ti.i32, y: ti.f32): ti.sync() -@ti.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan +@test_utils.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan def test_print_string(): @ti.kernel def func(x: ti.i32, y: ti.f32): @@ -44,7 +45,7 @@ def func(x: ti.i32, y: ti.f32): ti.sync() -@ti.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan +@test_utils.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan def test_print_matrix(): x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=()) y = ti.Vector.field(3, dtype=ti.f32, shape=3) @@ -60,7 +61,7 @@ def func(k: ti.f32): ti.sync() -@ti.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan +@test_utils.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan def test_print_sep_end(): @ti.kernel def func(): @@ -80,7 +81,7 @@ def func(): ti.sync() -@ti.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan +@test_utils.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan def test_print_multiple_threads(): x = ti.field(dtype=ti.f32, shape=(128, )) @@ -96,7 +97,7 @@ def func(k: ti.f32): ti.sync() -@ti.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan +@test_utils.test(exclude=[ti.vulkan]) # TODO(changyu): enable ti.vulkan def test_print_list(): x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=(2, 3)) y = ti.Vector.field(3, dtype=ti.f32, shape=()) @@ -117,7 +118,7 @@ def func(k: ti.f32): ti.sync() -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_python_scope_print_field(): x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=()) y = ti.Vector.field(3, dtype=ti.f32, shape=3) @@ -126,3 +127,32 @@ def test_python_scope_print_field(): print(x) print(y) print(z) + + +@test_utils.test(arch=ti.cpu) +def test_print_string_format(): + @ti.kernel + def func(k: ti.f32): + print(123) + print("{} abc".format(123)) + print("{} {} {}".format(1, 2, 3)) + print("{} {name} {value}".format(k, name=999, value=123)) + name = 123.4 + value = 456.7 + print("{} {name} {value}".format(k, name=name, value=value)) + + func(233.3) + ti.sync() + + +@test_utils.test(arch=ti.cpu) +def test_print_fstring(): + def foo1(x): + return x + 1 + + @ti.kernel + def func(i: ti.i32, f: ti.f32): + print(f'qwe {foo1(1)} {foo1(2) * 2 - 1} {i} {f} {4} {True} {1.23}') + + func(123, 4.56) + ti.sync() diff --git a/tests/python/test_ptr_assign.py b/tests/python/test_ptr_assign.py index 266a454e0c68b..63e5a1e7db990 100644 --- a/tests/python/test_ptr_assign.py +++ b/tests/python/test_ptr_assign.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_ptr_scalar(): a = ti.field(dtype=ti.f32, shape=()) @@ -18,7 +19,7 @@ def func(t: ti.f32): assert a[None] == x * y + y -@ti.test() +@test_utils.test() def test_ptr_matrix(): a = ti.Matrix.field(2, 2, dtype=ti.f32, shape=()) @@ -33,7 +34,7 @@ def func(t: ti.f32): assert a[None][1, 0] == x -@ti.test() +@test_utils.test() def test_ptr_field(): a = ti.field(dtype=ti.f32, shape=(3, 4)) @@ -51,7 +52,7 @@ def func(t: ti.f32): assert a[2, 0] == x + y -@ti.test() +@test_utils.test() def test_pythonish_tuple_assign(): a = ti.field(dtype=ti.f32, shape=()) b = ti.field(dtype=ti.f32, shape=()) @@ -68,7 +69,7 @@ def func(x: ti.f32, y: ti.f32): assert b[None] == x -@ti.test() +@test_utils.test() def test_ptr_func(): a = ti.field(dtype=ti.f32, shape=()) @@ -85,7 +86,7 @@ def func(): assert a[None] == 5.0 -@ti.test() +@test_utils.test() def test_ptr_class_func(): @ti.data_oriented class MyClass: diff --git a/tests/python/test_random.py b/tests/python/test_random.py index d9e40ba147474..fd609b1cbfca5 100644 --- a/tests/python/test_random.py +++ b/tests/python/test_random.py @@ -1,12 +1,8 @@ import taichi as ti -from taichi import approx +from tests import test_utils -def archs_support_random(func): - return ti.archs_excluding(ti.metal)(func) - - -@ti.test(exclude=ti.metal) +@test_utils.test() def test_random_float(): for precision in [ti.f32, ti.f64]: ti.init() @@ -22,10 +18,10 @@ def fill(): fill() X = x.to_numpy() for i in range(1, 4): - assert (X**i).mean() == approx(1 / (i + 1), rel=1e-2) + assert (X**i).mean() == test_utils.approx(1 / (i + 1), rel=1e-2) -@ti.test(exclude=ti.metal) +@test_utils.test() def test_random_int(): for precision in [ti.i32, ti.i64]: ti.init() @@ -45,10 +41,10 @@ def fill(): fill() X = x.to_numpy() for i in range(1, 4): - assert (X**i).mean() == approx(1 / (i + 1), rel=1e-2) + assert (X**i).mean() == test_utils.approx(1 / (i + 1), rel=1e-2) -@ti.test(exclude=ti.metal) +@test_utils.test() def test_random_independent_product(): n = 1024 x = ti.field(ti.f32, shape=n * n) @@ -63,10 +59,10 @@ def fill(): fill() X = x.to_numpy() for i in range(4): - assert X.mean() == approx(1 / 4, rel=1e-2) + assert X.mean() == test_utils.approx(1 / 4, rel=1e-2) -@ti.test(exclude=ti.metal) +@test_utils.test() def test_random_2d_dist(): n = 8192 @@ -75,7 +71,7 @@ def test_random_2d_dist(): @ti.kernel def gen(): for i in range(n): - x[i] = ti.Vector([ti.random(), ti.random()], dt=ti.f32) + x[i] = ti.Vector([ti.random(), ti.random()]) gen() @@ -86,10 +82,10 @@ def gen(): counters[c] += 1 for c in range(4): - assert counters[c] / n == approx(1 / 4, rel=0.2) + assert counters[c] / n == test_utils.approx(1 / 4, rel=0.2) -@ti.test(exclude=ti.metal) +@test_utils.test() def test_random_seed_per_launch(): n = 10 x = ti.field(ti.f32, shape=n) @@ -107,7 +103,7 @@ def gen(i: ti.i32): assert count <= n * 0.15 -@ti.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.metal]) def test_random_seed_per_program(): import numpy as np n = 10 @@ -128,7 +124,7 @@ def gen(): assert not np.allclose(result[0], result[1]) -@ti.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(arch=[ti.cpu, ti.cuda]) def test_random_f64(): ''' Tests the granularity of float64 random numbers. @@ -148,7 +144,7 @@ def foo(): assert np.max(frac) > 0 -@ti.test(exclude=ti.metal) +@test_utils.test() def test_randn(): ''' Tests the generation of Gaussian random numbers. @@ -170,4 +166,5 @@ def fill(): # https://en.wikipedia.org/wiki/Normal_distribution#Moments moments = [0.0, 1.0, 0.0, 3.0] for i in range(4): - assert (X**(i + 1)).mean() == approx(moments[i], abs=3e-2) + assert (X**(i + 1)).mean() == test_utils.approx(moments[i], + abs=3e-2) diff --git a/tests/python/test_reduction.py b/tests/python/test_reduction.py index 71ad62b0b997c..434a1dc589979 100644 --- a/tests/python/test_reduction.py +++ b/tests/python/test_reduction.py @@ -3,6 +3,7 @@ from pytest import approx import taichi as ti +from tests import test_utils OP_ADD = 0 OP_MIN = 1 @@ -32,8 +33,8 @@ def _test_reduction_single(dtype, criterion, op): N = 1024 * 1024 - if (ti.cfg.arch == ti.opengl - or ti.cfg.arch == ti.vulkan) and dtype == ti.f32: + if (ti.lang.impl.current_cfg().arch == ti.opengl or + ti.lang.impl.current_cfg().arch == ti.vulkan) and dtype == ti.f32: # OpenGL/Vulkan are not capable of such large number in its float32... N = 1024 * 16 @@ -80,42 +81,42 @@ def reduce_tmp() -> dtype: @pytest.mark.parametrize('op', [OP_ADD, OP_MIN, OP_MAX, OP_AND, OP_OR, OP_XOR]) -@ti.all_archs +@test_utils.test() def test_reduction_single_i32(op): _test_reduction_single(ti.i32, lambda x, y: x % 2**32 == y % 2**32, op) @pytest.mark.parametrize('op', [OP_ADD]) -@ti.test(exclude=ti.opengl) +@test_utils.test(exclude=ti.opengl) def test_reduction_single_u32(op): _test_reduction_single(ti.u32, lambda x, y: x % 2**32 == y % 2**32, op) @pytest.mark.parametrize('op', [OP_ADD, OP_MIN, OP_MAX]) -@ti.all_archs +@test_utils.test() def test_reduction_single_f32(op): _test_reduction_single(ti.f32, lambda x, y: x == approx(y, 3e-4), op) @pytest.mark.parametrize('op', [OP_ADD]) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_reduction_single_i64(op): _test_reduction_single(ti.i64, lambda x, y: x % 2**64 == y % 2**64, op) @pytest.mark.parametrize('op', [OP_ADD]) -@ti.test(exclude=ti.opengl, require=ti.extension.data64) +@test_utils.test(exclude=ti.opengl, require=ti.extension.data64) def test_reduction_single_u64(op): _test_reduction_single(ti.u64, lambda x, y: x % 2**64 == y % 2**64, op) @pytest.mark.parametrize('op', [OP_ADD]) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_reduction_single_f64(op): _test_reduction_single(ti.f64, lambda x, y: x == approx(y, 1e-12), op) -@ti.test() +@test_utils.test() def test_reduction_different_scale(): @ti.kernel def func(n: ti.template()) -> ti.i32: @@ -130,7 +131,7 @@ def func(n: ti.template()) -> ti.i32: assert n == func(n) -@ti.test() +@test_utils.test() def test_reduction_any_arr(): @ti.kernel def reduce(a: ti.any_arr()) -> ti.i32: diff --git a/tests/python/test_rescale.py b/tests/python/test_rescale.py index 3fe9bf0af7b25..abf95da4ea93a 100644 --- a/tests/python/test_rescale.py +++ b/tests/python/test_rescale.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_rescale(): a = ti.field(ti.f32) b = ti.field(ti.f32) diff --git a/tests/python/test_return.py b/tests/python/test_return.py index 7384621ad3b35..30bcc273dfcde 100644 --- a/tests/python/test_return.py +++ b/tests/python/test_return.py @@ -1,28 +1,17 @@ -import taichi as ti -from taichi import approx - - -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) -def _test_return_not_last_stmt(): # TODO: make this work - x = ti.field(ti.i32, ()) +import pytest - @ti.kernel - def kernel() -> ti.i32: - return 1 - x[None] = 233 - - kernel() +import taichi as ti +from tests import test_utils -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_return_without_type_hint(): @ti.kernel def kernel(): return 1 - kernel() + with pytest.raises(ti.TaichiCompilationError): + kernel() def test_const_func_ret(): @@ -36,28 +25,28 @@ def func1() -> ti.f32: def func2() -> ti.i32: return 3.3 # return type mismatch, will be auto-casted into ti.i32 - assert func1() == approx(3) + assert func1() == test_utils.approx(3) assert func2() == 3 -@ti.test() +@test_utils.test() def _test_binary_func_ret(dt1, dt2, dt3, castor): @ti.kernel def func(a: dt1, b: dt2) -> dt3: return a * b - if ti.core.is_integral(dt1): + if ti.types.is_integral(dt1): xs = list(range(4)) else: xs = [0.2, 0.4, 0.8, 1.0] - if ti.core.is_integral(dt2): + if ti.types.is_integral(dt2): ys = list(range(4)) else: ys = [0.2, 0.4, 0.8, 1.0] for x, y in zip(xs, ys): - assert func(x, y) == approx(castor(x * y)) + assert func(x, y) == test_utils.approx(castor(x * y)) def test_binary_func_ret(): @@ -65,3 +54,96 @@ def test_binary_func_ret(): _test_binary_func_ret(ti.f32, ti.i32, ti.f32, float) _test_binary_func_ret(ti.i32, ti.f32, ti.i32, int) _test_binary_func_ret(ti.f32, ti.i32, ti.i32, int) + + +@test_utils.test() +def test_return_in_static_if(): + @ti.kernel + def foo(a: ti.template()) -> ti.i32: + if ti.static(a == 1): + return 1 + elif ti.static(a == 2): + return 2 + return 3 + + assert foo(1) == 1 + assert foo(2) == 2 + assert foo(123) == 3 + + +@test_utils.test() +def test_func_multiple_return(): + @ti.func + def safe_sqrt(a): + if a > 0: + return ti.sqrt(a) + else: + return 0.0 + + @ti.kernel + def kern(a: float): + print(safe_sqrt(a)) + + with pytest.raises( + ti.TaichiCompilationError, + match='Return inside non-static if/for is not supported'): + kern(-233) + + +@test_utils.test() +def test_return_inside_static_for(): + @ti.kernel + def foo() -> ti.i32: + a = 0 + for i in ti.static(range(10)): + a += i * i + if ti.static(i == 8): + return a + + assert foo() == 204 + + +@test_utils.test() +def test_return_inside_non_static_for(): + with pytest.raises( + ti.TaichiCompilationError, + match='Return inside non-static if/for is not supported'): + + @ti.kernel + def foo() -> ti.i32: + for i in range(10): + return i + + foo() + + +@test_utils.test() +def test_kernel_no_return(): + with pytest.raises( + ti.TaichiSyntaxError, + match= + "Kernel has a return type but does not have a return statement"): + + @ti.kernel + def foo() -> ti.i32: + pass + + foo() + + +@test_utils.test() +def test_func_no_return(): + with pytest.raises( + ti.TaichiCompilationError, + match= + "Function has a return type but does not have a return statement"): + + @ti.func + def bar() -> ti.i32: + pass + + @ti.kernel + def foo() -> ti.i32: + return bar() + + foo() diff --git a/tests/python/test_runtime.py b/tests/python/test_runtime.py index 89c86825d803c..a397e0c6057d6 100644 --- a/tests/python/test_runtime.py +++ b/tests/python/test_runtime.py @@ -6,6 +6,7 @@ import pytest import taichi as ti +from tests import test_utils @contextmanager @@ -47,10 +48,8 @@ def patch_os_environ_helper(custom_environ: dict, excludes: dict): TF = [True, False] init_args = { # 'key': [default, choices], - 'print_preprocessed': [False, TF], 'log_level': ['info', ['error', 'warn', 'info', 'debug', 'trace']], 'gdb_trigger': [False, TF], - 'excepthook': [False, TF], 'advanced_optimization': [True, TF], 'debug': [False, TF], 'print_ir': [False, TF], @@ -60,7 +59,6 @@ def patch_os_environ_helper(custom_environ: dict, excludes: dict): 'flatten_if': [False, TF], 'simplify_before_lower_access': [True, TF], 'simplify_after_lower_access': [True, TF], - 'print_benchmark_stat': [False, TF], 'kernel_profiler': [False, TF], 'check_out_of_bound': [False, TF], 'print_accessor_ir': [False, TF], @@ -76,10 +74,8 @@ def patch_os_environ_helper(custom_environ: dict, excludes: dict): env_configs = ['TI_' + key.upper() for key in init_args.keys()] special_init_cfgs = [ - 'print_preprocessed', 'log_level', 'gdb_trigger', - 'excepthook', ] @@ -89,11 +85,12 @@ def test_init_arg(key, values): # helper function: def test_arg(key, value, kwargs={}): - spec_cfg = ti.init(_test_mode=True, **kwargs) if key in special_init_cfgs: + spec_cfg = ti.init(_test_mode=True, **kwargs) cfg = spec_cfg else: - cfg = ti.cfg + ti.init(**kwargs) + cfg = ti.lang.impl.current_cfg() assert getattr(cfg, key) == value with patch_os_environ_helper({}, excludes=env_configs): @@ -114,15 +111,15 @@ def test_arg(key, value, kwargs={}): test_arg(key, value) -@pytest.mark.parametrize('arch', ti.supported_archs()) +@pytest.mark.parametrize('arch', test_utils.expected_archs()) def test_init_arch(arch): with patch_os_environ_helper({}, excludes=['TI_ARCH']): ti.init(arch=arch) - assert ti.cfg.arch == arch - with patch_os_environ_helper({'TI_ARCH': ti.core.arch_name(arch)}, + assert ti.lang.impl.current_cfg().arch == arch + with patch_os_environ_helper({'TI_ARCH': ti._lib.core.arch_name(arch)}, excludes=['TI_ARCH']): ti.init(arch=ti.cc) - assert ti.cfg.arch == arch + assert ti.lang.impl.current_cfg().arch == arch def test_init_bad_arg(): @@ -130,33 +127,37 @@ def test_init_bad_arg(): ti.init(_test_mode=True, debug=True, foo_bar=233) -@ti.test(arch=ti.cpu) -def test_materialize_callback(): - x = ti.field(ti.f32, (3, 4)) +def test_init_require_version(): + ti_core = ti._lib.utils.import_ti_core() + require_version = '{}.{}.{}'.format(ti_core.get_version_major(), + ti_core.get_version_minor(), + ti_core.get_version_patch()) + ti.init(_test_mode=True, debug=True, require_version=require_version) - @ti.materialize_callback - @ti.kernel - def init_x(): - for i in range(3): - for j in range(4): - x[i, j] = i + j + 1 - # x will be initialized on first invocation - for i in range(3): - for j in range(4): - assert x[i, j] == i + j + 1 +def test_init_bad_require_version(): + with pytest.raises(Exception): + ti_core = ti._lib.utils.import_ti_core() + bad_require_version = '{}.{}.{}'.format( + ti_core.get_version_major(), ti_core.get_version_minor(), + ti_core.get_version_patch() + 1) + ti.init(_test_mode=True, + debug=True, + require_version=bad_require_version) -@pytest.mark.parametrize('level', ti.supported_log_levels) -@ti.test() +@pytest.mark.parametrize( + 'level', [ti.DEBUG, ti.TRACE, ti.INFO, ti.WARN, ti.ERROR, ti.CRITICAL]) +@test_utils.test() def test_supported_log_levels(level): spec_cfg = ti.init(_test_mode=True, log_level=level) assert spec_cfg.log_level == level -@pytest.mark.parametrize('level', ti.supported_log_levels) -@ti.test() +@pytest.mark.parametrize( + 'level', [ti.DEBUG, ti.TRACE, ti.INFO, ti.WARN, ti.ERROR, ti.CRITICAL]) +@test_utils.test() def test_supported_log_levels(level): spec_cfg = ti.init(_test_mode=True) ti.set_logging_level(level) - assert ti.is_logging_effective(level) + assert ti._logging.is_logging_effective(level) diff --git a/tests/python/test_scalar_op.py b/tests/python/test_scalar_op.py new file mode 100644 index 0000000000000..1be7606302f8d --- /dev/null +++ b/tests/python/test_scalar_op.py @@ -0,0 +1,192 @@ +import operator as ops + +import numpy as np +import pytest + +import taichi as ti +from tests import test_utils + +binary_func_table = [ + (ops.add, ) * 2, + (ops.sub, ) * 2, + (ops.mul, ) * 2, + (ops.truediv, ) * 2, + (ops.floordiv, ) * 2, + (ops.mod, ) * 2, + (ops.pow, ) * 2, + (ops.and_, ) * 2, + (ops.or_, ) * 2, + (ops.xor, ) * 2, + (ops.eq, ) * 2, + (ops.ne, ) * 2, + (ops.lt, ) * 2, + (ops.le, ) * 2, + (ops.gt, ) * 2, + (ops.ge, ) * 2, + (ti.max, np.maximum), + (ti.min, np.minimum), + (ti.atan2, np.arctan2), +] + +unary_func_table = [ + (ops.neg, ) * 2, + (ops.invert, ) * 2, + (ti.lang.ops.logical_not, np.logical_not), + (ti.lang.ops.abs, np.abs), + (ti.exp, np.exp), + (ti.log, np.log), + (ti.sin, np.sin), + (ti.cos, np.cos), + (ti.tan, np.tan), + (ti.asin, np.arcsin), + (ti.acos, np.arccos), + (ti.tanh, np.tanh), + (ti.round, np.round), + (ti.floor, np.floor), + (ti.ceil, np.ceil), +] + + +@pytest.mark.parametrize('ti_func,np_func', binary_func_table) +def test_python_scope_vector_binary(ti_func, np_func): + ti.init() + x = ti.Vector([2, 3]) + y = ti.Vector([5, 4]) + + result = ti_func(x, y).to_numpy() + if ti_func in [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge]: + result = result.astype(bool) + expected = np_func(x.to_numpy(), y.to_numpy()) + assert test_utils.allclose(result, expected) + + +@pytest.mark.parametrize('ti_func,np_func', unary_func_table) +def test_python_scope_vector_unary(ti_func, np_func): + ti.init() + x = ti.Vector([2, 3] if ti_func in + [ops.invert, ti.lang.ops.logical_not] else [0.2, 0.3]) + + result = ti_func(x).to_numpy() + if ti_func in [ti.lang.ops.logical_not]: + result = result.astype(bool) + expected = np_func(x.to_numpy()) + assert test_utils.allclose(result, expected) + + +def test_python_scope_matmul(): + ti.init() + a = np.array([[1, 2], [3, 4]]) + b = np.array([[5, 6], [7, 8]]) + x = ti.Vector(a) + y = ti.Vector(b) + + result = (x @ y).to_numpy() + expected = a @ b + assert test_utils.allclose(result, expected) + + +def test_python_scope_linalg(): + ti.init() + a = np.array([3, 4, -2]) + b = np.array([-5, 0, 6]) + x = ti.Vector(a) + y = ti.Vector(b) + + assert test_utils.allclose(x.dot(y), np.dot(a, b)) + assert test_utils.allclose(x.norm(), np.sqrt(np.dot(a, a))) + assert test_utils.allclose(x.normalized(), a / np.sqrt(np.dot(a, a))) + assert x.any() == 1 # To match that of Taichi IR, we return -1 for True + assert y.all() == 0 + + +@test_utils.test(arch=[ti.x64, ti.cuda, ti.metal]) +def test_16_min_max(): + @ti.kernel + def min_u16(a: ti.u16, b: ti.u16) -> ti.u16: + return ti.min(a, b) + + @ti.kernel + def min_i16(a: ti.i16, b: ti.i16) -> ti.i16: + return ti.min(a, b) + + @ti.kernel + def max_u16(a: ti.u16, b: ti.u16) -> ti.u16: + return ti.max(a, b) + + @ti.kernel + def max_i16(a: ti.i16, b: ti.i16) -> ti.i16: + return ti.max(a, b) + + a, b = 4, 2 + assert min_u16(a, b) == min(a, b) + assert min_i16(a, b) == min(a, b) + assert max_u16(a, b) == max(a, b) + assert max_i16(a, b) == max(a, b) + + +@test_utils.test(exclude=[ti.opengl, ti.cc]) +def test_32_min_max(): + @ti.kernel + def min_u32(a: ti.u32, b: ti.u32) -> ti.u32: + return ti.min(a, b) + + @ti.kernel + def min_i32(a: ti.i32, b: ti.i32) -> ti.i32: + return ti.min(a, b) + + @ti.kernel + def max_u32(a: ti.u32, b: ti.u32) -> ti.u32: + return ti.max(a, b) + + @ti.kernel + def max_i32(a: ti.i32, b: ti.i32) -> ti.i32: + return ti.max(a, b) + + a, b = 4, 2 + assert min_u32(a, b) == min(a, b) + assert min_i32(a, b) == min(a, b) + assert max_u32(a, b) == max(a, b) + assert max_i32(a, b) == max(a, b) + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_64_min_max(): + @ti.kernel + def min_u64(a: ti.u64, b: ti.u64) -> ti.u64: + return ti.min(a, b) + + @ti.kernel + def min_i64(a: ti.i64, b: ti.i64) -> ti.i64: + return ti.min(a, b) + + @ti.kernel + def max_u64(a: ti.u64, b: ti.u64) -> ti.u64: + return ti.max(a, b) + + @ti.kernel + def max_i64(a: ti.i64, b: ti.i64) -> ti.i64: + return ti.max(a, b) + + a, b = 4, 2 + assert min_u64(a, b) == min(a, b) + assert min_i64(a, b) == min(a, b) + assert max_u64(a, b) == max(a, b) + assert max_i64(a, b) == max(a, b) + + +@test_utils.test() +def test_min_max_vector_starred(): + @ti.kernel + def min_starred() -> ti.i32: + a = ti.Vector([1, 2, 3]) + b = ti.Vector([4, 5, 6]) + return ti.min(*a, *b) + + @ti.kernel + def max_starred() -> ti.i32: + a = ti.Vector([1, 2, 3]) + b = ti.Vector([4, 5, 6]) + return ti.max(*a, *b) + + assert min_starred() == 1 + assert max_starred() == 6 diff --git a/tests/python/test_scope_errors.py b/tests/python/test_scope_errors.py index c376034c28d49..ce6ef9231e0d1 100644 --- a/tests/python/test_scope_errors.py +++ b/tests/python/test_scope_errors.py @@ -1,8 +1,10 @@ +import pytest + import taichi as ti +from tests import test_utils -@ti.test() -@ti.must_throw(UnboundLocalError) +@test_utils.test() def test_if(): x = ti.field(ti.f32) @@ -16,11 +18,11 @@ def func(): a = 1 print(a) - func() + with pytest.raises(Exception): + func() -@ti.test() -@ti.must_throw(UnboundLocalError) +@test_utils.test() def test_for(): x = ti.field(ti.f32) @@ -32,11 +34,11 @@ def func(): a = i print(a) - func() + with pytest.raises(Exception): + func() -@ti.test() -@ti.must_throw(UnboundLocalError) +@test_utils.test() def test_while(): x = ti.field(ti.f32) @@ -48,4 +50,5 @@ def func(): a = 0 print(a) - func() + with pytest.raises(Exception): + func() diff --git a/tests/python/test_serial_execution.py b/tests/python/test_serial_execution.py index 2fdc4623180df..5acd815024124 100644 --- a/tests/python/test_serial_execution.py +++ b/tests/python/test_serial_execution.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(arch=ti.cpu, cpu_max_num_threads=1) +@test_utils.test(arch=ti.cpu, cpu_max_num_threads=1) def test_serial_range_for(): n = 1024 * 32 s = ti.field(dtype=ti.i32, shape=n) @@ -19,7 +20,7 @@ def fill_range(): assert s[i] == i -@ti.test(arch=ti.cpu, cpu_max_num_threads=1) +@test_utils.test(arch=ti.cpu, cpu_max_num_threads=1) def test_serial_struct_for(): n = 1024 * 32 s = ti.field(dtype=ti.i32, shape=n) diff --git a/tests/python/test_sfg.py b/tests/python/test_sfg.py index d4221913dc540..edee8e2879ddd 100644 --- a/tests/python/test_sfg.py +++ b/tests/python/test_sfg.py @@ -1,11 +1,11 @@ import numpy as np -import pytest import taichi as ti +from tests import test_utils -@ti.test(require=[ti.extension.async_mode, ti.extension.sparse], - async_mode=True) +@test_utils.test(require=[ti.extension.async_mode, ti.extension.sparse], + async_mode=True) def test_remove_clear_list_from_fused_serial(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -26,7 +26,7 @@ def init_xy(): init_xy() ti.sync() - stats = ti.get_kernel_stats() + stats = ti.tools.async_utils.get_kernel_stats() stats.clear() @ti.kernel @@ -63,7 +63,7 @@ def serial_z(): assert xs[i] == 0 -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_sfg_dead_store_elimination(): n = 32 @@ -86,7 +86,7 @@ def scatter(): x.from_numpy(xnp) ti.sync() - stats = ti.get_kernel_stats() + stats = ti.tools.async_utils.get_kernel_stats() stats.clear() for _ in range(5): @@ -102,10 +102,10 @@ def scatter(): x_grad = x.grad.to_numpy() for i in range(n): - assert ti.approx(x_grad[i]) == 2.0 * i + assert test_utils.approx(x_grad[i]) == 2.0 * i -@ti.test(require=ti.extension.async_mode, async_mode=True) +@test_utils.test(require=ti.extension.async_mode, async_mode=True) def test_global_tmp_value_state(): # https://github.com/taichi-dev/taichi/issues/2024 n = 10 @@ -121,4 +121,4 @@ def compute_mean_of_boundary_edges() -> ti.i32: x.from_numpy(np.arange(0, n, dtype=np.float32)) mean = compute_mean_of_boundary_edges() - assert ti.approx(mean) == 33 + assert test_utils.approx(mean) == 33 diff --git a/tests/python/test_simple_matrix_slice.py b/tests/python/test_simple_matrix_slice.py new file mode 100644 index 0000000000000..7c34649e77281 --- /dev/null +++ b/tests/python/test_simple_matrix_slice.py @@ -0,0 +1,22 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test() +def test_slice(): + b = 3 + + @ti.kernel + def foo1() -> ti.types.vector(3, dtype=ti.i32): + c = ti.Vector([0, 1, 2, 3, 4, 5, 6]) + return c[:5:2] + + @ti.kernel + def foo2() -> ti.types.matrix(2, 2, dtype=ti.i32): + a = ti.Matrix([[1, 2, 3], [4, 5, 6]]) + return a[:, :b:2] + + v1 = foo1() + assert (v1 == ti.Vector([0, 2, 4])).all() == 1 + m1 = foo2() + assert (m1 == ti.Matrix([[1, 3], [4, 6]])).all() == 1 diff --git a/tests/python/test_snode_layout_inspection.py b/tests/python/test_snode_layout_inspection.py new file mode 100644 index 0000000000000..ced2699a38623 --- /dev/null +++ b/tests/python/test_snode_layout_inspection.py @@ -0,0 +1,58 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test(arch=ti.cpu) +def test_primitives(): + x = ti.field(dtype=ti.i16) + y = ti.field(dtype=ti.f32) + z = ti.field(dtype=ti.f64) + + p = ti.field(dtype=ti.f32) + q = ti.field(dtype=ti.f32) + r = ti.field(dtype=ti.f64) + + n1 = ti.root.dense(ti.i, 32) + n1.place(x) + + n2 = ti.root.dense(ti.i, 32) + n2.place(y, z) + + n3 = ti.root.dense(ti.i, 1) + n3.place(p, q, r) + + assert n1._cell_size_bytes == 2 + assert n2._cell_size_bytes in [12, 16] + assert n3._cell_size_bytes == 16 + + assert n1._offset_bytes_in_parent_cell == 0 + assert n2._offset_bytes_in_parent_cell == 2 * 32 + assert n3._offset_bytes_in_parent_cell in [ + 2 * 32 + 12 * 32, 2 * 32 + 16 * 32 + ] + + assert x.snode._offset_bytes_in_parent_cell == 0 + assert y.snode._offset_bytes_in_parent_cell == 0 + assert z.snode._offset_bytes_in_parent_cell in [4, 8] + assert p.snode._offset_bytes_in_parent_cell == 0 + assert q.snode._offset_bytes_in_parent_cell == 4 + assert r.snode._offset_bytes_in_parent_cell == 8 + + +@test_utils.test(arch=ti.cpu) +def test_bit_struct(): + cit = ti.types.quantized_types.quant.int(16, False) + x = ti.field(dtype=cit) + y = ti.field(dtype=ti.types.quantized_types.type_factory.custom_float( + significand_type=cit)) + z = ti.field(dtype=ti.f32) + + n1 = ti.root.dense(ti.i, 32) + n1.bit_struct(num_bits=32).place(x) + + n2 = ti.root.dense(ti.i, 4) + n2.bit_struct(num_bits=32).place(y) + n2.place(z) + + assert n1._cell_size_bytes == 4 + assert n2._cell_size_bytes == 8 diff --git a/tests/python/test_sort.py b/tests/python/test_sort.py new file mode 100644 index 0000000000000..1eb3647e038fc --- /dev/null +++ b/tests/python/test_sort.py @@ -0,0 +1,33 @@ +import taichi as ti +from tests import test_utils + + +@test_utils.test(exclude=[ti.cc]) +def test_sort(): + def test_sort_for_dtype(dtype, N): + keys = ti.field(dtype, N) + values = ti.field(dtype, N) + + @ti.kernel + def fill(): + for i in keys: + keys[i] = ti.random() * N + values[i] = keys[i] + + fill() + ti._kernels.parallel_sort(keys, values) + + keys_host = keys.to_numpy() + values_host = values.to_numpy() + + for i in range(N): + if i < N - 1: + assert keys_host[i] <= keys_host[i + 1] + assert keys_host[i] == values_host[i] + + test_sort_for_dtype(ti.i32, 1) + test_sort_for_dtype(ti.i32, 256) + test_sort_for_dtype(ti.i32, 100001) + test_sort_for_dtype(ti.f32, 1) + test_sort_for_dtype(ti.f32, 256) + test_sort_for_dtype(ti.f32, 100001) diff --git a/tests/python/test_sparse_activate.py b/tests/python/test_sparse_activate.py index ccdec4c1fa01c..3c5d498204263 100644 --- a/tests/python/test_sparse_activate.py +++ b/tests/python/test_sparse_activate.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -29,7 +30,7 @@ def func(): assert s[None] == 32 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_non_dfs_snode_order(): x = ti.field(dtype=ti.i32) y = ti.field(dtype=ti.i32) diff --git a/tests/python/test_sparse_basics.py b/tests/python/test_sparse_basics.py index 8dab9a9db3529..118f3bfbbfd12 100644 --- a/tests/python/test_sparse_basics.py +++ b/tests/python/test_sparse_basics.py @@ -1,9 +1,10 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -26,7 +27,7 @@ def func(): assert s[None] == 256 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer_is_active(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -78,18 +79,19 @@ def func(): assert s[None] == 5 * n -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer2(): _test_pointer2() -@ti.test(require=[ti.extension.sparse, ti.extension.packed], packed=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], + packed=True) def test_pointer2_packed(): _test_pointer2() @pytest.mark.skip(reason='https://github.com/taichi-dev/taichi/issues/2520') -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer_direct_place(): x, y = ti.field(ti.i32), ti.field(ti.i32) diff --git a/tests/python/test_sparse_deactivate.py b/tests/python/test_sparse_deactivate.py index 5f352ac009dc4..e2bce91b47d76 100644 --- a/tests/python/test_sparse_deactivate.py +++ b/tests/python/test_sparse_deactivate.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer(): x = ti.field(ti.f32) s = ti.field(ti.i32, shape=()) @@ -33,7 +34,7 @@ def deactivate(): assert s[None] == 16 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer1(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -68,7 +69,7 @@ def deactivate(): assert s[None] == 32 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer2(): x = ti.field(ti.f32) @@ -106,7 +107,7 @@ def clear(): assert x[i] == 10.0 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer3(): x = ti.field(ti.f32) x_temp = ti.field(ti.f32) @@ -169,7 +170,7 @@ def clear_temp(): assert xn[i, j] == i + j -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_dynamic(): x = ti.field(ti.i32) s = ti.field(ti.i32) diff --git a/tests/python/test_sparse_linear_solver.py b/tests/python/test_sparse_linear_solver.py index 463ec5c7aa69c..a3a2a30573b8a 100644 --- a/tests/python/test_sparse_linear_solver.py +++ b/tests/python/test_sparse_linear_solver.py @@ -2,6 +2,7 @@ import pytest import taichi as ti +from tests import test_utils """ The symmetric positive definite matrix is created in matlab using the following script: @@ -29,15 +30,17 @@ ]) +@pytest.mark.parametrize("dtype", [ti.f32]) @pytest.mark.parametrize("solver_type", ["LLT", "LDLT", "LU"]) -@ti.test(arch=ti.cpu) -def test_sparse_LLT_solver(solver_type): +@pytest.mark.parametrize("ordering", ["AMD", "COLAMD"]) +@test_utils.test(arch=ti.cpu) +def test_sparse_LLT_solver(dtype, solver_type, ordering): n = 4 Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) b = ti.field(ti.f32, shape=n) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), + def fill(Abuilder: ti.types.sparse_matrix_builder(), InputArray: ti.ext_arr(), b: ti.template()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += InputArray[i, j] @@ -46,9 +49,11 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), fill(Abuilder, Aarray, b) A = Abuilder.build() - solver = ti.linalg.SparseSolver(solver_type=solver_type) + solver = ti.linalg.SparseSolver(dtype=dtype, + solver_type=solver_type, + ordering=ordering) solver.analyze_pattern(A) solver.factorize(A) x = solver.solve(b) for i in range(n): - assert x[i] == ti.approx(res[i]) + assert x[i] == test_utils.approx(res[i]) diff --git a/tests/python/test_sparse_matrix.py b/tests/python/test_sparse_matrix.py index 2b75120e88e6c..179cbefedc31a 100644 --- a/tests/python/test_sparse_matrix.py +++ b/tests/python/test_sparse_matrix.py @@ -1,13 +1,41 @@ +import pytest + import taichi as ti +from tests import test_utils + + +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_builder_deprecated_anno(dtype): + n = 8 + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) + + @ti.kernel + def fill(Abuilder: ti.types.sparse_matrix_builder()): + for i, j in ti.ndrange(n, n): + Abuilder[i, j] += i + j + + fill(Abuilder) + A = Abuilder.build() + for i in range(n): + for j in range(n): + assert A[i, j] == i + j -@ti.test(arch=ti.cpu) -def test_sparse_matrix_builder(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_builder(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j @@ -18,13 +46,36 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder()): assert A[i, j] == i + j -@ti.test(arch=ti.cpu) -def test_sparse_matrix_element_access(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_shape(dtype): + n, m = 8, 9 + Abuilder = ti.linalg.SparseMatrixBuilder(n, + m, + max_num_triplets=100, + dtype=dtype) + + @ti.kernel + def fill(Abuilder: ti.types.sparse_matrix_builder()): + for i, j in ti.ndrange(n, m): + Abuilder[i, j] += i + j + + fill(Abuilder) + A = Abuilder.build() + assert A.shape() == (n, m) + + +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_element_access(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder()): for i in range(n): Abuilder[i, i] += i @@ -34,13 +85,17 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder()): assert A[i, i] == i -@ti.test(arch=ti.cpu) -def test_sparse_matrix_element_modify(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_element_modify(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder()): for i in range(n): Abuilder[i, i] += i @@ -50,15 +105,22 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder()): assert A[0, 0] == 1024.0 -@ti.test(arch=ti.cpu) -def test_sparse_matrix_addition(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_addition(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) - Bbuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) + Bbuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), - Bbuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), + Bbuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j Bbuilder[i, j] += i - j @@ -72,15 +134,22 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), assert C[i, j] == 2 * i -@ti.test(arch=ti.cpu) -def test_sparse_matrix_subtraction(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_subtraction(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) - Bbuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) + Bbuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), - Bbuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), + Bbuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j Bbuilder[i, j] += i - j @@ -94,13 +163,17 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), assert C[i, j] == 2 * j -@ti.test(arch=ti.cpu) -def test_sparse_matrix_scalar_multiplication(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_scalar_multiplication(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j @@ -112,13 +185,17 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder()): assert B[i, j] == 3 * (i + j) -@ti.test(arch=ti.cpu) -def test_sparse_matrix_transpose(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_transpose(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j @@ -130,15 +207,22 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder()): assert B[i, j] == A[j, i] -@ti.test(arch=ti.cpu) -def test_sparse_matrix_elementwise_multiplication(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_elementwise_multiplication(dtype): n = 8 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) - Bbuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) + Bbuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), - Bbuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), + Bbuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j Bbuilder[i, j] += i - j @@ -152,15 +236,22 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), assert C[i, j] == (i + j) * (i - j) -@ti.test(arch=ti.cpu) -def test_sparse_matrix_multiplication(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_multiplication(dtype): n = 2 - Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) - Bbuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) + Bbuilder = ti.linalg.SparseMatrixBuilder(n, + n, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), - Bbuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), + Bbuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j Bbuilder[i, j] += i - j @@ -175,15 +266,22 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), assert C[1, 1] == -1.0 -@ti.test(arch=ti.cpu) -def test_sparse_matrix_nonsymmetric_multiplication(): +@pytest.mark.parametrize('dtype', [ti.f32, ti.f64]) +@test_utils.test(arch=ti.cpu) +def test_sparse_matrix_nonsymmetric_multiplication(dtype): n, k, m = 2, 3, 4 - Abuilder = ti.linalg.SparseMatrixBuilder(n, k, max_num_triplets=100) - Bbuilder = ti.linalg.SparseMatrixBuilder(k, m, max_num_triplets=100) + Abuilder = ti.linalg.SparseMatrixBuilder(n, + k, + max_num_triplets=100, + dtype=dtype) + Bbuilder = ti.linalg.SparseMatrixBuilder(k, + m, + max_num_triplets=100, + dtype=dtype) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), - Bbuilder: ti.linalg.sparse_matrix_builder()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), + Bbuilder: ti.types.sparse_matrix_builder()): for i, j in ti.ndrange(n, k): Abuilder[i, j] += i + j for i, j in ti.ndrange(k, m): diff --git a/tests/python/test_sparse_multi_tree.py b/tests/python/test_sparse_multi_tree.py index 94565ec7ac448..89ebe6a3c89d7 100644 --- a/tests/python/test_sparse_multi_tree.py +++ b/tests/python/test_sparse_multi_tree.py @@ -1,9 +1,10 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test(arch=[ti.cpu, ti.cuda]) +@test_utils.test(arch=[ti.cpu, ti.cuda]) def test_pointer(): e = ti.Vector.field(2, dtype=int, shape=16) diff --git a/tests/python/test_sparse_parallel.py b/tests/python/test_sparse_parallel.py index 4680396bc570d..1a797c15cfeb7 100644 --- a/tests/python/test_sparse_parallel.py +++ b/tests/python/test_sparse_parallel.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -26,7 +27,7 @@ def func(): assert s[None] == n * n -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_pointer2(): x = ti.field(ti.f32) s = ti.field(ti.i32) @@ -52,7 +53,7 @@ def func(): assert s[None] == N * (N - 1) / 2 -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_nested_struct_fill_and_clear(): a = ti.field(dtype=ti.f32) N = 512 diff --git a/tests/python/test_spmv.py b/tests/python/test_spmv.py index 87481c4617b40..ee0c78b4fd6be 100644 --- a/tests/python/test_spmv.py +++ b/tests/python/test_spmv.py @@ -1,14 +1,15 @@ import taichi as ti +from tests import test_utils -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_sparse_matrix_vector_multiplication1(): n = 8 Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) b = ti.field(ti.f32, shape=n) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), b: ti.template()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), b: ti.template()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i @@ -22,14 +23,14 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), b: ti.template()): assert x[i] == 8 * i -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_sparse_matrix_vector_multiplication2(): n = 8 Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) b = ti.field(ti.f32, shape=n) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), b: ti.template()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), b: ti.template()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i - j @@ -46,14 +47,14 @@ def fill(Abuilder: ti.linalg.sparse_matrix_builder(), b: ti.template()): assert x[i] == res[i] -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_sparse_matrix_vector_multiplication3(): n = 8 Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100) b = ti.field(ti.f32, shape=n) @ti.kernel - def fill(Abuilder: ti.linalg.sparse_matrix_builder(), b: ti.template()): + def fill(Abuilder: ti.types.sparse_matrix_builder(), b: ti.template()): for i, j in ti.ndrange(n, n): Abuilder[i, j] += i + j diff --git a/tests/python/test_ssa.py b/tests/python/test_ssa.py index 5fd8c4cef31c5..585af04392f99 100644 --- a/tests/python/test_ssa.py +++ b/tests/python/test_ssa.py @@ -8,10 +8,10 @@ import numpy as np import taichi as ti -from taichi import approx +from tests import test_utils -@ti.test() +@test_utils.test() def test_matrix_self_assign(): a = ti.Vector.field(2, ti.f32, ()) b = ti.Matrix.field(2, 2, ti.f32, ()) @@ -21,7 +21,7 @@ def test_matrix_self_assign(): def func(): a[None] = a[None].normalized() b[None] = b[None].transpose() - c[None] = ti.Vector([c[None][1], c[None][0]], dt=ti.f32) + c[None] = ti.Vector([c[None][1], c[None][0]]) inv_sqrt2 = 1 / math.sqrt(2) @@ -29,25 +29,25 @@ def func(): b[None] = [[1, 2], [3, 4]] c[None] = [2, 3] func() - assert a[None].value == ti.Vector([inv_sqrt2, inv_sqrt2]) - assert b[None].value == ti.Matrix([[1, 3], [2, 4]]) - assert c[None].value == ti.Vector([3, 2]) + assert a[None] == ti.Vector([inv_sqrt2, inv_sqrt2]) + assert b[None] == ti.Matrix([[1, 3], [2, 4]]) + assert c[None] == ti.Vector([3, 2]) -@ti.test() +@test_utils.test() def test_random_vector_dup_eval(): a = ti.Vector.field(2, ti.f32, ()) @ti.kernel def func(): - a[None] = ti.Vector([ti.random(), 1], dt=ti.f32).normalized() + a[None] = ti.Vector([ti.random(), 1]).normalized() for i in range(4): func() - assert a[None].value.norm_sqr() == approx(1) + assert a[None].norm_sqr() == test_utils.approx(1) -@ti.test() +@test_utils.test() def test_func_argument_dup_eval(): @ti.func def func(a, t): @@ -61,15 +61,15 @@ def kern(t: ti.f32) -> ti.f32: assert kern(1.0) == 0.0 -@ti.test() +@test_utils.test() def test_func_random_argument_dup_eval(): @ti.func def func(a): - return ti.Vector([ti.cos(a), ti.sin(a)], dt=ti.f32) + return ti.Vector([ti.cos(a), ti.sin(a)]) @ti.kernel def kern() -> ti.f32: return func(ti.random()).norm_sqr() for i in range(4): - assert kern() == approx(1.0, rel=5e-5) + assert kern() == test_utils.approx(1.0, rel=5e-5) diff --git a/tests/python/test_static.py b/tests/python/test_static.py index 08f36eab9fea7..767cf31ab8660 100644 --- a/tests/python/test_static.py +++ b/tests/python/test_static.py @@ -2,10 +2,11 @@ import pytest import taichi as ti +from tests import test_utils @pytest.mark.parametrize('val', [0, 1]) -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_static_if(val): x = ti.field(ti.i32) @@ -22,7 +23,7 @@ def static(): assert x[0] == val -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_static_if_error(): x = ti.field(ti.i32) @@ -35,11 +36,12 @@ def static(val: float): else: x[0] = 0 - with pytest.raises(ValueError, match='must be compile-time constants'): + with pytest.raises(ti.TaichiCompilationError, + match='must be compile-time constants'): static(42) -@ti.test() +@test_utils.test() def test_static_ndrange(): n = 3 x = ti.Matrix.field(n, n, dtype=ti.f32, shape=(n, n)) @@ -56,7 +58,7 @@ def fill(): assert x[i, j][i, j] == i + j * 2 -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_static_break(): x = ti.field(ti.i32, 5) @@ -72,7 +74,7 @@ def func(): assert np.allclose(x.to_numpy(), np.array([1, 1, 1, 0, 0])) -@ti.test(ti.cpu) +@test_utils.test(ti.cpu) def test_static_continue(): x = ti.field(ti.i32, 5) diff --git a/tests/python/test_stencils.py b/tests/python/test_stencils.py index 8892a834d8e77..3f34a6d5c9529 100644 --- a/tests/python/test_stencils.py +++ b/tests/python/test_stencils.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_simple(): # Note: access simplification does not work in this case. Maybe worth fixing. x = ti.field(ti.i32) diff --git a/tests/python/test_stop_grad.py b/tests/python/test_stop_grad.py index 2354509af5f60..450782b3733a9 100644 --- a/tests/python/test_stop_grad.py +++ b/tests/python/test_stop_grad.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_normal_grad(): x = ti.field(ti.f32) loss = ti.field(ti.f32) @@ -27,7 +28,7 @@ def func(): assert x.grad[i] == i * 2 -@ti.test() +@test_utils.test() def test_stop_grad(): x = ti.field(ti.f32) loss = ti.field(ti.f32) @@ -41,7 +42,7 @@ def test_stop_grad(): @ti.kernel def func(): for i in range(n): - ti.core.stop_grad(x.snode.ptr) + ti.stop_grad(x) loss[None] += x[i]**2 for i in range(n): @@ -54,7 +55,7 @@ def func(): assert x.grad[i] == 0 -@ti.test() +@test_utils.test() def test_stop_grad2(): x = ti.field(ti.f32) loss = ti.field(ti.f32) diff --git a/tests/python/test_struct.py b/tests/python/test_struct.py index a81d7a1f1e4c7..4a1343045511f 100644 --- a/tests/python/test_struct.py +++ b/tests/python/test_struct.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_linear(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -25,7 +26,7 @@ def test_linear_repeated(): test_linear() -@ti.test() +@test_utils.test() def test_linear_nested(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -44,7 +45,7 @@ def test_linear_nested(): assert y[i] == i + 123 -@ti.test() +@test_utils.test() def test_linear_nested_aos(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -62,7 +63,7 @@ def test_linear_nested_aos(): assert y[i] == i + 123 -@ti.test(exclude=[ti.vulkan]) +@test_utils.test(exclude=[ti.vulkan]) def test_2d_nested(): x = ti.field(ti.i32) diff --git a/tests/python/test_struct_for.py b/tests/python/test_struct_for.py index 60825eff1b37c..95ce85ab2efcd 100644 --- a/tests/python/test_struct_for.py +++ b/tests/python/test_struct_for.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_singleton(): x = ti.field(ti.i32, shape=()) @@ -15,7 +16,7 @@ def fill(): assert x[None] == 3 -@ti.test() +@test_utils.test() def test_singleton2(): x = ti.field(ti.i32) @@ -31,7 +32,7 @@ def fill(): assert x[None] == 3 -@ti.test() +@test_utils.test() def test_linear(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -54,7 +55,7 @@ def fill(): assert y[i] == i * 2 -@ti.test() +@test_utils.test() def test_nested(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -77,7 +78,7 @@ def fill(): assert y[i] == i * 2 -@ti.test() +@test_utils.test() def test_nested2(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -102,7 +103,7 @@ def fill(): assert y[i] == i * 2 -@ti.test() +@test_utils.test() def test_2d(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -123,7 +124,7 @@ def fill(): assert x[i, j] == i + j * 2 -@ti.test() +@test_utils.test() def test_2d_non_POT(): x = ti.field(ti.i32) y = ti.field(ti.i32, shape=()) @@ -146,7 +147,7 @@ def fill(): assert y[None] == tot -@ti.test() +@test_utils.test() def test_nested_2d(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -167,7 +168,7 @@ def fill(): assert x[i, j] == i + j * 2 -@ti.test() +@test_utils.test() def test_nested_2d_more_nests(): x = ti.field(ti.i32) y = ti.field(ti.i32) @@ -191,7 +192,7 @@ def fill(): assert x[i, j] == i + j * 2 -@ti.test() +@test_utils.test() def test_linear_k(): x = ti.field(ti.i32) @@ -210,7 +211,7 @@ def fill(): assert x[i] == i -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_struct_for_branching(): # Related issue: https://github.com/taichi-dev/taichi/issues/704 x = ti.field(dtype=ti.i32) @@ -240,7 +241,7 @@ def func3(): func3() -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_struct_for_pointer_block(): n = 16 block_size = 8 @@ -262,11 +263,11 @@ def count() -> int: assert count() == 1 -@ti.test(require=ti.extension.quant) +@test_utils.test(require=ti.extension.quant) def test_struct_for_quant(): n = 8 - ci13 = ti.quant.int(13, True) + ci13 = ti.types.quantized_types.quant.int(13, True) x = ti.field(dtype=ci13) ti.root.dense(ti.i, n).bit_struct(num_bits=32).place(x) @@ -279,3 +280,36 @@ def count() -> int: return tot assert count() == 28 + + +@test_utils.test(require=ti.extension.sparse) +def test_struct_for_continue(): + # Related issue: https://github.com/taichi-dev/taichi/issues/3272 + x = ti.field(dtype=ti.i32) + n = 4 + ti.root.pointer(ti.i, n).dense(ti.i, n).place(x) + + @ti.kernel + def init(): + for i in range(n): + x[i * n + i] = 1 + + @ti.kernel + def struct_for_continue() -> ti.i32: + cnt = 0 + for i in x: + if x[i]: continue + cnt += 1 + return cnt + + @ti.kernel + def range_for_continue() -> ti.i32: + cnt = 0 + for i in range(n * n): + if x[i]: continue + cnt += 1 + return cnt + + init() + assert struct_for_continue() == n * (n - 1) + assert range_for_continue() == n * (n - 1) diff --git a/tests/python/test_struct_for_dynamic.py b/tests/python/test_struct_for_dynamic.py index 74bbf747d2bd7..67595a6e08cf9 100644 --- a/tests/python/test_struct_for_dynamic.py +++ b/tests/python/test_struct_for_dynamic.py @@ -1,11 +1,8 @@ import taichi as ti +from tests import test_utils -def ti_support_dynamic(test): - return ti.archs_excluding(ti.opengl, ti.cc, ti.vulkan)(test) - - -@ti.test(exclude=[ti.opengl, ti.cc, ti.vulkan]) +@test_utils.test(exclude=[ti.opengl, ti.cc, ti.vulkan]) def test_dynamic(): x = ti.field(ti.i32) y = ti.field(ti.i32, shape=()) @@ -26,7 +23,7 @@ def count(): assert y[None] == n // 3 + 1 -@ti.test(exclude=[ti.opengl, ti.cc, ti.vulkan]) +@test_utils.test(exclude=[ti.opengl, ti.cc, ti.vulkan]) def test_dense_dynamic(): n = 128 diff --git a/tests/python/test_struct_for_intermediate.py b/tests/python/test_struct_for_intermediate.py index 0d6128edbce55..4971f20f19f61 100644 --- a/tests/python/test_struct_for_intermediate.py +++ b/tests/python/test_struct_for_intermediate.py @@ -1,4 +1,5 @@ import taichi as ti +from tests import test_utils def _test_nested(): @@ -20,27 +21,27 @@ def iterate(): assert x[i * n, j * m] == 1, (i, j) -@ti.test(require=ti.extension.sparse, - demote_dense_struct_fors=False, - packed=False) +@test_utils.test(require=ti.extension.sparse, + demote_dense_struct_fors=False, + packed=False) def test_nested(): _test_nested() -@ti.test(demote_dense_struct_fors=True, packed=False) +@test_utils.test(demote_dense_struct_fors=True, packed=False) def test_nested_demote(): _test_nested() -@ti.test(require=[ti.extension.sparse, ti.extension.packed], - demote_dense_struct_fors=False, - packed=True) +@test_utils.test(require=[ti.extension.sparse, ti.extension.packed], + demote_dense_struct_fors=False, + packed=True) def test_nested_packed(): _test_nested() -@ti.test(require=ti.extension.packed, - demote_dense_struct_fors=True, - packed=True) +@test_utils.test(require=ti.extension.packed, + demote_dense_struct_fors=True, + packed=True) def test_nested_demote_packed(): _test_nested() diff --git a/tests/python/test_struct_for_non_pot.py b/tests/python/test_struct_for_non_pot.py index c06337b04e763..c8eefbe3defb5 100644 --- a/tests/python/test_struct_for_non_pot.py +++ b/tests/python/test_struct_for_non_pot.py @@ -1,4 +1,5 @@ import taichi as ti +from tests import test_utils def _test_1d(): @@ -21,12 +22,12 @@ def accumulate(): assert sum[None] == 4950 -@ti.test() +@test_utils.test() def test_1d(): _test_1d() -@ti.test(require=ti.extension.packed, packed=True) +@test_utils.test(require=ti.extension.packed, packed=True) def test_1d_packed(): _test_1d() @@ -57,11 +58,11 @@ def accumulate(): assert sum[None] == gt -@ti.test() +@test_utils.test() def test_2d(): _test_2d() -@ti.test(require=ti.extension.packed, packed=True) +@test_utils.test(require=ti.extension.packed, packed=True) def test_2d_packed(): _test_2d() diff --git a/tests/python/test_svd.py b/tests/python/test_svd.py index 7b331fd1b8f0e..988dc764a248d 100644 --- a/tests/python/test_svd.py +++ b/tests/python/test_svd.py @@ -1,10 +1,10 @@ import numpy as np import taichi as ti -from taichi import approx +from tests import test_utils -@ti.test(require=ti.extension.data64, fast_math=False) +@test_utils.test(require=ti.extension.data64, fast_math=False) def test_precision(): u = ti.field(ti.f64, shape=()) v = ti.field(ti.f64, shape=()) @@ -16,8 +16,8 @@ def forward(): w[None] = ti.cast(u[None] + 7, ti.f64) / ti.cast(u[None] + 3, ti.f64) forward() - assert v[None]**2 == approx(3.25, abs=1e-12) - assert w[None] * 3 == approx(7, abs=1e-12) + assert v[None]**2 == test_utils.approx(3.25, abs=1e-12) + assert w[None] * 3 == test_utils.approx(7, abs=1e-12) def mat_equal(A, B, tol=1e-6): @@ -26,7 +26,7 @@ def mat_equal(A, B, tol=1e-6): def _test_svd(dt, n): print( - f'arch={ti.cfg.arch} default_fp={ti.cfg.default_fp} fast_math={ti.cfg.fast_math} dim={n}' + f'arch={ti.lang.impl.current_cfg().arch} default_fp={ti.lang.impl.current_cfg().default_fp} fast_math={ti.lang.impl.current_cfg().fast_math} dim={n}' ) A = ti.Matrix.field(n, n, dtype=dt, shape=()) A_reconstructed = ti.Matrix.field(n, n, dtype=dt, shape=()) @@ -58,23 +58,24 @@ def run(): for i in range(n): for j in range(n): if i != j: - assert sigma[None][i, j] == approx(0) + assert sigma[None][i, j] == test_utils.approx(0) def test_svd(): for fp in [ti.f32, ti.f64]: for d in [2, 3]: - @ti.test(require=ti.extension.data64 if fp == ti.f64 else [], - default_fp=fp, - fast_math=False) + @test_utils.test( + require=ti.extension.data64 if fp == ti.f64 else [], + default_fp=fp, + fast_math=False) def wrapped(): _test_svd(fp, d) wrapped() -@ti.test() +@test_utils.test() def test_transpose_no_loop(): A = ti.Matrix.field(3, 3, dtype=ti.f32, shape=()) U = ti.Matrix.field(3, 3, dtype=ti.f32, shape=()) diff --git a/tests/python/test_sync.py b/tests/python/test_sync.py index 2cbf659daf374..feec840081b23 100644 --- a/tests/python/test_sync.py +++ b/tests/python/test_sync.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_kernel_sync(): n = 128 x = ti.field(ti.i32, shape=(3, )) diff --git a/tests/python/test_syntax_errors.py b/tests/python/test_syntax_errors.py index a25ab12b6a969..6ecc89809c49a 100644 --- a/tests/python/test_syntax_errors.py +++ b/tests/python/test_syntax_errors.py @@ -1,10 +1,10 @@ import pytest import taichi as ti +from tests import test_utils -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_try(): x = ti.field(ti.f32) @@ -17,11 +17,11 @@ def func(): except: a = 1 - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_for_else(): x = ti.field(ti.f32) @@ -34,11 +34,11 @@ def func(): else: pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_while_else(): x = ti.field(ti.f32) @@ -51,11 +51,22 @@ def func(): else: pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() +def test_raise(): + @ti.kernel + def foo(): + raise Exception() + + with pytest.raises(ti.TaichiSyntaxError, + match='Unsupported node "Raise"') as e: + foo() + + +@test_utils.test() def test_loop_var_range(): x = ti.field(ti.f32) @@ -67,11 +78,11 @@ def func(): for i in range(10): pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_loop_var_struct(): x = ti.field(ti.f32) @@ -83,11 +94,11 @@ def func(): for i in x: pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_loop_var_struct(): x = ti.field(ti.f32) @@ -99,11 +110,11 @@ def func(): for i, j in x: pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_func_def_in_kernel(): @ti.kernel def kernel(): @@ -113,11 +124,11 @@ def func(): print(func()) - kernel() + with pytest.raises(ti.TaichiCompilationError): + kernel() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_func_def_in_func(): @ti.func def func(): @@ -131,62 +142,62 @@ def func2(): def kernel(): print(func()) - kernel() + with pytest.raises(ti.TaichiCompilationError): + kernel() -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_kernel_bad_argument_annotation(): - with pytest.raises(ti.KernelDefError, match='annotation'): + with pytest.raises(ti.TaichiSyntaxError, match='annotation'): @ti.kernel def kernel(x: 'bar'): print(x) -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_func_bad_argument_annotation(): - with pytest.raises(ti.KernelDefError, match='annotation'): + with pytest.raises(ti.TaichiSyntaxError, match='annotation'): @ti.func def func(x: 'foo'): print(x) -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_nested_static(): @ti.kernel def func(): for i in ti.static(ti.static(range(1))): pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_nested_grouped(): @ti.kernel def func(): for i in ti.grouped(ti.grouped(range(1))): pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_nested_ndrange(): @ti.kernel def func(): for i in ti.ndrange(ti.ndrange(1)): pass - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_static_grouped_struct_for(): val = ti.field(ti.i32) @@ -197,11 +208,11 @@ def test(): for I in ti.static(ti.grouped(val)): pass - test() + with pytest.raises(ti.TaichiCompilationError): + test() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_is(): b = ti.field(ti.i32, shape=()) c = ti.field(ti.i32, shape=()) @@ -210,11 +221,11 @@ def test_is(): def func(): a = b is c - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_is_not(): b = ti.field(ti.i32, shape=()) c = ti.field(ti.i32, shape=()) @@ -223,11 +234,11 @@ def test_is_not(): def func(): a = b is not c - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_in(): b = ti.field(ti.i32, shape=()) c = ti.field(ti.i32, shape=()) @@ -236,11 +247,11 @@ def test_in(): def func(): a = b in c - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_not_in(): b = ti.field(ti.i32, shape=()) c = ti.field(ti.i32, shape=()) @@ -249,78 +260,71 @@ def test_not_in(): def func(): a = b not in c - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test() -@ti.must_throw(ti.TaichiSyntaxError) +@test_utils.test() def test_expr_set(): @ti.kernel def func(): x = {2, 4, 6} - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.cpu) -def test_func_multiple_return(): - @ti.func - def safe_sqrt(a): - if a > 0: - return ti.sqrt(a) - else: - return 0.0 - +@test_utils.test() +def test_redefining_template_args(): @ti.kernel - def kern(a: float): - print(safe_sqrt(a)) + def foo(a: ti.template()): + a = 5 - with pytest.raises(ti.TaichiSyntaxError, - match='cannot have multiple returns'): - kern(-233) + with pytest.raises( + ti.TaichiSyntaxError, + match= + "Variable 'a' cannot be assigned. Maybe it is not a Taichi object?" + ): + foo(1) -@ti.test(arch=ti.cpu) -def test_func_multiple_return_in_static_if(): - @ti.func - def safe_static_sqrt(a: ti.template()): - if ti.static(a > 0): - return ti.sqrt(a) - else: - return 0.0 - +@test_utils.test() +def test_break_in_outermost_for(): @ti.kernel - def kern(): - print(safe_static_sqrt(-233)) + def foo(): + for i in range(10): + break - kern() + with pytest.raises(ti.TaichiSyntaxError, + match="Cannot break in the outermost loop"): + foo() -@ti.test() -def test_func_def_inside_kernel(): +@test_utils.test() +def test_funcdef_in_kernel(): @ti.kernel - def k(): - @ti.func - def illegal(): - return 1 + def foo(): + def bar(): + pass - with pytest.raises(ti.TaichiSyntaxError, - match='Function definition not allowed'): - k() + with pytest.raises( + ti.TaichiSyntaxError, + match="Function definition is not allowed in 'ti.kernel'"): + foo() -@ti.test() -def test_func_def_inside_func(): +@test_utils.test() +def test_funcdef_in_func(): @ti.func - def f(): - @ti.func - def illegal(): - return 1 + def foo(): + def bar(): + pass @ti.kernel - def k(): - f() + def baz(): + foo() - with pytest.raises(ti.TaichiSyntaxError, - match='Function definition not allowed'): - k() + with pytest.raises( + ti.TaichiSyntaxError, + match="Function definition is not allowed in 'ti.func'"): + baz() diff --git a/tests/python/test_tensor_dimensionality.py b/tests/python/test_tensor_dimensionality.py index 7023adbe96df4..5eb079a0934ad 100644 --- a/tests/python/test_tensor_dimensionality.py +++ b/tests/python/test_tensor_dimensionality.py @@ -1,14 +1,15 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def _test_dimensionality(d): x = ti.Vector.field(2, dtype=ti.i32, shape=(2, ) * d) @ti.kernel def fill(): for I in ti.grouped(x): - x[I] += ti.Vector([I.sum(), I[0]], dt=ti.i32) + x[I] += ti.Vector([I.sum(), I[0]]) for i in range(2**d): indices = [] @@ -17,7 +18,7 @@ def fill(): x.__getitem__(tuple(indices))[0] = sum(indices) * 2 fill() # FIXME(yuanming-hu): snode_writer needs 9 arguments actually.. - if ti.cfg.arch == ti.cc and d >= 8: + if ti.lang.impl.current_cfg().arch == ti.cc and d >= 8: return for i in range(2**d): indices = [] @@ -27,5 +28,5 @@ def fill(): def test_dimensionality(): - for i in range(2, ti.core.get_max_num_indices() + 1): + for i in range(2, ti._lib.core.get_max_num_indices() + 1): _test_dimensionality(i) diff --git a/tests/python/test_tensor_reflection.py b/tests/python/test_tensor_reflection.py index fe31cba6ffdc0..ae662418b6819 100644 --- a/tests/python/test_tensor_reflection.py +++ b/tests/python/test_tensor_reflection.py @@ -1,9 +1,11 @@ import pytest +from taichi.lang import impl import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_POT(): val = ti.field(ti.i32) @@ -17,7 +19,7 @@ def test_POT(): assert val.dtype == ti.i32 -@ti.test() +@test_utils.test() def test_non_POT(): val = ti.field(ti.i32) @@ -34,7 +36,7 @@ def test_non_POT(): assert val.dtype == ti.i32 -@ti.test() +@test_utils.test() def test_unordered(): val = ti.field(ti.i32) @@ -56,18 +58,18 @@ def test_unordered(): assert val.snode.parent(3) == blk1 assert val.snode.parent(4) == ti.root - assert val.snode in blk3.get_children() - assert blk3 in blk2.get_children() - assert blk2 in blk1.get_children() - ti.get_runtime().materialize() - assert blk1 in ti.FieldsBuilder.finalized_roots()[0].get_children() + assert val.snode in blk3._get_children() + assert blk3 in blk2._get_children() + assert blk2 in blk1._get_children() + impl.get_runtime().materialize_root_fb(False) + assert blk1 in ti.FieldsBuilder._finalized_roots()[0]._get_children() expected_str = f'ti.root => dense {[n]} => dense {[m, n]}' \ f' => dense {[m, p, n]} => place {[m, p, n]}' assert str(val.snode) == expected_str -@ti.test() +@test_utils.test() def test_unordered_matrix(): val = ti.Matrix.field(3, 2, ti.i32) @@ -88,9 +90,12 @@ def test_unordered_matrix(): assert val.snode.parent(2) == blk2 assert val.snode.parent(3) == blk1 assert val.snode.parent(4) == ti.root + assert val.snode._path_from_root() == [ + ti.root, blk1, blk2, blk3, val.snode + ] -@ti.test() +@test_utils.test() def test_parent_exceeded(): val = ti.field(ti.f32) diff --git a/tests/python/test_test.py b/tests/python/test_test.py index 6421fe5feab86..c5444904c8fd3 100644 --- a/tests/python/test_test.py +++ b/tests/python/test_test.py @@ -3,80 +3,82 @@ TODO: Skips these tests after all tests are using @ti.test ''' +import os import pytest import taichi as ti +from tests import test_utils ### `ti.test` -@ti.test() +@test_utils.test() def test_all_archs(): - assert ti.cfg.arch in ti.supported_archs() + assert ti.lang.impl.current_cfg().arch in test_utils.expected_archs() -@ti.test(arch=ti.cpu) +@test_utils.test(arch=ti.cpu) def test_arch_cpu(): - assert ti.cfg.arch in [ti.cpu] + assert ti.lang.impl.current_cfg().arch in [ti.cpu] -@ti.test(arch=[ti.cpu]) +@test_utils.test(arch=[ti.cpu]) def test_arch_list_cpu(): - assert ti.cfg.arch in [ti.cpu] + assert ti.lang.impl.current_cfg().arch in [ti.cpu] -@ti.test(exclude=ti.cpu) +@test_utils.test(exclude=ti.cpu) def test_exclude_cpu(): - assert ti.cfg.arch not in [ti.cpu] + assert ti.lang.impl.current_cfg().arch not in [ti.cpu] -@ti.test(exclude=[ti.cpu]) +@test_utils.test(exclude=[ti.cpu]) def test_exclude_list_cpu(): - assert ti.cfg.arch not in [ti.cpu] + assert ti.lang.impl.current_cfg().arch not in [ti.cpu] -@ti.test(arch=ti.opengl) +@test_utils.test(arch=ti.opengl) def test_arch_opengl(): - assert ti.cfg.arch in [ti.opengl] + assert ti.lang.impl.current_cfg().arch in [ti.opengl] -@ti.test(arch=[ti.cpu, ti.opengl, ti.metal]) +@test_utils.test(arch=[ti.cpu, ti.opengl, ti.metal]) def test_multiple_archs(): - assert ti.cfg.arch in [ti.cpu, ti.opengl, ti.metal] + assert ti.lang.impl.current_cfg().arch in [ti.cpu, ti.opengl, ti.metal] -@ti.test(arch=ti.cpu, debug=True, advanced_optimization=False) +@test_utils.test(arch=ti.cpu, debug=True, advanced_optimization=False) def test_init_args(): - assert ti.cfg.debug == True - assert ti.cfg.advanced_optimization == False + assert ti.lang.impl.current_cfg().debug == True + assert ti.lang.impl.current_cfg().advanced_optimization == False -@ti.test(require=ti.extension.sparse) +@test_utils.test(require=ti.extension.sparse) def test_require_extensions_1(): - assert ti.cfg.arch in [ti.cpu, ti.cuda, ti.metal] + assert ti.lang.impl.current_cfg().arch in [ti.cpu, ti.cuda, ti.metal] -@ti.test(arch=[ti.cpu, ti.opengl], require=ti.extension.sparse) +@test_utils.test(arch=[ti.cpu, ti.opengl], require=ti.extension.sparse) def test_require_extensions_2(): - assert ti.cfg.arch in [ti.cpu] + assert ti.lang.impl.current_cfg().arch in [ti.cpu] -@ti.test(arch=[ti.cpu, ti.opengl], - require=[ti.extension.sparse, ti.extension.bls]) +@test_utils.test(arch=[ti.cpu, ti.opengl], + require=[ti.extension.sparse, ti.extension.bls]) def test_require_extensions_2(): - assert ti.cfg.arch in [ti.cuda] + assert ti.lang.impl.current_cfg().arch in [ti.cuda] -### `ti.approx` and `ti.allclose` +### `test_utils.approx` and `test_utils.allclose` @pytest.mark.parametrize('x', [0.1, 3]) -@pytest.mark.parametrize('allclose', - [ti.allclose, lambda x, y: x == ti.approx(y)]) -@ti.test() +@pytest.mark.parametrize( + 'allclose', [test_utils.allclose, lambda x, y: x == test_utils.approx(y)]) +@test_utils.test() def test_allclose_rel(x, allclose): - rel = ti.get_rel_eps() + rel = test_utils.get_rel_eps() assert not allclose(x + x * rel * 3.0, x) assert not allclose(x + x * rel * 1.2, x) assert allclose(x + x * rel * 0.9, x) @@ -89,11 +91,11 @@ def test_allclose_rel(x, allclose): @pytest.mark.parametrize('x', [0.1, 3]) -@pytest.mark.parametrize('allclose', - [ti.allclose, lambda x, y: x == ti.approx(y)]) -@ti.test() +@pytest.mark.parametrize( + 'allclose', [test_utils.allclose, lambda x, y: x == test_utils.approx(y)]) +@test_utils.test() def test_allclose_rel_reordered1(x, allclose): - rel = ti.get_rel_eps() + rel = test_utils.get_rel_eps() assert not allclose(x + x * rel * 3.0, x) assert not allclose(x + x * rel * 1.2, x) assert allclose(x + x * rel * 0.9, x) @@ -106,11 +108,11 @@ def test_allclose_rel_reordered1(x, allclose): @pytest.mark.parametrize('x', [0.1, 3]) -@pytest.mark.parametrize('allclose', - [ti.allclose, lambda x, y: x == ti.approx(y)]) -@ti.test() +@pytest.mark.parametrize( + 'allclose', [test_utils.allclose, lambda x, y: x == test_utils.approx(y)]) +@test_utils.test() def test_allclose_rel_reordered2(x, allclose): - rel = ti.get_rel_eps() + rel = test_utils.get_rel_eps() assert not allclose(x + x * rel * 3.0, x) assert not allclose(x + x * rel * 1.2, x) assert allclose(x + x * rel * 0.9, x) @@ -120,3 +122,20 @@ def test_allclose_rel_reordered2(x, allclose): assert allclose(x - x * rel * 0.9, x) assert not allclose(x - x * rel * 1.2, x) assert not allclose(x - x * rel * 3.0, x) + + +@pytest.mark.skipif(ti._lib.core.with_metal(), + reason="Skip metal because metal is used as the example") +def test_disable_fallback(): + old_environ = os.environ.get('TI_WANTED_ARCHS', '') + os.environ['TI_WANTED_ARCHS'] = "metal" + + with pytest.raises(RuntimeError): + + @test_utils.test(ti.metal) + def test(): + pass + + test() + os.environ['TI_WANTED_ARCHS'] = old_environ + os.environ['TI_WANTED_ARCHS'] = old_environ diff --git a/tests/python/test_threading.py b/tests/python/test_threading.py index 3bc5d5d678ee9..67980c6a8b863 100644 --- a/tests/python/test_threading.py +++ b/tests/python/test_threading.py @@ -1,6 +1,9 @@ +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_while(): - assert ti.core.test_threading() + assert ti._lib.core.test_threading() diff --git a/tests/python/test_torch_ad.py b/tests/python/test_torch_ad.py index 7da20ff43b170..d2422b6ff2eb8 100644 --- a/tests/python/test_torch_ad.py +++ b/tests/python/test_torch_ad.py @@ -1,14 +1,18 @@ +import sys + import numpy as np import pytest +from taichi.lang.util import has_pytorch import taichi as ti +from tests import test_utils -if ti.has_pytorch(): +if has_pytorch(): import torch -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=ti.opengl) def test_torch_ad(): n = 32 @@ -47,8 +51,9 @@ def backward(ctx, outp_grad): assert ret[j] == 4 -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@pytest.mark.skipif(sys.platform == 'win32', reason='not working on Windows.') +@test_utils.test(exclude=ti.opengl) def test_torch_ad_gpu(): if not torch.cuda.is_available(): return diff --git a/tests/python/test_torch_io.py b/tests/python/test_torch_io.py index 75d2bdb6a812b..fc6182efbddd4 100644 --- a/tests/python/test_torch_io.py +++ b/tests/python/test_torch_io.py @@ -1,14 +1,17 @@ import numpy as np import pytest +from taichi.lang import impl +from taichi.lang.util import has_pytorch import taichi as ti +from tests import test_utils -if ti.has_pytorch(): +if has_pytorch(): import torch -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io_devices(): n = 32 x = ti.field(dtype=ti.i32, shape=n) @@ -44,8 +47,8 @@ def store(y: ti.ext_arr()): assert y[i] == (11 + i) * 2 -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io(): n = 32 @@ -84,8 +87,8 @@ def backward(ctx, outp_grad): assert ret[i] == 4 -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io_2d(): n = 32 @@ -108,8 +111,8 @@ def forward(ctx, inp): assert val == 2 * 2 * n * n -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io_3d(): n = 16 @@ -134,8 +137,8 @@ def forward(ctx, inp): assert val == 2 * 2 * n * n * n -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io_simple(): n = 32 @@ -161,8 +164,8 @@ def test_io_simple(): assert (t2 == t3).all() -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io_zeros(): mat = ti.Matrix.field(2, 6, dtype=ti.f32, shape=(), needs_grad=True) zeros = torch.zeros((2, 6)) @@ -175,8 +178,8 @@ def test_io_zeros(): assert zeros[1, 2] == 4 -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_io_struct(): n = 16 x1 = ti.Struct.field({"a": ti.i32, "b": ti.f32}, shape=(n, )) @@ -195,20 +198,20 @@ def test_io_struct(): assert (t1[k] == t2[k]).all() -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_fused_kernels(): n = 12 X = ti.Matrix.field(3, 2, ti.f32, shape=(n, n, n)) - s = ti.get_runtime().get_num_compiled_functions() + s = impl.get_runtime().get_num_compiled_functions() t = X.to_torch() - assert ti.get_runtime().get_num_compiled_functions() == s + 1 + assert impl.get_runtime().get_num_compiled_functions() == s + 1 X.from_torch(t) - assert ti.get_runtime().get_num_compiled_functions() == s + 2 + assert impl.get_runtime().get_num_compiled_functions() == s + 2 -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_device(): n = 12 X = ti.Matrix.field(3, 2, ti.f32, shape=(n, n, n)) @@ -218,8 +221,8 @@ def test_device(): assert X.to_torch(device='cuda:0').device == torch.device('cuda:0') -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_shape_matrix(): n = 12 x = ti.Matrix.field(3, 2, ti.f32, shape=(n, n)) @@ -238,8 +241,8 @@ def test_shape_matrix(): assert (X == X1).all() -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_shape_vector(): n = 12 x = ti.Vector.field(3, ti.f32, shape=(n, n)) @@ -257,8 +260,8 @@ def test_shape_vector(): assert (X == X1).all() -@pytest.mark.skipif(not ti.has_pytorch(), reason='Pytorch not installed.') -@ti.test(exclude=ti.opengl) +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_torch_zero(): @ti.kernel def test_torch(arr: ti.ext_arr()): @@ -267,3 +270,19 @@ def test_torch(arr: ti.ext_arr()): test_torch(torch.zeros((0), dtype=torch.int32)) test_torch(torch.zeros((0, 5), dtype=torch.int32)) test_torch(torch.zeros((5, 0, 5), dtype=torch.int32)) + + +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) +def test_torch_view(): + @ti.kernel + def copy(x: ti.any_arr(), y: ti.any_arr()): + for i, j in x: + y[i, j] = x[i, j] + + x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).T + y = ti.ndarray(int, (3, 3)) + + with pytest.raises(ValueError, + match=r'Non contiguous tensors are not supported'): + copy(x, y) diff --git a/tests/python/test_trailing_bits.py b/tests/python/test_trailing_bits.py deleted file mode 100644 index ece5a1d738d2a..0000000000000 --- a/tests/python/test_trailing_bits.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest - -import taichi as ti - - -def _test_trailing_bits(): - ti.init(arch=ti.cpu, debug=True, print_ir=True) - - x = ti.field(ti.f32) - y = ti.field(ti.f32) - - block = ti.root.pointer(ti.i, 8) - block.dense(ti.i, 32).place(x) - - # Here every 32 ti.i share the same dense node of 16 y along ti.j. - block.dense(ti.j, 16).place(y) - assert y.shape == (256, 16) - # instead of y.shape == (8, 16), - # since there are 5 trailing bits for ti.i for y's SNode - - assert x.shape == (256, ) - - y[255, 15] = 0 - - with pytest.raises(RuntimeError): - y[256, 15] = 0 - - with pytest.raises(RuntimeError): - y[255, 16] = 0 - - y[255, 3] = 123 - - # They are the same element... - assert y[255, 3] == 123 - assert y[254, 3] == 123 - assert y[240, 3] == 123 - - -def _test_inconsistent_trailing_bits(): - ti.init(arch=ti.cpu, debug=True, print_ir=True) - - x = ti.field(ti.f32) - y = ti.field(ti.f32) - z = ti.field(ti.f32) - - block = ti.root.pointer(ti.i, 8) - - # Here the numbers of bits of x and z are inconsistent, - # which leads to the RuntimeError below. - block.dense(ti.i, 32).place(x) - block.dense(ti.i, 16).place(z) - - block.dense(ti.j, 16).place(y) - - with pytest.raises(RuntimeError): - ti.get_runtime().materialize() diff --git a/tests/python/test_tuple_assign.py b/tests/python/test_tuple_assign.py index cdd8be458184d..03dc05bec7d18 100644 --- a/tests/python/test_tuple_assign.py +++ b/tests/python/test_tuple_assign.py @@ -1,7 +1,11 @@ +import pytest +from taichi.lang.misc import get_host_arch_list + import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_fibonacci(): @ti.kernel def ti_fibonacci(n: ti.i32) -> ti.i32: @@ -22,7 +26,7 @@ def py_fibonacci(n): assert ti_fibonacci(n) == py_fibonacci(n) -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_assign2(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -36,8 +40,7 @@ def func(): assert b[None] == 3 -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(ValueError) +@test_utils.test(arch=get_host_arch_list()) def test_assign2_mismatch3(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -46,11 +49,11 @@ def test_assign2_mismatch3(): def func(): a[None], b[None] = 2, 3, 4 - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(TypeError) +@test_utils.test(arch=get_host_arch_list()) def test_assign2_mismatch1(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -59,10 +62,11 @@ def test_assign2_mismatch1(): def func(): a[None], b[None] = 2 - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_swap2(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -78,7 +82,7 @@ def func(): assert b[None] == 2 -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_assign2_static(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -94,7 +98,7 @@ def func(): assert b[None] == 2 -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_swap3(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -113,7 +117,7 @@ def func(): assert c[None] == 2 -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_from_tuple(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -131,8 +135,7 @@ def func(): assert c[None] == 4 -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(ValueError) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_mismatch_tuple(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -143,10 +146,11 @@ def test_unpack_mismatch_tuple(): def func(): a[None], b[None] = list - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_from_vector(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -163,8 +167,7 @@ def func(): assert c[None] == 4 -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(ValueError) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_mismatch_vector(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -174,11 +177,11 @@ def func(): vector = ti.Vector([2, 3, 4]) a[None], b[None] = vector - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(TypeError) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_mismatch_type(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -189,11 +192,11 @@ def test_unpack_mismatch_type(): def func(): a[None], b[None] = bad - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) -@ti.must_throw(ValueError) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_mismatch_matrix(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) @@ -205,10 +208,11 @@ def func(): bad = ti.Matrix([[2, 3], [4, 5]]) a[None], b[None], c[None], d[None] = bad - func() + with pytest.raises(ti.TaichiCompilationError): + func() -@ti.test(arch=ti.get_host_arch_list()) +@test_utils.test(arch=get_host_arch_list()) def test_unpack_from_shape(): a = ti.field(ti.f32, ()) b = ti.field(ti.f32, ()) diff --git a/tests/python/test_type_check.py b/tests/python/test_type_check.py new file mode 100644 index 0000000000000..0330e7be74317 --- /dev/null +++ b/tests/python/test_type_check.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest +from taichi.lang.util import has_pytorch + +import taichi as ti +from tests import test_utils + + +@test_utils.test(arch=ti.cpu) +def test_unary_op(): + @ti.kernel + def floor(): + a = 1 + b = ti.floor(a) + + with pytest.raises(ti.TaichiTypeError, + match="'floor' takes real inputs only"): + floor() + + +@test_utils.test(arch=ti.cpu) +def test_binary_op(): + @ti.kernel + def bitwise_float(): + a = 1 + b = 3.1 + c = a & b + + with pytest.raises(ti.TaichiTypeError, + match=r"unsupported operand type\(s\) for '&'"): + bitwise_float() + + +@test_utils.test(arch=ti.cpu) +def test_ternary_op(): + @ti.kernel + def select(): + a = 1.1 + b = 3 + c = 3.6 + d = b if a else c + + with pytest.raises(TypeError, + match="`if` conditions must be of type int32"): + select() + + +@pytest.mark.skipif(not has_pytorch(), reason='Pytorch not installed.') +@test_utils.test(arch=[ti.cpu, ti.opengl]) +def test_subscript(): + a = ti.ndarray(ti.i32, shape=(10, 10)) + + @ti.kernel + def any_array(x: ti.any_arr()): + b = x[3, 1.1] + + with pytest.raises(ti.TaichiTypeError, match="indices must be integers"): + any_array(a) + + +@test_utils.test() +def test_0d_ndarray(): + @ti.kernel + def foo() -> ti.i32: + a = np.array(3, dtype=np.int32) + return a + + assert foo() == 3 + + +@test_utils.test() +def test_non_0d_ndarray(): + @ti.kernel + def foo(): + a = np.array([1]) + + with pytest.raises( + ti.TaichiTypeError, + match= + "Only 0-dimensional numpy array can be used to initialize a scalar expression" + ): + foo() diff --git a/tests/python/test_types.py b/tests/python/test_types.py index 72598324ed77f..2d1cb785136be 100644 --- a/tests/python/test_types.py +++ b/tests/python/test_types.py @@ -1,6 +1,8 @@ import pytest +from taichi.lang import impl import taichi as ti +from tests import test_utils _TI_TYPES = [ti.i8, ti.i16, ti.i32, ti.u8, ti.u16, ti.u32, ti.f32] _TI_64_TYPES = [ti.i64, ti.u64, ti.f64] @@ -18,13 +20,13 @@ def func(value: dt): @pytest.mark.parametrize('dt', _TI_TYPES) -@ti.test(exclude=[ti.opengl, ti.vulkan]) +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_type_assign_argument(dt): _test_type_assign_argument(dt) @pytest.mark.parametrize('dt', _TI_64_TYPES) -@ti.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) +@test_utils.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) def test_type_assign_argument64(dt): _test_type_assign_argument(dt) @@ -50,13 +52,13 @@ def func(): @pytest.mark.parametrize('dt', _TI_TYPES) -@ti.test(exclude=[ti.opengl, ti.vulkan]) +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_type_operator(dt): _test_type_operator(dt) @pytest.mark.parametrize('dt', _TI_64_TYPES) -@ti.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) +@test_utils.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) def test_type_operator64(dt): _test_type_operator(dt) @@ -75,13 +77,13 @@ def func(i: ti.i32, j: ti.i32): @pytest.mark.parametrize('dt', _TI_TYPES) -@ti.test(exclude=[ti.opengl, ti.vulkan]) +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_type_field(dt): _test_type_field(dt) @pytest.mark.parametrize('dt', _TI_64_TYPES) -@ti.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) +@test_utils.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) def test_type_field64(dt): _test_type_field(dt) @@ -103,7 +105,7 @@ def func(): assert a[None] == 2**n // 3 assert b[None] == 2**n // 3 - if ti.core.is_signed(dt): + if ti.types.is_signed(dt): assert c[None] == 2**n // 3 * 2 - (2**n) # overflows else: assert c[None] == 2**n // 3 * 2 # does not overflow @@ -117,7 +119,7 @@ def func(): (ti.i32, 32), (ti.u32, 32), ]) -@ti.test(exclude=[ti.opengl, ti.vulkan]) +@test_utils.test(exclude=[ti.opengl, ti.vulkan]) def test_overflow(dt, n): _test_overflow(dt, n) @@ -126,7 +128,7 @@ def test_overflow(dt, n): (ti.i64, 64), (ti.u64, 64), ]) -@ti.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) +@test_utils.test(exclude=[ti.opengl, ti.vulkan], require=ti.extension.data64) def test_overflow64(dt, n): _test_overflow(dt, n) @@ -135,10 +137,10 @@ def test_overflow64(dt, n): (ti.u32, 0xffffffff), (ti.u64, 0xffffffffffffffff), ]) -@ti.test(require=ti.extension.data64) +@test_utils.test(require=ti.extension.data64) def test_uint_max(dt, val): # https://github.com/taichi-dev/taichi/issues/2060 - ti.get_runtime().default_ip = dt + impl.get_runtime().default_ip = dt N = 16 f = ti.field(dt, shape=N) diff --git a/tests/python/test_unary_ops.py b/tests/python/test_unary_ops.py index faa4342132bb5..5aea17db3eaea 100644 --- a/tests/python/test_unary_ops.py +++ b/tests/python/test_unary_ops.py @@ -1,10 +1,13 @@ import numpy as np import taichi as ti +from tests import test_utils def _test_op(dt, taichi_op, np_op): - print('arch={} default_fp={}'.format(ti.cfg.arch, ti.cfg.default_fp)) + print('arch={} default_fp={}'.format( + ti.lang.impl.current_cfg().arch, + ti.lang.impl.current_cfg().default_fp)) n = 4 val = ti.field(dt, shape=n) @@ -23,9 +26,10 @@ def fill(): if dt == ti.f64: assert abs(np_op(float(f(i))) - val[i]) < 1e-15 else: - assert abs( - np_op(float(f(i))) - val[i] - ) < 1e-6 if ti.cfg.arch != ti.opengl and ti.cfg.arch != ti.vulkan else 1e-5 + assert abs(np_op(float(f(i))) - + val[i]) < 1e-6 if ti.lang.impl.current_cfg( + ).arch != ti.opengl and ti.lang.impl.current_cfg( + ).arch != ti.vulkan else 1e-5 def test_f64_trig(): @@ -42,8 +46,9 @@ def test_f64_trig(): for dt in [ti.f32, ti.f64]: for taichi_op, np_op in op_pairs: - @ti.test(require=ti.extension.data64 if dt == ti.f64 else [], - default_fp=dt) + @test_utils.test( + require=ti.extension.data64 if dt == ti.f64 else [], + default_fp=dt) def wrapped(): _test_op(dt, taichi_op, np_op) diff --git a/tests/python/test_while.py b/tests/python/test_while.py index f3a51815fbfca..a3c53e7973da7 100644 --- a/tests/python/test_while.py +++ b/tests/python/test_while.py @@ -1,7 +1,8 @@ import taichi as ti +from tests import test_utils -@ti.test() +@test_utils.test() def test_while(): x = ti.field(ti.f32) @@ -22,7 +23,7 @@ def func(): assert x[0] == 45 -@ti.test() +@test_utils.test() def test_break(): ret = ti.field(ti.i32, shape=()) diff --git a/tests/run_tests.py b/tests/run_tests.py new file mode 100644 index 0000000000000..1f00d4c2d3c54 --- /dev/null +++ b/tests/run_tests.py @@ -0,0 +1,216 @@ +import argparse +import os +import pdb +import subprocess +import sys +import warnings + +import taichi as ti + + +def _test_cpp(): + ti.reset() + print("Running C++ tests...") + ti_lib_dir = os.path.join(ti.__path__[0], '_lib', 'runtime') + + cpp_test_filename = 'taichi_cpp_tests' + curr_dir = os.path.dirname(os.path.abspath(__file__)) + build_dir = os.path.join(curr_dir, '../build') + if os.path.exists(os.path.join(build_dir, cpp_test_filename)): + subprocess.check_call(f'./{cpp_test_filename}', + env={'TI_LIB_DIR': ti_lib_dir}, + cwd=build_dir) + else: + warnings.warn( + f"C++ tests are skipped due to missing {cpp_test_filename} in {build_dir}." + "Try building taichi with `TAICHI_CMAKE_ARGS=\'-DTI_BUILD_TESTS:BOOL=ON\' python setup.py develop`" + "if you want to enable it.") + + +def _test_python(args): + print("\nRunning Python tests...\n") + + test_38 = sys.version_info >= (3, 8) + + curr_dir = os.path.dirname(os.path.abspath(__file__)) + test_dir = os.path.join(curr_dir, 'python') + pytest_args = [] + + # TODO: use pathlib to deal with suffix and stem name manipulation + if args.files: + # run individual tests + for f in args.files: + # auto-complete file names + if not f.startswith('test_'): + f = 'test_' + f + if not f.endswith('.py'): + f = f + '.py' + file = os.path.join(test_dir, f) + has_tests = False + if os.path.exists(file): + pytest_args.append(file) + has_tests = True + assert has_tests, f"Test {f} does not exist." + else: + # run all the tests + pytest_args = [test_dir] + if args.verbose: + pytest_args += ['-v'] + if args.rerun: + pytest_args += ['--reruns', args.rerun] + try: + if args.coverage: + pytest_args += ['--cov-branch', '--cov=python/taichi'] + if args.cov_append: + pytest_args += ['--cov-append'] + if args.keys: + pytest_args += ['-k', args.keys] + if args.marks: + pytest_args += ['-m', args.marks] + if args.failed_first: + pytest_args += ['--failed-first'] + if args.fail_fast: + pytest_args += ['--exitfirst'] + except AttributeError: + pass + + try: + from multiprocessing import cpu_count # pylint: disable=C0415 + threads = min(8, cpu_count()) # To prevent running out of memory + except NotImplementedError: + threads = 2 + + if not os.environ.get('TI_DEVICE_MEMORY_GB'): + os.environ['TI_DEVICE_MEMORY_GB'] = '1.0' # Discussion: #769 + + env_threads = os.environ.get('TI_TEST_THREADS', '') + threads = args.threads or env_threads or threads + print(f'Starting {threads} testing thread(s)...') + if args.show_output: + pytest_args += ['-s'] + print( + f'Due to how pytest-xdist is implemented, the -s option does not work with multiple thread...' + ) + else: + if int(threads) > 1: + pytest_args += ['-n', str(threads)] + import pytest # pylint: disable=C0415 + return int(pytest.main(pytest_args)) + + +def test(): + """Run the tests""" + parser = argparse.ArgumentParser( + description=f"Run taichi cpp & python tess") + parser.add_argument('files', + nargs='*', + help='Test name(s) to be run, e.g. "cli"') + parser.add_argument('-c', + '--cpp', + dest='cpp', + default=True, + action='store_true', + help='Run the C++ tests') + parser.add_argument('-s', + '--show', + dest='show_output', + action='store_true', + help='Show output (do not capture)') + parser.add_argument('-v', + '--verbose', + dest='verbose', + action='store_true', + help='Run with verbose outputs') + parser.add_argument('-r', + '--rerun', + required=False, + default=None, + dest='rerun', + type=str, + help='Rerun failed tests for given times') + parser.add_argument('-k', + '--keys', + required=False, + default=None, + dest='keys', + type=str, + help='Only run tests that match the keys') + parser.add_argument('-m', + '--marks', + required=False, + default=None, + dest='marks', + type=str, + help='Only run tests with specific marks') + parser.add_argument('-f', + '--failed-first', + required=False, + default=None, + dest='failed_first', + action='store_true', + help='Run the previously failed test first') + parser.add_argument('-x', + '--fail-fast', + required=False, + default=None, + dest='fail_fast', + action='store_true', + help='Exit instantly on the first failed test') + parser.add_argument('-C', + '--coverage', + required=False, + default=None, + dest='coverage', + action='store_true', + help='Run tests and record the coverage result') + parser.add_argument( + '-A', + '--cov-append', + required=False, + default=None, + dest='cov_append', + action='store_true', + help='Append coverage result to existing one instead of overriding it') + parser.add_argument('-t', + '--threads', + required=False, + default=None, + dest='threads', + type=str, + help='Custom number of threads for parallel testing') + parser.add_argument('-a', + '--arch', + required=False, + default=None, + dest='arch', + type=str, + help='Custom the arch(s) (backend) to run tests on') + parser.add_argument( + '-n', + '--exclusive', + required=False, + default=False, + dest='exclusive', + action='store_true', + help= + 'Exclude arch(s) from test instead of include them, together with -a') + + args = parser.parse_args() + print(args) + + if args.arch: + arch = args.arch + if args.exclusive: + arch = '^' + arch + print(f'Running on Arch={arch}') + os.environ['TI_WANTED_ARCHS'] = arch + + if args.cpp: + _test_cpp() + + if _test_python(args) != 0: + exit(1) + + +if __name__ == '__main__': + test() diff --git a/python/taichi/testing.py b/tests/test_utils.py similarity index 59% rename from python/taichi/testing.py rename to tests/test_utils.py index 80bb3561b4d55..90d7affa4a8da 100644 --- a/python/taichi/testing.py +++ b/tests/test_utils.py @@ -2,19 +2,22 @@ import functools import itertools import os +from errno import EEXIST from tempfile import mkstemp -from taichi.core import ti_core as _ti_core +from taichi._lib import core as _ti_core +from taichi.lang import cc, cpu, cuda, gpu, metal, opengl, vulkan +from taichi.lang.misc import is_arch_supported import taichi as ti # Helper functions def get_rel_eps(): - arch = ti.cfg.arch + arch = ti.lang.impl.current_cfg().arch if arch == ti.opengl: return 1e-3 - elif arch == ti.metal: + if arch == ti.metal: # Debatable, different hardware could yield different precisions # On AMD Radeon Pro 5500M, 1e-6 works fine... # https://github.com/taichi-dev/taichi/pull/1779 @@ -22,6 +25,18 @@ def get_rel_eps(): return 1e-6 +def mkdir_p(dir_path): + '''Creates a directory. equivalent to using mkdir -p on the command line''' + + try: + os.makedirs(dir_path) + except OSError as exc: # Python > 2.5 + if exc.errno == EEXIST and os.path.isdir(dir_path): + pass + else: + raise + + def approx(expected, **kwargs): '''Tweaked pytest.approx for OpenGL low precisions''' class boolean_integer: @@ -80,15 +95,58 @@ def required_extensions(self): } +def expected_archs(): + """ + Reads the environment variable `TI_WANTED_ARCHS` (usually set by option `-a` in `python tests/run_tests.py`) + and gets all expected archs on the machine. + If `TI_WANTED_ARCHS` is set and does not start with `^`, archs specified in it will be returned. + If `TI_WANTED_ARCHS` starts with `^` (usually when option `-n` is specified in `python tests/run_tests.py`), + all supported archs except archs specified in it will be returned. + If `TI_WANTED_ARCHS` is not set, all supported archs will be returned. + Returns: + List[taichi_core.Arch]: All expected archs on the machine. + """ + archs = set([cpu, cuda, metal, vulkan, opengl, cc]) + # TODO: now expected_archs is not called per test so we cannot test it + archs = set( + filter(functools.partial(is_arch_supported, use_gles=False), archs)) + + wanted_archs = os.environ.get('TI_WANTED_ARCHS', '') + want_exclude = wanted_archs.startswith('^') + if want_exclude: + wanted_archs = wanted_archs[1:] + wanted_archs = wanted_archs.split(',') + # Note, ''.split(',') gives you [''], which is not an empty array. + expanded_wanted_archs = set([]) + for arch in wanted_archs: + if arch == '': + continue + if arch == 'cpu': + expanded_wanted_archs.add(cpu) + elif arch == 'gpu': + expanded_wanted_archs.update(gpu) + else: + expanded_wanted_archs.add(_ti_core.arch_from_name(arch)) + if len(expanded_wanted_archs) == 0: + return list(archs) + if want_exclude: + expected = archs - expanded_wanted_archs + else: + expected = expanded_wanted_archs + return list(expected) + + def test(arch=None, exclude=None, require=None, **options): - ''' + """ + Performs tests on archs in `expected_archs()` which are in `arch` and not in `exclude` and satisfy `require` .. function:: ti.test(arch=[], exclude=[], require=[], **options) :parameter arch: backends to include :parameter exclude: backends to exclude :parameter require: extensions required :parameter options: other options to be passed into ``ti.init`` - ''' + + """ if arch is None: arch = [] @@ -102,17 +160,19 @@ def test(arch=None, exclude=None, require=None, **options): exclude = [exclude] if not isinstance(require, (list, tuple)): require = [require] - supported_archs = ti.supported_archs() + archs_expected = expected_archs() if len(arch) == 0: - arch = supported_archs + arch = archs_expected else: - arch = list(filter(lambda x: x in supported_archs, arch)) - if len(arch) == 0: - return lambda x: print('No supported arch found. Skipping') + arch = list(filter(lambda x: x in archs_expected, arch)) def decorator(foo): @functools.wraps(foo) def wrapped(*args, **kwargs): + if len(arch) == 0: + print('No supported arch found. Skipping.') + return + arch_params_sets = [arch, *_test_features.values()] arch_params_combinations = list( itertools.product(*arch_params_sets)) @@ -143,7 +203,9 @@ def wrapped(*args, **kwargs): if skip: continue - ti.init(arch=req_arch, **current_options) + ti.init(arch=req_arch, + enable_fallback=False, + **current_options) foo(*args, **kwargs) ti.reset() @@ -153,9 +215,5 @@ def wrapped(*args, **kwargs): __all__ = [ - 'get_rel_eps', - 'approx', - 'allclose', - 'make_temp_file', 'test', ] diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000..e6e6db4c47c64 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +v0.9.2